//#define DEBUG

#include <emp-sh2pc/emp-sh2pc.h>
#include <emp-tool/emp-tool.h>
#include "emp-shmpc/emp-shmpc.h"
#include "emp-shmpc/constant.h"
#include "algs/shamir.cpp"
#include <iostream>
#include <string>
using namespace emp;
using namespace NTL;
using namespace std;

// this runs malicious-secure FLTrust aggregation (F^R) without communication steps
// in order to benchmark the CPU time for a single committee member


int main (int argc, char** argv) {

    string arg1 = argv[1];
    string arg2 = argv[2];
    string arg3;
    size_t pos;
    int NUM_PARTICIPANTS = stoi(arg1, &pos);
    int NUM_DIMENSIONS = stoi(arg2, &pos);
    int THETA;
    if (argc > 3) {
        arg3 = argv[3];
        THETA = stoi(arg3, &pos);
    } else {
        THETA = 16; // number of bits in fixed point representation
    }

    const int COMMITTEE_SIZE = 121;

    cout << "NUM_PARTICIPANTS: " << NUM_PARTICIPANTS << endl;
    cout << "NUM_DIMENSIONS: " << NUM_DIMENSIONS << endl;
    cout << "COMMITTEE_SIZE: " << COMMITTEE_SIZE << endl;
    cout << "THETA: " << THETA << endl;



    
    string fn = "fltrust_mal_bench_clients";
    fn.append(argv[1]);
    fn.append("_params");
    fn.append(argv[2]);
    fn.append("_theta");
    if (argc > 3) {
        fn.append(argv[3]);
        fn.append("_");
    } else {
        fn.append("20_");
    }

    time_t t = time(0);   // get time now
    struct tm * now = localtime( & t );

    
    char buffer [80];
    strftime (buffer,80,"%Y-%m-%d-%H-%M-%S.txt",now);
    fn.append(buffer);

    cout << "logfile: " << fn << endl;

    ofstream logfile;
    logfile.open(fn);

    logfile << "NUM_PARTICIPANTS: " << NUM_PARTICIPANTS << endl;
    logfile << "NUM_DIMENSIONS: " << NUM_DIMENSIONS << endl;
    logfile << "COMMITTEE_SIZE: " << COMMITTEE_SIZE << endl;
    logfile << "THETA: " << THETA << endl;
    

    //auto prepro_start = clock_start();
    // setup PRG and NTL
    PRG prg;

    init_ZZp();
    const ZZ_p FE_ONE = ZZ_p(1);
    const ZZ_p FE_ZERO = ZZ_p(0);
    const ZZ_p FE_NEG_ONE = ZZ_p(-1);

    

    // preprocessing:
    // initialize committee member's received data
    // a committee member receives parameters from each client
    // each parameter is encoded as a THETA-bit fixed point and a sign bit
    // each bit is shared individually

    // initialize vector of x values for reconstruction
    Vec<ZZ_p> xs;
    xs.SetLength(COMMITTEE_SIZE);
    for (int i=0; i<COMMITTEE_SIZE; i++) {
        ZZ_p field_x(i+1);
        xs[i] = field_x;
    }

    Vec<ZZ_p> dummy_shares = shamir_share(1, COMMITTEE_SIZE / 3, COMMITTEE_SIZE, prg);

    // use dummy shares to reconstruct
    // in order to get the CPU time
    vector< Vec<ZZ_p> > mag_bits(NUM_PARTICIPANTS); // [N x P*theta]
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        Vec<ZZ_p> cmb; // client magnitude bits
        cmb.SetLength(THETA * NUM_DIMENSIONS);
        for (int j=0; j<(THETA*NUM_DIMENSIONS); j++) {
            if (i==0) {
                if ( j == 3*THETA + (THETA-1) ) {
                    cmb[j] = FE_ONE;
                    //cmb[j] = ZZ_p(5);
                } else {
                    cmb[j] = FE_ZERO;
                }
            } else {
                if ( j == (THETA-1) ) {
                    cmb[j] = FE_ONE;
                    //cmb[j] = ZZ_p(5);
                } else {
                    cmb[j] = FE_ZERO;
                }
            }
        }
        mag_bits[i] = cmb;
    }

    

    vector< Vec<ZZ_p> > sign_bits(NUM_PARTICIPANTS); // [N x P]
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        Vec<ZZ_p> csb; // client sign bits
        csb.SetLength(NUM_DIMENSIONS);
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            if (i==1 && j == 0) {
                csb[j] = FE_NEG_ONE;
            } else {
                csb[j] = FE_NEG_ONE;
            }
            
        }
        sign_bits[i] = csb;
    }

    #ifdef DEBUG
    cout << "initialized mag_bits" << endl;
    #endif
    
    
    // PROTOCOL PROPER:
    
    int timer = 0;
    auto start = emp::clock_start();

    // BINARY CHECK FOR mag_bits
    ZZ_p BIN_BATCH_CHECK(0);
    ZZ_p one(1);
    ZZ_p t1, t2, t3, zz_r;
    long p_r;
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        for(int j=0; j<THETA*NUM_DIMENSIONS; j++) {
            // creating s_i, a share of b(b-1)
            sub(t1, mag_bits[i][j], one); // t1 = b-1
            mul(t2, mag_bits[i][j], t1); // t2 = b*(b-1)

            // generating r_i, a random field element
            prg.random_data(&p_r, sizeof(p_r)); // prg generates random long
            conv(zz_r, p_r); // put random long into ZZ_p
            mul(t3, zz_r, t2); // t3 = <random field element> * b*(b-1)
            // adding up \sum s_i * r_i
            add(BIN_BATCH_CHECK, BIN_BATCH_CHECK, t3); // BIN_BATCH_CHECK += t3
        }
    }
    // reconstruct and reveal BIN_BATCH_CHECK
    // (reconstructing a dummy to obtain the CPU time)
    ZZ_p dummy_check = reconstruct(xs, dummy_shares);
    cout << "BIN_BATCH_CHECK: "; 
    if (BIN_BATCH_CHECK==ZZ_p(0)) { 
        cout << "PASSED" << endl;
    } else {
        cout << "FAILED" << endl;
    }

    ZZ_p SIGN_BATCH_CHECK = ZZ_p(0);
    // SIGN CHECK
    // proves that the sign bit is either -1 or 1
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            // creating s_i, a share of (b+1)(b-1)
            sub(t1, sign_bits[i][j], one);
            add(t2, sign_bits[i][j], one);
            mul(t2, t1, t2);

            prg.random_data(&p_r, sizeof(p_r)); // prg generates random long
            conv(zz_r, p_r); // put random long into ZZ_p
            mul(t3, zz_r, t2); // t3 = <random field element> * (b+1)(b-1)
            // adding up \sum s_i * r_i
            add(SIGN_BATCH_CHECK, SIGN_BATCH_CHECK, t3);
        }
    }

    // reconstruct and reveal SIGN_BATCH_CHECK
    // (reconstruct dummy to obtain CPU time)
    ZZ_p dummy_check2 = reconstruct(xs, dummy_shares);
    cout << "SIGN_BATCH_CHECK: "; 
    if (SIGN_BATCH_CHECK==ZZ_p(0)) { 
        cout << "PASSED" << endl;
    } else {
        cout << "FAILED" << endl;
    }

    
    // get debinarized representation of magnitudes
    vector< Vec<ZZ_p> > mag_fes(NUM_PARTICIPANTS); // [N x P]
    for (int i = 0; i<NUM_PARTICIPANTS; i++) {
        Vec<ZZ_p> cmfes; // client magnitude field elements
        cmfes.SetLength(NUM_DIMENSIONS);
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            cmfes[j] = unsigned_debinarize(j*THETA, THETA, mag_bits[i]);
        } 
        mag_fes[i] = cmfes;
    }


    #ifdef DEBUG
    cout << "initialized mag_fes" << endl;
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        cout << "[";
        for (int j=0; j<NUM_DIMENSIONS; j++) {
            cout << " " << mag_fes[i][j] << " ";
        }
        cout << "]\n";
    }
    #endif


    // UNIT LENGTH CHECK:

    long temp = (1 << (2*(THETA-1))); // overflow issue past THETA=16
    ZZ_p FP_SQUARE_UNIT_LENGTH = ZZ_p(temp);

    ZZ_p UNIT_BATCH_CHECK = ZZ_p(0);
    //long temp_r;
    //ZZ_p zz_r, t1;
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        ZZ_p check = ZZ_p(0);
        for(int j=0; j<NUM_DIMENSIONS; j++) {
            check += mag_fes[i][j] * mag_fes[i][j]; // <g_i, g_i>
        }
        check -= FP_SQUARE_UNIT_LENGTH; // check = <g_i, g_i> - FP_SQ_UNIT_LENGTH
        prg.random_data(&p_r, sizeof(p_r));
        conv(zz_r, p_r);
        mul(t1, zz_r, check); // t1 = <random field element> * (<g_i, g_i> - FP_SQ_UNIT_LENGTH)
        add(UNIT_BATCH_CHECK, UNIT_BATCH_CHECK, t1); // add them up to batch check them
        #ifdef DEBUG
        cout << " " << check << " ";
        #endif
    }

    // reconstruct and reveal UNIT_BATCH_CHECK
    // (reconstruct dummy to obtain CPU time)
    ZZ_p dummy_check3 = reconstruct(xs, dummy_shares);

    cout << endl;
    cout << "UNIT_BATCH_CHECK: "; // TODO: reconstruct and reveal
    if (UNIT_BATCH_CHECK == ZZ_p(0)) {
        cout << "PASSED" << endl;
    } else {
        cout << "FAILED" << endl;
    }

    // GET TRUST SCORE
    // TS_i = ReLU(cos(\phi)) where \phi is the angle between g_0 and g_i
    // we have cos(angle between vectors v and u) = <v, u> / ||v|| ||u||
    // but since we normalized to unit length,
    // this means that cos(\phi) = <g_i, g_0>.
    // and since we rotated g_0 to align with the x-axis,
    // this means that g_0 = [1, 0, 0, ... 0]
    // which means that <g_i, g_0> = the 0th coordinate of g_i.
    // and we avoid taking a RELU by not accepting a sign bit for the 0th coordinate
    // rather we assume that it is positive
    // So TS_i = 0th coordinate of g_i
    Vec<ZZ_p> trust_scores;
    trust_scores.SetLength(NUM_PARTICIPANTS);
    for (int i=0; i<NUM_PARTICIPANTS; i++) {
        trust_scores[i] = mag_fes[i][0];
    }

    // SUM PARAMETER DELTAS
    // compute \forall p \in P, \sum_{c \in N} mag_p \cdot TS_c \cdot (-1 * sign_p)
    Vec<ZZ_p> param_deltas;
    param_deltas.SetLength(NUM_DIMENSIONS);
    ZZ_p neg_one = ZZ_p(-1);
    for (int j=0; j<NUM_DIMENSIONS; j++) {
        param_deltas[j] = FE_ZERO;
        for (int i=0; i<NUM_PARTICIPANTS; i++) {
            if (j==0) {
                param_deltas[j] += mag_fes[i][j] * trust_scores[i]; // sign bit not used for the 0th coordinate, always positive
            } else {
                param_deltas[j] += mag_fes[i][j] * trust_scores[i] * sign_bits[i][j];
            }
        }
    }

    // RECONSTRUCT PARAMETER DELTAS
    for (int j=0; j<NUM_DIMENSIONS; j++) {
        ZZ_p dummy_check4 = reconstruct(xs, dummy_shares);
    }

    timer = emp::time_from(start);
    cout << "protocol time: " << timer << endl;
    logfile << "protocol time: " << timer << endl;
    
    #ifdef DEBUG
    cout << "PARAM DELTAS: " << endl;
    cout << "[";
    for (int j=0; j<NUM_DIMENSIONS; j++) {
        cout << " " << param_deltas[j] << " ";
    }
    cout << "]\n";
    #endif


    cout << "finished." << endl;


    return 0;
}