#include <algorithm>
#include <chrono>
#include <cmath>
#include "VDA_algorithm.h"
using namespace std;

// constructor
VDA_algorithm::VDA_algorithm(the_Data *input_data){
    this -> data = input_data;

    this -> N  = input_data -> N;
    this -> p  = input_data -> p;
    this -> n  = this -> N;
    this -> r = n / 2 / this -> p;

    XXT.init(p+1, p+1);

    XY.resize(p+1);
    for(int i=0; i < p+1; ++i){
            XY[i] = 0;
    }

    lower_averaging_Z.init(p, p);

    lower_averaging_y.resize(p);
    for(int i=0; i < p; ++i){
        lower_averaging_y[i] = 0;
    }
    lower_averaging_counter.resize(p);
    for(int i=0; i < p; ++i){
        lower_averaging_counter[i] = 0;
    }

    upper_averaging_Z.init(p,p);

    upper_averaging_y.resize(p);
    for(int i=0; i < p; ++i){
        upper_averaging_y[i] = 0;
    }
    upper_averaging_counter.resize(p);
    for(int i=0; i < p; ++i){
        upper_averaging_counter[i] = 0;
    }
}

VDA_algorithm::~VDA_algorithm(){
}

vector<T> VDA_algorithm::Execute(){
    select_and_average();
    compute();
    vector<T> estimated_beta;
    estimated_beta = solve();
    return estimated_beta;
}

void VDA_algorithm::select_and_average(){

    vector<T> tmp_vector(N);

    vector<int> permute_indices;
    permute_indices.resize(N);
    for(int i = 0; i< N; ++i ){
        permute_indices[i] = i;
    }

    random_shuffle(permute_indices.begin(), permute_indices.end());

    for (int j = 0; j < p; ++j ){
        for(int i = 0; i < r; ++i){
            ++upper_averaging_counter[j];
            upper_averaging_y[j] += data->y[j * 2 * r + i];
            for(int k = 0; k< p; ++k){
                upper_averaging_Z(j, k) += data->Z(j * 2 * r + i, k);
            }
        }
        for(int i = 0; i < r; ++i){
                ++lower_averaging_counter[j];
                lower_averaging_y[j] += data->y[j * 2 * r + r + i];
                for(int k = 0; k< p;  ++k){
                    lower_averaging_Z(j, k) += data->Z(j * 2 * r + r + i, k);
                }
        }
    }
}

void VDA_algorithm::compute(){
    // add averaged data to XXT and XY
    vector<T> tmp_lower_divider(p);
    vector<T> tmp_upper_divider(p);
    for(int i = 0; i < p; ++i){
        tmp_lower_divider[i] = sqrt(double(lower_averaging_counter[i]));
        if(lower_averaging_counter[i] == 0){
            cout << "No averaging data ???" << endl;
            continue;
        }
        lower_averaging_y[i] /= tmp_lower_divider[i];
        for(int j = 0; j < p; ++j){
            lower_averaging_Z(i, j) /= tmp_lower_divider[i];
        }
    }
    for(int i = 0; i < p; ++i){
        tmp_upper_divider[i] = sqrt(double(upper_averaging_counter[i]));
        if(upper_averaging_counter[i] == 0){
            cout << "No averaging data ???" << endl;
            continue;
        }
        upper_averaging_y[i] /= tmp_upper_divider[i];
        for(int j = 0; j < p; ++j){
            upper_averaging_Z(i, j) /= tmp_upper_divider[i];
        }
    }

    for(int counter = 0; counter < p; ++counter){
        XXT(0, 0) += tmp_lower_divider[counter] * tmp_lower_divider[counter] + tmp_upper_divider[counter] * tmp_upper_divider[counter];
        XY[0] += tmp_lower_divider[counter] * lower_averaging_y[counter] + tmp_upper_divider[counter] * upper_averaging_y[counter];
        for(int i = 1; i< p+1; ++i ){
            XXT(i, 0) += tmp_lower_divider[counter] * lower_averaging_Z(counter, i-1) + tmp_upper_divider[counter] * upper_averaging_Z(counter, i-1);
            XY[i] += lower_averaging_Z(counter, i-1) * lower_averaging_y[counter] + upper_averaging_Z(counter, i-1) * upper_averaging_y[counter];
            for(int j = 1; j<= i; ++j){
                XXT(i, j) += lower_averaging_Z(counter, i-1) * lower_averaging_Z(counter, j-1) + upper_averaging_Z(counter, i-1) * upper_averaging_Z(counter, j-1);
            }
        }
    }

}

// Gaussian elimination
vector<T> VDA_algorithm::solve(){

    vector<T> estimated_beta(p+1);

    for(int i = 0; i< p+1; ++i){
        for(int j =i; j< p+1; ++j){
            XXT(i, j) = XXT(j, i);
        }
    }

    for(int counter = 0; counter < p; ++counter){
        for(int i = counter + 1; i < p+1; ++i){
            T the_ratio = XXT(counter, i) / XXT(counter, counter);
            for(int j = counter; j < p+1; ++j){
                XXT(j, i) -= the_ratio * XXT(j, counter);
            }
            XY[i] -= the_ratio * XY[counter];
        }
    }

    for(int i = p; i>=0; --i){
        estimated_beta[i] =  XY[i];
        for(int j = i+1; j<=p; ++j){
            estimated_beta[i] -= XXT(j, i) * estimated_beta[j];
        }
        estimated_beta[i] /= XXT(i, i);
    }
    return estimated_beta;
}
