#include <algorithm>
#include <array>
#include <cassert>
#include <cstring>
#include <iostream>
#include <random>
#include <vector>
#include "distro.h"

using namespace std;

// this is the best I could do after hours spent on stackoverflow and reddit
// google "c++ mersenne twister mt19937 seed" if you want to have fun
mt19937 well_seeded_mt19937(unsigned int seed) {

    // create an array of random numbers (the state of mt19937 is large)
    minstd_rand simple_rand_gen(seed);
    array<int, mt19937::state_size> raw_seed_seq;
    generate_n(raw_seed_seq.data(), raw_seed_seq.size(), simple_rand_gen);

    // turn it into a seed sequence (for compatibility reasons) and create the generator
    seed_seq jumbo_seed_seq(begin(raw_seed_seq), end(raw_seed_seq));
    mt19937 ran_gen(jumbo_seed_seq);

    return ran_gen;
}

/// uniform distro ///

distro_uniform::distro_uniform(double min_val, double max_val):
    min_val(min_val),
    max_val(max_val)
{
    assert(min_val < max_val);
}

double distro_uniform::extract_num(mt19937& ran_gen)
{
    uniform_real_distribution<double> dist(min_val, max_val);
    return dist(ran_gen);
}

vector<double> distro_uniform::extract_vec(mt19937& ran_gen, unsigned int n)
{
    uniform_real_distribution<double> dist(min_val, max_val);

    vector<double> vec(n);
    for(unsigned int i = 0; i < n; ++i)
        vec[i] = dist(ran_gen);
    return vec;
}

void distro_uniform_test()
{
    mt19937 ran_gen = well_seeded_mt19937(0);

    // test the constructor
    double min_val = -1.0;
    double max_val = 2.0;
    distro_uniform d(min_val, max_val);

    // test extract_num()
    double r = d.extract_num(ran_gen);
    assert(r >= min_val);
    assert(r < max_val);

    // test extract_vec()
    unsigned int n = 100;
    vector<double> v = d.extract_vec(ran_gen, n);
    vector<double> w = d.extract_vec(ran_gen, n);

    // values outside the interval
    for(unsigned int i = 0; i < n; ++i) {
        assert(v[i] >= min_val);
        assert(v[i] < max_val);
    }

    // repetitions inside the same vector
    unsigned int repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < i; ++j)
            if(v[j] == v[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries

    // repetitions between different calls
    repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < n; ++j)
            if(v[j] == w[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries
}

/// beta distro ///

distro_beta::distro_beta(double alpha, double beta):
    alpha(alpha),
    beta(beta)
{
    assert(alpha > 0.0);
    assert(beta > 0.0);
}

static double gamma_2_beta(gamma_distribution<double>& dist_alpha,
                           gamma_distribution<double>& dist_beta,
                           mt19937& ran_gen)
{
    // generate a beta-distributed number from two gamma-distributed ones
    double a = dist_alpha(ran_gen);
    double b = dist_beta(ran_gen);
    return a / (a + b);
}

double distro_beta::extract_num(mt19937& ran_gen)
{
    gamma_distribution<double> dist_alpha(alpha, 1.0);
    gamma_distribution<double> dist_beta(beta, 1.0);
    return gamma_2_beta(dist_alpha, dist_beta, ran_gen);
}

vector<double> distro_beta::extract_vec(mt19937& ran_gen, unsigned int n)
{
    gamma_distribution<double> dist_alpha(alpha, 1.0);
    gamma_distribution<double> dist_beta(beta, 1.0);

    vector<double> vec(n);
    for(unsigned int i = 0; i < n; ++i)
        vec[i] = gamma_2_beta(dist_alpha, dist_beta, ran_gen);
    return vec;
}

void distro_beta_test()
{
    mt19937 ran_gen = well_seeded_mt19937(0);

    // test the constructor
    double alpha = 2.0;
    double beta = 1.0;
    distro_beta d(alpha, beta);

    // test extract_num()
    double r = d.extract_num(ran_gen);
    assert(r >= 0.0);
    assert(r < 1.0);

    // test extract_vec()
    unsigned int n = 100;
    vector<double> v = d.extract_vec(ran_gen, n);
    vector<double> w = d.extract_vec(ran_gen, n);

    // values outside the interval
    for(unsigned int i = 0; i < n; ++i) {
        assert(v[i] >= 0.0);
        assert(v[i] < 1.0);
    }

    // repetitions inside the same vector
    unsigned int repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < i; ++j)
            if(v[j] == v[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries

    // repetitions between different calls
    repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < n; ++j)
            if(v[j] == w[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries
}

/// dirac's delta distro ///

distro_dirac::distro_dirac(double p):
    p(p)
{
    assert(p >= 0.0);
    assert(p <= 1.0);
}

double distro_dirac::extract_num(mt19937& ran_gen)
{
    return p;
}

vector<double> distro_dirac::extract_vec(mt19937& ran_gen, unsigned int n)
{
    vector<double> vec(n, p);
    return vec;
}

void distro_dirac_test()
{
    mt19937 ran_gen = well_seeded_mt19937(0);

    // test the constructor
    double p = 0.7;
    distro_dirac d(p);

    // test extract_num()
    double r = d.extract_num(ran_gen);
    assert(r == p);

    // test extract_vec()
    unsigned int n = 100;
    vector<double> v = d.extract_vec(ran_gen, n);
    for(unsigned int i = 0; i < n; ++i)
        assert(v[i] == p);
}

/// kronecker's delta distro ///

distro_kronecker::distro_kronecker(int val):
    val(val)
{}

double distro_kronecker::extract_num(mt19937& ran_gen)
{
    return (double) val;
}

vector<double> distro_kronecker::extract_vec(mt19937& ran_gen, unsigned int n)
{
    vector<double> vec(n, (double) val);
    return vec;
}

void distro_kronecker_test()
{
    mt19937 ran_gen = well_seeded_mt19937(0);

    // test the constructor
    int val = -123;
    distro_kronecker d(val);

    // test extract_num()
    double r = d.extract_num(ran_gen);
    assert(r == (double) val);

    // test extract_vec()
    unsigned int n = 100;
    vector<double> v = d.extract_vec(ran_gen, n);
    for(unsigned int i = 0; i < n; ++i)
        assert(v[i] == val);
}

/// power law distribution ///

distro_power::distro_power(double v_min, double v_max, double c_pow):
    v_min(v_min),
    v_max(v_max),
    c_pow(c_pow)
{
    p_min = pow(v_min, c_pow + 1.0);
    p_max = pow(v_max, c_pow + 1.0);
    d_pow = 1.0 / (c_pow + 1.0);
}

double distro_power::extract_num(mt19937& ran_gen)
{
    uniform_real_distribution<double> dist(0.0, 1.0);
    double r_pow = pow((p_max - p_min) * dist(ran_gen) + p_min, d_pow);
    return v_max - r_pow + v_min;
}

vector<double> distro_power::extract_vec(mt19937& ran_gen, unsigned int n)
{
    uniform_real_distribution<double> dist(0.0, 1.0);

    vector<double> vec(n);
    for(unsigned int i = 0; i < n; ++i) {
        double r_pow = pow((p_max - p_min) * dist(ran_gen) + p_min, d_pow);
        vec[i] = v_max - r_pow + v_min;
    }
    return vec;
}

void distro_power_test()
{
    mt19937 ran_gen = well_seeded_mt19937(0);

    // test the constructor
    double v_min = 1.0;
    double v_max = 10.0;
    double c_pow = 4.0;
    distro_power d(v_min, v_max, c_pow);

    // test extract_num()
    double r = d.extract_num(ran_gen);
    assert(r >= 1.0);
    assert(r < 10.0);

    // test extract_vec()
    unsigned int n = 100;
    vector<double> v = d.extract_vec(ran_gen, n);
    vector<double> w = d.extract_vec(ran_gen, n);

    // values outside the interval
    for(unsigned int i = 0; i < n; ++i) {
        assert(v[i] >= 1.0);
        assert(v[i] < 10.0);
    }

    // repetitions inside the same vector
    unsigned int repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < i; ++j)
            if(v[j] == v[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries

    // repetitions between different calls
    repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < n; ++j)
            if(v[j] == w[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries

    // power law shape
    unsigned int tail = 0;
    for(unsigned int i = 0; i < n; ++i)
        if(v[i] > 3.0)
            ++tail;
    assert(tail < n / 2); // stochastic test: we expect 33% mass on the tail
}

/// exponential distribution ///

distro_exponential::distro_exponential(double v_min, double lambda):
    v_min(v_min),
    lambda(lambda)
{}

double distro_exponential::extract_num(mt19937& ran_gen)
{
    exponential_distribution<double> dist(lambda);
    return dist(ran_gen) + v_min;
}

vector<double> distro_exponential::extract_vec(mt19937& ran_gen, unsigned int n)
{
    exponential_distribution<double> dist(lambda);

    vector<double> vec(n);
    for(unsigned int i = 0; i < n; ++i)
        vec[i] = dist(ran_gen) + v_min;
    return vec;
}

void distro_exponential_test()
{
    mt19937 ran_gen = well_seeded_mt19937(0);

    // test the constructor
    double v_min = 2.0;
    double lambda = 0.1;
    distro_exponential d(v_min, lambda);

    // test extract_num()
    double r = d.extract_num(ran_gen);
    assert(r >= v_min);

    // test extract_vec()
    unsigned int n = 100;
    vector<double> v = d.extract_vec(ran_gen, n);
    vector<double> w = d.extract_vec(ran_gen, n);

    // values outside the interval
    for(unsigned int i = 0; i < n; ++i)
        assert(v[i] >= v_min);

    // repetitions inside the same vector
    unsigned int repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < i; ++j)
            if(v[j] == v[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries

    // repetitions between different calls
    repeat = 0;
    for(unsigned int i = 0; i < n; ++i)
        for(unsigned int j = 0; j < n; ++j)
            if(v[j] == w[i])
                ++repeat;
    assert(repeat <= 1); // stochastic test: we expect very few repeated entries
}

/// parser ///

distro* distro_parse(int *argc, char **argv[])
{
    if(*argc < 1) {
        cerr << "Not enough arguments to parse the distro type" << endl;
        return NULL;
    }

    if(strcmp((*argv)[0], "uniform") == 0) {
        if(*argc < 3) {
            cerr << "Not enough arguments to instantiate distro_uniform" << endl;
            return NULL;
        }
        double min_val = strtod((*argv)[1], NULL);
        double max_val = strtod((*argv)[2], NULL);
        *argc -= 3; *argv += 3;
        return (distro*) new distro_uniform(min_val, max_val);

    } else if(strcmp((*argv)[0], "beta") == 0) {
        if(*argc < 3) {
            cerr << "Not enough arguments to instantiate distro_beta" << endl;
            return NULL;
        }
        double alpha = strtod((*argv)[1], NULL);
        double beta = strtod((*argv)[2], NULL);
        *argc -= 3; *argv += 3;
        return (distro*) new distro_beta(alpha, beta);

    } else if(strcmp((*argv)[0], "dirac") == 0) {
        if(*argc < 2) {
            cerr << "Not enough arguments to instantiate distro_dirac" << endl;
            return NULL;
        }
        double p = strtod((*argv)[1], NULL);
        *argc -= 2; *argv += 2;
        return (distro*) new distro_dirac(p);

    } else if(strcmp((*argv)[0], "kronecker") == 0) {
        if(*argc < 2) {
            cerr << "Not enough arguments to instantiate distro_kronecker" << endl;
            return NULL;
        }
        int val = strtol((*argv)[1], NULL, 10);
        *argc -= 2; *argv += 2;
        return (distro*) new distro_kronecker(val);

    } else if(strcmp((*argv)[0], "power") == 0) {
        if(*argc < 4) {
            cerr << "Not enough arguments to instantiate distro_power" << endl;
            return NULL;
        }
        double v_min = strtod((*argv)[1], NULL);
        double v_max = strtod((*argv)[2], NULL);
        double c_pow = strtod((*argv)[3], NULL);
        *argc -= 4; *argv += 4;
        return (distro*) new distro_power(v_min, v_max, c_pow);

    } else if(strcmp((*argv)[0], "exponential") == 0) {
        if(*argc < 3) {
            cerr << "Not enough arguments to instantiate distro_exponential" << endl;
            return NULL;
        }
        double v_min = strtod((*argv)[1], NULL);
        double lambda = strtod((*argv)[2], NULL);
        *argc -= 3; *argv += 3;
        return (distro*) new distro_exponential(v_min, lambda);
    }

    cerr << "Unable to parse the distro type" << endl;
    return NULL;
}

void distro_parse_test() {

    const char *argv_u[] = {"uniform", "-3.7", "-1.2", "welcome"};
    const char *argv_b[] = {"beta", "1.3", "0.7"};
    const char *argv_d[] = {"dirac", "0.765", "to", "the"};
    const char *argv_k[] = {"kronecker", "8", "jungle"};
    const char *argv_p[] = {"power", "1.0", "2.0", "3.0", "idiot!"};
    const char *argv_e[] = {"exponential", "2.0", "23.0"};
    char **argv_uniform = (char**) argv_u;
    char **argv_beta = (char**) argv_b;
    char **argv_dirac = (char**) argv_d;
    char **argv_kronecker = (char**) argv_k;
    char **argv_power = (char**) argv_p;
    char **argv_exponential = (char**) argv_e;

    int argc_uniform = 4;
    int argc_beta = 3;
    int argc_dirac = 4;
    int argc_kronecker = 3;
    int argc_power = 5;
    int argc_exponential = 3;

    distro_uniform* d_uniform = (distro_uniform*) distro_parse(&argc_uniform, &argv_uniform);
    distro_beta* d_beta = (distro_beta*) distro_parse(&argc_beta, &argv_beta);
    distro_dirac* d_dirac = (distro_dirac*) distro_parse(&argc_dirac, &argv_dirac);
    distro_kronecker* d_kronecker = (distro_kronecker*) distro_parse(&argc_kronecker, &argv_kronecker);
    distro_power* d_power = (distro_power*) distro_parse(&argc_power, &argv_power);
    distro_exponential* d_exponential = (distro_exponential*) distro_parse(&argc_exponential, &argv_exponential);

    assert(argc_uniform == 1);
    assert(argc_beta == 0);
    assert(argc_dirac == 2);
    assert(argc_kronecker == 1);
    assert(argc_power == 1);
    assert(argc_exponential == 0);

    assert(strcmp(argv_uniform[0], "welcome") == 0);
    assert(strcmp(argv_dirac[0], "to") == 0);
    assert(strcmp(argv_dirac[1], "the") == 0);
    assert(strcmp(argv_kronecker[0], "jungle") == 0);
    assert(strcmp(argv_power[0], "idiot!") == 0);

    assert(d_uniform != NULL);
    assert(d_beta != NULL);
    assert(d_dirac != NULL);
    assert(d_kronecker != NULL);
    assert(d_power != NULL);
    assert(d_exponential != NULL);

    assert(d_uniform->min_val == -3.7);
    assert(d_uniform->max_val == -1.2);
    assert(d_beta->alpha == 1.3);
    assert(d_beta->beta == 0.7);
    assert(d_dirac->p == 0.765);
    assert(d_kronecker->val == 8);
    assert(d_power->v_min == 1.0);
    assert(d_power->v_max == 2.0);
    assert(d_power->c_pow == 3.0);
    assert(d_exponential->v_min == 2.0);
    assert(d_exponential->lambda == 23.0);

    delete d_uniform;
    delete d_beta;
    delete d_dirac;
    delete d_kronecker;
    delete d_power;
    delete d_exponential;
}
