#include <algorithm>
#include <cassert>
#include <cstring>
#include <iostream>
#include <fstream>
#include <sstream>
#include <random>
#include <vector>
#include "crowd.h"
#include "distro.h"

using namespace std;

/// data point tuple ///

data_point::data_point(unsigned int task_id, unsigned int work_id, double label):
    task_id(task_id),
    work_id(work_id),
    label(label)
{}

void data_point::set_task_id(unsigned int task_id)
{
    this->task_id = task_id;
}

/// fixed budget crowd ///

crowd_budget::crowd_budget(distro* d_work_acc, distro* d_work_quota, unsigned int n_workers):
    d_work_acc(d_work_acc),
    d_work_quota(d_work_quota),
    n_workers(n_workers)
{}

void crowd_budget::create_data_point(std::mt19937& ran_gen, unsigned int t)
{
    // initialise and shuffle all the data at the beginning
    if(t == 0) {
        uniform_real_distribution<double> dist_uni(0.0, 1.0);
        work_acc = d_work_acc->extract_vec(ran_gen, n_workers);
        vector<double> work_quota = d_work_quota->extract_vec(ran_gen, n_workers);

        for(unsigned int j = 0; j < n_workers; ++j) {
            int quota = (int) work_quota[j];
            assert(quota > 0);

            for(unsigned int q = 0; q < (unsigned int) quota; ++q) {
                double label = (dist_uni(ran_gen) < work_acc[j])? +1.0: -1.0;
                data_seq.push_back(data_point(0, j, label));
            }
        }
        shuffle(data_seq.begin(), data_seq.end(), ran_gen);
    }

    // boundary check
    assert(t < data_seq.size());
}

crowd_budget::~crowd_budget()
{
    delete d_work_acc;
    delete d_work_quota;
}

void crowd_budget_test()
{
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    mt19937 ran_gen = well_seeded_mt19937(17);
    c.create_data_point(ran_gen, 0);

    unsigned int correct = 0;
    for(unsigned int t = 0; t < c.data_seq.size(); ++t)
        if(c.data_seq[t].label > 0.0)
            ++correct;

    assert(c.work_acc.size() == n_workers);
    assert(c.data_seq.size() == n_workers * quota);
    assert(correct >= n_workers * quota / 2); // stochastic test: we expect a majority of correct labels
}

/// fixed quota crowd ///

crowd_quota::crowd_quota(distro* d_work_acc, distro* d_work_quota):
    d_work_acc(d_work_acc),
    d_work_quota(d_work_quota)
{}

void crowd_quota::create_data_point(std::mt19937& ran_gen, unsigned int t)
{
    // avoid skipping (too many) data points
    assert(t <= data_seq.size());

    // create a new worker each time we run out of data
    if(t == data_seq.size()) {
        unsigned int j = work_acc.size();
        uniform_real_distribution<double> dist_uni(0.0, 1.0);
        work_acc.push_back(d_work_acc->extract_num(ran_gen));
        int quota = (int) d_work_quota->extract_num(ran_gen);
        assert(quota > 0);

        for(unsigned int q = 0; q < (unsigned int) quota; ++q) {
            double label = (dist_uni(ran_gen) < work_acc[j])? +1.0: -1.0;
            data_seq.push_back(data_point(0, j, label));
        }
    }
}

crowd_quota::~crowd_quota()
{
    delete d_work_acc;
    delete d_work_quota;
}

void crowd_quota_test()
{
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_quota c(d_work_acc, d_work_quota);

    unsigned int t_max = 43;
    mt19937 ran_gen = well_seeded_mt19937(17);
    for(unsigned t = 0; t < t_max; ++t)
        c.create_data_point(ran_gen, t);

    unsigned int correct = 0;
    for(unsigned int t = 0; t < c.data_seq.size(); ++t)
        if(c.data_seq[t].label > 0.0)
            ++correct;

    unsigned int n_workers = (t_max + quota - 1) / quota;
    assert(c.work_acc.size() == n_workers);
    assert(c.data_seq.size() == n_workers * quota);
    assert(correct >= n_workers * quota / 2); // stochastic test: we expect a majority of correct labels
}

/// fixed workers crowd ///

crowd_workers::crowd_workers(distro* d_work_acc, unsigned int n_workers):
    d_work_acc(d_work_acc),
    n_workers(n_workers)
{}

void crowd_workers::create_data_point(std::mt19937& ran_gen, unsigned int t)
{
    // avoid skipping (too many) data points
    assert(t <= data_seq.size());

    // initialise the workers' accuracy at the start
    if(t == 0)
        work_acc = d_work_acc->extract_vec(ran_gen, n_workers);

    // create one more label per worker each time we run out of data
    if(t == data_seq.size()) {
        uniform_real_distribution<double> dist_uni(0.0, 1.0);

        for(unsigned int j = 0; j < n_workers; ++j) {
            double label = (dist_uni(ran_gen) < work_acc[j])? +1.0: -1.0;
            data_seq.push_back(data_point(0, j, label));
        }
    }
}

crowd_workers::~crowd_workers()
{
    delete d_work_acc;
}

void crowd_workers_test()
{
    double p = 0.8;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    unsigned int n_workers = 10;
    crowd_workers c(d_work_acc, n_workers);

    unsigned int t_max = 43;
    mt19937 ran_gen = well_seeded_mt19937(17);
    for(unsigned t = 0; t < t_max; ++t)
        c.create_data_point(ran_gen, t);

    unsigned int correct = 0;
    for(unsigned int t = 0; t < c.data_seq.size(); ++t)
        if(c.data_seq[t].label > 0.0)
            ++correct;

    unsigned int quota = (t_max + n_workers - 1) / n_workers;
    assert(c.work_acc.size() == n_workers);
    assert(c.data_seq.size() == n_workers * quota);
    assert(correct >= n_workers * quota / 2); // stochastic test: we expect a majority of correct labels
}

/// read the data from a file ///

crowd_file::crowd_file(const char *filename)
{
    // input format: <worker_id>, <task_id>, <worker_label>, <gold_label>

    ifstream csv_file(filename);

    if(!csv_file.is_open()) {
        cerr << "Unable to open " << filename << endl;
        return;
    }

    string line, token[4];

    n_tasks = 0;
    n_workers = 0;

    for(n_labels = 0; getline(csv_file, line); ++n_labels) {
        stringstream line_stream(line);

        unsigned int i;
        for(i = 0; i < 4 && getline(line_stream, token[i], ','); ++i);

        if(i != 4) {
            cerr << "Wrong format in line " << (n_labels + 1) << endl;
            continue;
        }

        unsigned int worker_id = stoul(token[0], NULL, 10);
        unsigned int task_id = stoul(token[1], NULL, 10);
        double worker_label = 2.0 * stod(token[2], NULL) - 1.0;
        double gold_label = 2.0 * stod(token[3], NULL) - 1.0;

        if(task_id >= n_tasks) n_tasks = task_id + 1;
        if(worker_id >= n_workers) n_workers = worker_id + 1;
        if(gold_label != 1.0) worker_label = -worker_label; // force all gold_labels to +1.0

        data_seq.push_back(data_point(task_id, worker_id, worker_label));
    }

    work_acc.resize(n_workers);
}

crowd_file::~crowd_file()
{}

void crowd_file::create_data_point(std::mt19937& ran_gen, unsigned int t)
{
    assert(t < data_seq.size());
}

/// parser ///

crowd* crowd_parse(int *argc, char **argv[], unsigned int n_labels)
{
    if(*argc < 1) {
        cerr << "Not enough arguments to parse the crowd type" << endl;
        return NULL;
    }

    if(strcmp((*argv)[0], "crowd_budget") == 0) {
        if(*argc < 2) {
            cerr << "Not enough arguments to instantiate crowd_budget" << endl;
            return NULL;
        }
        unsigned int n_workers = strtoul((*argv)[1], NULL, 10);
        *argc -= 2; *argv += 2;
        distro* d_work_acc = distro_parse(argc, argv);
        distro* d_work_quota = distro_parse(argc, argv);

        // make an effort to match the workers' output to the required number of labels
        // works only with a fixed quota of labels per workers
        distro_kronecker* d = dynamic_cast<distro_kronecker*>(d_work_quota);
        if(d != NULL && n_labels > 0) {
            unsigned int work_quota = d->val;
            n_workers = (n_labels + work_quota - 1) / work_quota;
        }
        return (crowd*) new crowd_budget(d_work_acc, d_work_quota, n_workers);

    } else if(strcmp((*argv)[0], "crowd_quota") == 0) {
        *argc -= 1; *argv += 1;
        distro* d_work_acc = distro_parse(argc, argv);
        distro* d_work_quota = distro_parse(argc, argv);
        return (crowd*) new crowd_quota(d_work_acc, d_work_quota);

    } else if(strcmp((*argv)[0], "crowd_workers") == 0) {
        if(*argc < 2) {
            cerr << "Not enough arguments to instantiate crowd_workers" << endl;
            return NULL;
        }
        unsigned int n_workers = strtoul((*argv)[1], NULL, 10);
        *argc -= 2; *argv += 2;
        distro* d_work_acc = distro_parse(argc, argv);
        return (crowd*) new crowd_workers(d_work_acc, n_workers);
    }

    cerr << "Unable to parse the crowd type" << endl;
    return NULL;
}

void crowd_parse_test()
{
    const char *argv_fb[] = {"crowd_budget", "10", "dirac", "0.765", "kronecker", "8", "welcome"};
    const char *argv_fq[] = {"crowd_quota", "dirac", "0.765", "kronecker", "8", "to", "the"};
    const char *argv_fw[] = {"crowd_workers", "10", "dirac", "0.765", "jungle"};
    char **argv_fixed_budget = (char**) argv_fb;
    char **argv_fixed_quota = (char**) argv_fq;
    char **argv_fixed_workers = (char**) argv_fw;

    int argc_fixed_budget = 7;
    int argc_fixed_quota = 7;
    int argc_fixed_workers = 5;

    crowd_budget* c_fixed_budget = (crowd_budget*) crowd_parse(&argc_fixed_budget, &argv_fixed_budget, 23);
    crowd_quota* c_fixed_quota = (crowd_quota*) crowd_parse(&argc_fixed_quota, &argv_fixed_quota, 0);
    crowd_workers* c_fixed_workers = (crowd_workers*) crowd_parse(&argc_fixed_workers, &argv_fixed_workers, 0);

    assert(argc_fixed_budget == 1);
    assert(argc_fixed_quota == 2);
    assert(argc_fixed_workers == 1);

    assert(strcmp(argv_fixed_budget[0], "welcome") == 0);
    assert(strcmp(argv_fixed_quota[0], "to") == 0);
    assert(strcmp(argv_fixed_quota[1], "the") == 0);
    assert(strcmp(argv_fixed_workers[0], "jungle") == 0);

    assert(c_fixed_budget != NULL);
    assert(c_fixed_quota != NULL);
    assert(c_fixed_workers != NULL);

    mt19937 ran_gen = well_seeded_mt19937(17);
    c_fixed_budget->create_data_point(ran_gen, 0);
    c_fixed_quota->create_data_point(ran_gen, 0);
    c_fixed_workers->create_data_point(ran_gen, 0);

    assert(c_fixed_budget->work_acc.size() == 3);
    assert(c_fixed_quota->work_acc.size() == 1);
    assert(c_fixed_workers->work_acc.size() == 10);

    delete c_fixed_budget;
    delete c_fixed_quota;
    delete c_fixed_workers;
}
