import multiprocessing as mp
import os.path
from multiprocessing import get_context

import pandas as pd

from active_ranking.base.model import ActiveLearner, ActiveClassificationLeaner
from active_ranking.base.model import ActiveLearnerOld, ActiveNaiveLeaner
from active_ranking.base.model import DTrackingLearner, PassiveLearner, \
    MessyRank
from active_ranking.scenarios import inputs
from active_ranking.scenarios.algorithms import Scenario


class Experiment:
    def __init__(
            self,
            method: callable, n_repeat: int,
            scenario_name: str,
    ):

        scenario = inputs.__dict__[scenario_name]

        self.method = method
        self.n_repeat = n_repeat
        self.n_0 = scenario["n_0"]
        self.n_max = scenario["n_max"]
        self.j_max = scenario["j_max"]
        self.d = scenario["d"]
        self.true_eta = scenario["eta"]
        self.result = {}
        self.name = f"{method.name}_" \
                    f"{scenario_name}"

    def run(self):
        results_list = []
        with get_context('spawn').Pool(
                processes=int(mp.cpu_count() * 0.8)) as pool:
            for i in range(self.n_repeat):
                pool.apply_async(
                    parallelizer,
                    args=(self.n_0, self.n_max, self.j_max,
                          self.d, self.true_eta, self.method),
                    callback=results_list.append)
            pool.close()
            pool.join()
        for i in range(self.n_repeat):
            try:
                self.result[i] = results_list[i]
            except Exception as e:
                print(e)
        self.__analysis()

    def __analysis(self):
        self._table = pd.DataFrame()
        for i in self.result.keys():
            df_ = pd.DataFrame(dict(
                id=i,
                inf_norm=self.result[i]["inf_norm"],
                one_norm=self.result[i]["one_norm"],
                n_sample=self.result[i]["n_sample"]
            ))
            self._table = pd.concat((self._table, df_), axis=0)
        self.stats()
        self.save()

    def save(self):
        if not os.path.exists("./results/"):
            os.makedirs("./results/")
        self._table.to_csv(f"./results/{self.name}", index=False)

    def load(self):
        self._table = pd.read_csv(f"./results/{self.name}")
        self.stats()

    def run_or_load(self):
        if os.path.exists(f"./results/{self.name}"):
            self.load()
        else:
            self.run()

    def stats(self):
        self.mean_n_sample = self._table.groupby(["n_sample"]).mean()
        self.std_n_sample = self._table.groupby(["n_sample"]).std()
        self.quantile_95 = self._table.groupby(["n_sample"]).std()
        self.min = self._table.groupby(["n_sample"]).min()
        self.max = self._table.groupby(["n_sample"]).max()


def parallelizer(n_0, n_max, j_max, d, true_eta, function):
    ret = function(n_0, n_max, j_max, d, true_eta)
    ret_dict = dict(
        inf_norm=ret.norm_infinity,
        one_norm=ret.norm_one,
        n_sample=ret.n_sample)
    return ret_dict


active_rank = Scenario(ActiveLearnerOld)
active_naive_rank = Scenario(ActiveNaiveLeaner)
passive_rank = Scenario(PassiveLearner)
active_d_tracking = Scenario(DTrackingLearner)
active_rank_new = Scenario(ActiveLearner)
active_classification = Scenario(ActiveClassificationLeaner)
messy_rank = Scenario(MessyRank)

models = [active_naive_rank,
          passive_rank,
          active_classification,
          messy_rank]
