#ifndef INFER_H_INCLUDED
#define INFER_H_INCLUDED

#include <vector>
#include "crowd.h"

// list of the tasks a worker executed in the past
class worker_history {
public:
    std::vector<unsigned int> time_steps;

    bool operator==(const worker_history& other) const;
};

// inference algorithm (abstract class)
class infer {
public:
    std::vector<worker_history> work_hist;
    std::vector<double> task_odds;

    virtual ~infer(){};

    virtual void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t) = 0; // inference from scratch
    virtual void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t) = 0; // add one more label

    virtual double error_num() = 0;
    virtual double error_rate() = 0;
    virtual double error_predict() = 0;
};

// weighted majority voting (with access to the ground-truth worker accuracy)
class infer_weighted: public infer {
private:
    unsigned int t_next;
public:
    std::vector<double> work_weight;

    infer_weighted(unsigned int n_tasks);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// class infer_golden: public infer

// majority voting
class infer_majority: public infer {
private:
    unsigned int t_next;
public:
    double work_weight;

    // initialise work_prior = math_expit(1.0) for the classic (non-Bayesian) algorithm
    infer_majority(unsigned int n_tasks, double work_prior);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// acyclic Bayesian estimation
class infer_acyclic: public infer {
private:
    unsigned int t_next;
public:
    std::vector<double> task_post;
    std::vector<double> work_weight;
    double alpha;
    double beta;

    infer_acyclic(unsigned int n_tasks, double alpha, double beta);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// delayed acyclic Bayesian estimation
class infer_delayed: public infer {
private:
    unsigned int t_next;
public:
    std::vector<std::vector<double>> view_odds;
    std::vector<std::vector<double>> view_post;
    double alpha;
    double beta;

    infer_delayed(unsigned int n_tasks, double alpha, double beta);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// optimised acyclic Bayesian estimation
class infer_quick: public infer {
private:
    void dataset_pass(crowd* c, int n_bit, unsigned int t_max);
public:
    std::vector<std::vector<double>> sub_odds;
    std::vector<std::vector<double>> sub_post;
    double alpha;
    double beta;

    infer_quick(unsigned int n_tasks, double alpha, double beta);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// variational Bayesian estimation [Liu et al. 2012]
class infer_variational: public infer {
private:
    unsigned int t_next;
    std::vector<double> correct;

    void maximisation(crowd* c);
    void expectation(crowd* c);
public:
    std::vector<double> task_post;
    std::vector<double> work_weight;
    double alpha;
    double beta;
    unsigned int n_iter_full;
    unsigned int n_iter_update;

    infer_variational(unsigned int n_tasks, double alpha, double beta, unsigned int n_iter_full, unsigned int n_iter_update);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// matrix factorisation inference [Karger et al. 2012]
class infer_eigen: public infer {
private:
    unsigned int t_next;

    void power_iteration(crowd *c);
public:
    std::vector<double> work_weight;
    unsigned int n_iter_full;
    unsigned int n_iter_update;

    infer_eigen(unsigned int n_tasks, unsigned int n_iter_full, unsigned int n_iter_update);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// gibbs sampling on the exact marginal
class infer_montecarlo: public infer {
private:
    unsigned int t_next;

    void gibbs_sampling(std::mt19937& ran_gen, crowd* c);
public:
    std::vector<double> work_right;
    std::vector<double> work_wrong;
    std::vector<double> work_weight;
    std::vector<double> task_label;
    std::vector<double> task_sample;
    double alpha;
    double beta;
    unsigned int n_iter_full;
    unsigned int n_iter_update;

    infer_montecarlo(unsigned int n_tasks, double alpha, double beta, unsigned int n_iter_full, unsigned int n_iter_update);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// particle filter
class particle {
public:
    std::vector<double> task_cls;
    double log_weight;

    particle() {};
    particle(unsigned int n_tasks, std::mt19937& ran_gen);

    void gibbs_move(std::mt19937& ran_gen,
                    crowd* c,
                    unsigned int t_max,
                    double alpha,
                    double beta);
    void update_weight(std::vector<worker_history>& work_hist,
                       crowd* c,
                       unsigned int t_new,
                       double alpha,
                       double beta);
};

class infer_particle: public infer {
private:
    unsigned int t_next;
public:
    std::vector<particle> part_swarm;
    std::vector<double> task_post;

    double alpha;
    double beta;
    unsigned int n_particles;
    unsigned int n_updates;

    infer_particle(unsigned int n_tasks, double alpha, double beta, unsigned int n_particles, unsigned int n_updates);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// triangular estimation
class infer_triangle: public infer {
private:
    unsigned int t_next;
public:
    std::vector<std::vector<double>> work_corr; // C_{ij}
    std::vector<std::vector<double>> work_norm; // N_{ij}
    std::vector<std::vector<double>> work_prod; // M_{ij}
    std::vector<double> work_theta;
    std::vector<double> work_weight;

    unsigned int bool_iter_full;

    infer_triangle(unsigned int n_tasks, unsigned int bool_iter_full);

    void infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t);
    void infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t);

    double error_num();
    double error_rate();
    double error_predict();
};

// infer parser
infer* infer_parse(int *argc, char **argv[]);

// home-made unit tests
void infer_weighted_test();
void infer_majority_test();
void infer_acyclic_test();
void infer_delayed_test();
void infer_quick_test();
void infer_variational_test();
void infer_eigen_test();
void infer_montecarlo_test();
void infer_particle_test();
void infer_triangle_test();
void infer_parse_test();

// maths functions
double math_expit(double x);
double math_logit(double x);
double math_logit_safe(double x);
void math_normalise_1(std::vector<double>& vec);
double math_log_beta(double a, double b);

#define LGU_TBL_MAX 1024
void lgammau_init();
double lgammau(double n);

#endif // INFER_H_INCLUDED
