#include <emp-sh2pc/emp-sh2pc.h>
#include <emp-tool/emp-tool.h>
#include "emp-shmpc/emp-shmpc.h"
#include "emp-shmpc/constant.h"
#include "algs/shamir.cpp"
//#include <cassert>
#include <iostream>
#include <string>
using namespace emp;
using namespace NTL;
using namespace std;

// this runs malicious RSA aggregation without communication steps


int main (int argc, char** argv) {
    // renaming constants from emp-shmpc/constant.h
    //const int NUM_PARTICIPANTS = CONST_POINTS;
    //const int COMMITTEE_SIZE = NUM_PARTIES;

    //const int NUM_DIMENSIONS = 10000;
    string arg1 = argv[1];
    string arg2 = argv[2];
    size_t pos;
    int NUM_PARTICIPANTS = stoi(arg1, &pos);
    int NUM_DIMENSIONS = stoi(arg2, &pos);

    //const int NUM_PARTICIPANTS = 50;
    const int COMMITTEE_SIZE = 46;


    string fn = "rsa_mal_bench_clients";
    fn.append(argv[1]);
    fn.append("_params");
    fn.append(argv[2]);
    fn.append("_");

    time_t t = time(0);   // get time now
    struct tm * now = localtime( & t );

    char buffer [80];
    strftime (buffer,80,"%Y-%m-%d-%H-%M.txt",now);
    fn.append(buffer);

    ofstream logfile;
    logfile.open(fn);
    cout << "logfile: " << fn << endl;
    cout << "NUM_PARTICIPANTS: " << NUM_PARTICIPANTS << endl;
    cout << "NUM_DIMENSIONS: " << NUM_DIMENSIONS << endl;
    cout << "COMMITTEE_SIZE: " << COMMITTEE_SIZE << endl;

    logfile << "NUM_PARTICIPANTS: " << NUM_PARTICIPANTS << endl;
    logfile << "NUM_DIMENSIONS: " << NUM_DIMENSIONS << endl;
    logfile << "COMMITTEE_SIZE: " << COMMITTEE_SIZE << endl;

    auto prepro_start = clock_start();
    // setup PRG and NTL
    PRG prg;

    init_ZZp();

    // TEST PREPROCESSING:
    // steps that are not necessary during the actually aggregation,
    // but need to be performed in order to do local computation benchmarking.

    // NOTE: here we are generating shares of participant updates locally
    // this is for the sake of benchmarking. 
    // In practice, each committee member will receive a NUM_PARTICIPANTS * NUM_DIMENSIONS vector of shares
    // which will store a NUM_DIMENSIONS vector of shares received from each client.
    
    // GENERATE MASTER LIST OF UPDATE SHARES
    // master_shares is a NUM_PARTICIPANTS * NUM_DIMENSIONS * COMMITTEE_SIZE vector
    // each entry is an NTL::ZZ_p representing a share of a client update
    // in total, the shares reconstruct NUM_PARTICIPANTS * NUM_DIMENSIONS values
    vector<vector<Vec<ZZ_p> > > master_shares(NUM_PARTICIPANTS);

    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        vector<Vec<ZZ_p>> p_share(NUM_DIMENSIONS);
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            p_share[j] = shamir_share(1, COMMITTEE_SIZE / 2, COMMITTEE_SIZE, prg);
        }
        master_shares[i] = p_share;
    }

    cout << "generated master_shares" << endl;
    logfile << "generated master_shares" << endl;

    // NOTE: here we are generating shares of b(b-1) for participant updates locally
    // In practice, each committee member will generate one b(b-1) check value for each
    // of the received NUM_PARTICIPANTS * NUM_DIMENSIONS shares.

    /*
    // GENERATE MASTER LIST OF CHECK SHARES
    // master_check_shares is a NUM_PARTICIPANTS * NUM_DIMENSIONS * COMMITTEE_SIZE vector
    // each entry is an NTL::ZZ_p representing a share of b(b-1) where b is a client update for a single parameter.
    // in total, the shares reconstruct binary-value checks for NUM_PARTICIPANTS * NUM_DIMENSIONS values

    vector<vector<Vec<ZZ_p> > > master_check_shares(NUM_PARTICIPANTS);

    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        vector<Vec<ZZ_p> > v(NUM_DIMENSIONS);
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            Vec<ZZ_p> s;
            s.SetLength(COMMITTEE_SIZE);
            for (int k=0; k<COMMITTEE_SIZE; k++) {
                s[k] = check_binary_help(master_shares[i][j][k]);
            }
            v[j] = s;
        }
        master_check_shares[i] = v;
    }

    cout << "generated master_check_shares" << endl;
    */

    // generate master shares of batch check
    // generate randomness -- same across all committee members by using shared random seed
    vector< vector<long> > master_r(NUM_PARTICIPANTS);
    long temp_r;
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        vector<long> rv(NUM_DIMENSIONS);
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            prg.random_data(&temp_r, sizeof(temp_r));
            rv[j] = temp_r;
        }
        master_r[i] = rv;
    }

    Vec<ZZ_p> master_batch_check; // shares of the batch check
    master_batch_check.SetLength(COMMITTEE_SIZE);
    
    for (int p=0; p<COMMITTEE_SIZE; p++) {
        ZZ_p bcs(0); // batch check sum for each party
        ZZ_p one(1);
        ZZ_p t1, t2, t3, zz_r;
        for (int i=0; i<NUM_PARTICIPANTS; i++) {
            for(int j=0; j<NUM_DIMENSIONS; j++) {
                //cs[j] = check_binary_help(master_shares[i][j][party-1]);
                // doing check_binary_help without creating new ZZ_p objects
                sub(t1, master_shares[i][j][p], one); // t1 = b-1
                mul(t2, master_shares[i][j][p], t1); // t2 = b*(b-1)
                
                //prg.random_data(&p_r, sizeof(p_r)); // prg generates random long
                conv(zz_r, master_r[i][j]); // put random long into ZZ_p
                mul(t3, zz_r, t2); // t3 = <random field element> * b*(b-1)
                add(bcs, bcs, t3); // batch_check_sum += t3
            }
        }
        master_batch_check[p] = bcs;
    }
    

    // GENERATE MASTER LIST OF SUM SHARES
    // master_sum_shares is a COMMITTEE_SIZE * NUM_DIMENSIONS vector
    // each entry is an NTL::ZZ_p representing a share of summed client updates for a single parameter.
    vector<Vec<ZZ_p> > master_sum_shares(COMMITTEE_SIZE);
    for (int i=0; i<COMMITTEE_SIZE; i++) {
        Vec<ZZ_p> ss;
        ss.SetLength(NUM_DIMENSIONS);
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            for (int k=0; k<NUM_PARTICIPANTS; k++) {
                ss[j] += master_shares[k][j][i]; // master_shares indexed by participants, dimensions, committee
            }
        }
        master_sum_shares[i] = ss;
    }
    int prepro_t = time_from(prepro_start);

    // TEST PREPROCESSING FINISHED
    // PROTOCOL START

    int timer = 0;
    auto start = clock_start();
    // INITIALIZATION

    // port not actually used, party used to index in places
    //int port, party;
    //parse_party_and_port(argv, &party, &port);

    int party = 1;

    
    // GENERATE BATCH CHECK
    // the batch check proceeds by each committee member locally generating a share of
    // \sum r_i * s_i
    // where r_i is a uniform random field element sampled from the PRG
    // and s_i is a share of b*(b-1) for each b in the [<num clients> x <num params>] vector of shares that each party possesses
    // we need to check that all s_i are equal to zero, and this is accomplished by checking
    // whether \sum r_i * s_i == 0, which is performed later in the protocol.
    
    ZZ_p batch_check_sum(0);
    ZZ_p one(1);
    ZZ_p t1, t2, t3, zz_r;
    long p_r;
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        for(int j=0; j<NUM_DIMENSIONS; j++) {
            // creating s_i, a share of b(b-1)
            sub(t1, master_shares[i][j][party-1], one); // t1 = b-1
            mul(t2, master_shares[i][j][party-1], t1); // t2 = b*(b-1)
            
            // generating r_i, a random field element
            prg.random_data(&p_r, sizeof(p_r)); // prg generates random long
            conv(zz_r, p_r); // put random long into ZZ_p
            mul(t3, zz_r, t2); // t3 = <random field element> * b*(b-1)
            // adding up \sum s_i * r_i
            add(batch_check_sum, batch_check_sum, t3); // batch_check_sum += t3
        }
    }
    

    /*
    // code from before batch check was implemented, stores shares of b(b-1) for each b
    // in the [<num clients> * <num params>] vector, so that they can be checked individually later
    vector< Vec<ZZ_p> > check_shares(NUM_PARTICIPANTS);
    ZZ_p one(1);
    ZZ_p t1, t2;
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        Vec<ZZ_p> cs; 
        cs.SetLength(NUM_DIMENSIONS); 
        for(int j=0; j<NUM_DIMENSIONS; j++) {
            //cs[j] = check_binary_help(master_shares[i][j][party-1]);
            // doing check_binary_help without creating new ZZ_p objects
            sub(t1, master_shares[i][j][party-1], one); // t1 = b-1
            mul(t2, master_shares[i][j][party-1], t1); // t2 = b*(b-1)
            cs[j] = t2;
        }
        check_shares[i] = cs;
    }
    */
    

    auto gen_bincheck_t = time_from(start);
    timer = time_from(start);
    
    cout << "generated check_shares" << endl;
    logfile << "generated check_shares" << endl;
    
    // GENERATE SHARES OF PARAMETER DELTAS
    // sum_shares is a NUM_DIMENSIONS vector
    // each entry is a share of \sum_{p \in clients} b_p -- for each parameter
    // these shares reconstruct the global update
    Vec<ZZ_p> sum_shares;
    sum_shares.SetLength(NUM_DIMENSIONS);
    // initialize to zero
    // TODO: check if this is done automatically
    for (int i=0; i<NUM_DIMENSIONS; i++) {
        sum_shares[i] = ZZ_p(0);
    }
    // take the sum
    // TODO: NTL has vector summation, this can potentially be optimized
    for (int j=0; j<NUM_PARTICIPANTS; j++) {
        for (int i=0; i<NUM_DIMENSIONS; i++) {
            sum_shares[i] += master_shares[j][i][party-1];
        }
    }

    auto sum_shares_t = time_from(start) - timer;
    timer = time_from(start);

    cout << "generated sum_shares" << endl;
    logfile << "generated sum_shares" << endl;

    // CHECK THAT UPDATES ARE BINARY-VALUED

    // initialize vector of x values for reconstruction
    Vec<ZZ_p> xs;
    xs.SetLength(COMMITTEE_SIZE);
    for (int i=0; i<COMMITTEE_SIZE; i++) {
        ZZ_p field_x(i+1);
        xs[i] = field_x;
    }
    
    /*
    // previous code from before batch check was implemented
    // needed to check b(b-1)==0 for each b in the [<num client> x <num param>] vector of shares
    ZZ_p r_check;
    ZZ_p field_zero(0);
    ZZ_pX temp_poly;
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            //r_check = reconstruct(xs, master_check_shares[i][j]);
            temp_poly = interpolate(xs, master_check_shares[i][j]);
            r_check = eval(temp_poly, field_zero);
            if (!(r_check==field_zero)) {
                cout << "ABORT\n";
            }
        }
    }
    */

    // NOTE: in practice, each committee member should send and receive
    // their share of batch_check_sum to/from each other committee member.
    // for benchmarking the local computation, we reference the master list of batch check shares generated during
    // the TEST PREPROCESSING step.
    // We copy the shares over to a new vector simulate the time spent on allocation for this process.

    Vec<ZZ_p> batch_check_shares;
    batch_check_shares.SetLength(COMMITTEE_SIZE);
    for (int p=0; p<COMMITTEE_SIZE; p++) {
        batch_check_shares[p] = master_batch_check[p];
    }

    ZZ_p r_check;
    ZZ_p field_zero(0);
    r_check = reconstruct(xs, batch_check_shares);
    // if \sum s_i * r_i != 0, then some share is not binary valued, so abort.
    if (!(r_check==field_zero)) {
        cout << "ABORT\n";
        logfile << "ABORT\n";
    }

    auto bin_check_t = time_from(start) - timer;
    timer = time_from(start);
    cout << "ran binary-value check" << endl;
    logfile << "ran binary-value check" << endl;

    // RECONSTRUCT PARAMETER DELTAS
    // in practice, each party will send each other party shares of the parameter deltas for reconstruction
    // for local computation benchmarking, we reference the precomputed master list to find the shares of each party, and reconstruct from the shares
    Vec<ZZ_p> param_deltas; // vector of reconstructed parameter deltas
    param_deltas.SetLength(NUM_DIMENSIONS);

    Vec<ZZ_p> param_shares; // vector to hold 1 share from each party so that it can be reconstructed
    param_shares.SetLength(COMMITTEE_SIZE);


    for (int i=0; i<NUM_DIMENSIONS; i++) {
        // load the shares into param_shares so they can be reconstructed
        for (int j=0; j<COMMITTEE_SIZE; j++) {
            param_shares[j] = master_sum_shares[j][i];
        }
        param_deltas[i] = reconstruct(xs, param_shares);
    }
    auto rec_delta_t = time_from(start) - timer;
    timer = time_from(start);


    // PROTOCOL END
    cout << "preprocessing time (not part of protocol): " << prepro_t << endl;
    cout << endl;

    cout << "generating binary check values: " << gen_bincheck_t << endl;
    cout << "summing parameter deltas: " << sum_shares_t << endl;
    cout << "binary value checking: " << bin_check_t << endl;
    cout << "reconstructing deltas: " << rec_delta_t << endl;
    cout << "total protocol time: " << time_from(start) << endl;

    cout << "total time w/ preprocessing: " << time_from(start) + prepro_t << endl;

    cout << "preprocessing time (not part of protocol): " << prepro_t << endl;
    cout << endl;

    logfile << "generating binary check values: " << gen_bincheck_t << endl;
    logfile << "summing parameter deltas: " << sum_shares_t << endl;
    logfile << "binary value checking: " << bin_check_t << endl;
    logfile << "reconstructing deltas: " << rec_delta_t << endl;
    logfile << "total protocol time: " << time_from(start) << endl;

    logfile << "total time w/ preprocessing: " << time_from(start) + prepro_t << endl;

    // SANITY CHECK
    // all party updates are given as 1, 
    // so each parameter delta should be NUM_PARTICIPANTS
    // in practice this is not part of the computation
    // NOTE: this sanity check is hardcoded for the case where all
    // participants give an update of 1. It will fail in any other cases,
    // even ones where parties submit valid binary values. It is only for debugging purposes.

    int sc=0;
    ZZ_p target(NUM_PARTICIPANTS);
    for (int i=0; i<NUM_DIMENSIONS; i++) {
        if (param_deltas[i]!=target) {
            cout << "param_deltas[" << i << "] WRONG! " << param_deltas[i] << endl;
            logfile << "param_deltas[" << i << "] WRONG! " << param_deltas[i] << endl;
            sc++;
        }
    }
    if (sc==0) {
        cout << "SANITY CHECK PASSED" << endl;
        logfile << "SANITY CHECK PASSED" << endl;
    }

    return 0;
}