import os
import pickle as pkl

import botorch
import torch
from botorch import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.optim.fit import fit_gpytorch_mll_torch
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.models import  ExactGP
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel
from gpytorch.distributions import MultivariateNormal
import pandas as pd
import numpy as np

from pathlib import Path
import os, sys
ROOT = str(Path(os.path.realpath(__file__)).parent.parent)
sys.path.insert(0, ROOT)

from nap.RL.utils_gp import MixtureKernel


class ExactMixedTypeGPModel(ExactGP):
    def __init__(self, train_inputs, train_targets, likelihood, cat_dims, cont_dims):
        super(ExactMixedTypeGPModel, self).__init__(train_inputs, train_targets, likelihood)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(MixtureKernel(
            categorical_dims=cat_dims,
            continuous_dims=cont_dims,
        ))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


if __name__ == '__main__':
    # Fixed from MIP code GP model directly to get variables in exactly the same order
    CAT_NAMES = ['branching/scorefunc', 'branching/preferbinary', 'branching/lpgainnormalize', 'lp/pricing', 'nodeselection/childsel', 'separating/rapidlearning/freq', 'separating/flowcover/freq', 'separating/cmir/freq', 'separating/aggregation/freq', 'separating/gomory/freq', 'separating/impliedbounds/freq', 'separating/strongcg/freq', 'separating/zerohalf/freq', 'separating/clique/freq', 'separating/disjunctive/freq', 'separating/mcf/freq', 'separating/cgmip/freq', 'separating/convexproj/freq', 'separating/eccuts/freq', 'separating/gauge/freq', 'separating/oddcycle/freq', 'heuristics/oneopt/freq', 'heuristics/rounding/freq', 'heuristics/simplerounding/freq', 'heuristics/indicator/freq', 'heuristics/trysol/freq', 'heuristics/zirounding/freq', 'heuristics/subnlp/freq', 'heuristics/adaptivediving/freq', 'heuristics/pscostdiving/freq', 'heuristics/conflictdiving/freq', 'heuristics/nlpdiving/freq', 'heuristics/veclendiving/freq', 'heuristics/shifting/freq', 'heuristics/distributiondiving/freq', 'heuristics/farkasdiving/freq', 'heuristics/fracdiving/freq', 'heuristics/guideddiving/freq', 'heuristics/intshifting/freq', 'heuristics/linesearchdiving/freq', 'heuristics/lpface/freq', 'heuristics/alns/freq', 'heuristics/feaspump/freq', 'heuristics/gins/freq', 'heuristics/objpscostdiving/freq', 'heuristics/rootsoldiving/freq', 'heuristics/randrounding/freq', 'heuristics/rins/freq', 'heuristics/crossover/freq', 'heuristics/mpec/freq', 'heuristics/clique/freq', 'heuristics/multistart/freq', 'heuristics/rens/freq', 'heuristics/trivial/freq', 'heuristics/vbounds/freq', 'heuristics/shiftandpropagate/freq', 'heuristics/completesol/freq', 'heuristics/locks/freq', 'heuristics/ofins/freq', 'heuristics/padm/freq', 'heuristics/reoptsols/freq', 'heuristics/trivialnegation/freq', 'heuristics/undercover/freq', 'heuristics/actconsdiving/freq', 'heuristics/bound/freq', 'heuristics/coefdiving/freq', 'heuristics/dins/freq', 'heuristics/dualval/freq', 'heuristics/fixandinfer/freq', 'heuristics/intdiving/freq', 'heuristics/octane/freq', 'heuristics/proximity/freq', 'heuristics/repair/freq', 'heuristics/localbranching/freq', 'heuristics/mutation/freq', 'heuristics/zeroobj/freq', 'heuristics/trustregion/freq', 'heuristics/twoopt/freq', 'lp/initalgorithm', 'lp/resolvealgorithm', 'lp/presolving', 'branching/treemodel/enable', 'branching/delaypscostupdate', 'branching/divingpscost', 'constraints/agelimit', 'constraints/obsoleteage', 'constraints/disableenfops', 'heuristics/useuctsubscip', 'propagating/dualfix/maxprerounds', 'propagating/genvbounds/maxprerounds', 'propagating/obbt/maxprerounds', 'propagating/nlobbt/maxprerounds', 'propagating/probing/maxprerounds', 'propagating/pseudoobj/maxprerounds', 'propagating/redcost/maxprerounds', 'propagating/rootredcost/maxprerounds', 'propagating/symmetry/maxprerounds', 'propagating/vbounds/maxprerounds', 'propagating/dualfix/freq', 'propagating/genvbounds/freq', 'propagating/obbt/freq', 'propagating/nlobbt/freq', 'propagating/probing/freq', 'propagating/pseudoobj/freq', 'propagating/redcost/freq', 'propagating/rootredcost/freq', 'propagating/symmetry/freq', 'propagating/vbounds/freq', 'presolving/boundshift/maxrounds', 'presolving/convertinttobin/maxrounds', 'presolving/domcol/maxrounds', 'presolving/dualagg/maxrounds', 'presolving/dualcomp/maxrounds', 'presolving/dualinfer/maxrounds', 'presolving/gateextraction/maxrounds', 'presolving/implics/maxrounds', 'presolving/inttobinary/maxrounds', 'presolving/redvub/maxrounds', 'presolving/trivial/maxrounds', 'presolving/tworowbnd/maxrounds', 'presolving/sparsify/maxrounds', 'presolving/dualsparsify/maxrounds', 'presolving/stuffing/maxrounds']
    CONT_NAMES = ['branching/scorefac', 'branching/clamp', 'branching/midpull', 'branching/midpullreldomtrig', 'lp/colagelimit', 'lp/rowagelimit', 'separating/minortho', 'separating/minorthoroot', 'separating/maxcuts', 'separating/maxcutsroot', 'separating/cutagelimit', 'separating/poolfreq']

    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('--cluster', default=None, type=str)
    args = parser.parse_args()
    if args.cluster is not None:
        clusters = [args.cluster]
    else:
        clusters = ['cluster0', 'cluster1', 'cluster2', 'cluster3', 'cluster4']
    mip2_data_root = os.path.join(ROOT, 'MIP_data')
    train_datasets = [d for d in os.listdir(os.path.join(mip2_data_root)) if '.csv' in d]
    conf = pd.read_csv(os.path.join(mip2_data_root, 'paras_extended_cat.csv'), index_col=0)
    cat_names = CAT_NAMES
    num_names = CONT_NAMES

    alphabet_hmap = {}
    for cn in cat_names:
        var_range = conf.loc[cn]['range']
        if ',' in var_range:
            var_range = var_range.split(',')
        alphabet_hmap[cn] = {c: i for i, c in enumerate(var_range)}
        if conf.loc[cn]['type'] == 'bool':
            alphabet_hmap[cn].update({'0': 0, '1': 1, 'False': 0, 'True': 1})
    pkl.dump(alphabet_hmap, open(os.path.join(mip2_data_root, 'alphabet_hmap.pkl'), "wb"))
    num_classes = []
    for k in alphabet_hmap:
        elements = []
        for kk in alphabet_hmap[k]:
            elements.append(alphabet_hmap[k][kk])
        num_classes.append(len(np.unique(elements)))
    pkl.dump(num_classes, open(os.path.join(mip2_data_root, 'alphabet_num_classes.pkl'), "wb"))

    for cluster in clusters:
        dataset = [d for d in train_datasets if cluster in d][0]
        gp_name = f'gp_{cluster}.pt'

        data = pd.read_csv(os.path.join(mip2_data_root, dataset))
        data_num = data[num_names].values
        data_num = (data_num - data_num.min(0)) / (data_num.max(0) - data_num.min(0))
        data_cat = data[cat_names].values
        for i in np.arange(data_cat.shape[0]):
            for j, n in enumerate(cat_names):
                try:
                    data_cat[i][j] = alphabet_hmap[n][data_cat[i][j]]
                except KeyError:
                    data_cat[i][j] = alphabet_hmap[n][str(data_cat[i][j])]
        data_cat = data_cat.astype(int)

        X = np.concatenate((data_num, data_cat), axis=-1)
        Y = 7200.01 - data['time'].values
        stdY = (Y - Y.mean()) / Y.std()

        data = {'domain': X, 'accs': Y}
        pkl.dump(data, open(os.path.join(mip2_data_root, f'data_{cluster}.pkl'), "wb"))
        print(f"Saved {os.path.join(mip2_data_root, f'data_{cluster}.pkl')}")

        if not os.path.exists(os.path.join(mip2_data_root, gp_name)):
            # Fit and save GP
            print(f'Fit GP on dataset {dataset} containing {X.shape[0]} points...')
            X = torch.from_numpy(X)#.to(device='cuda:0')
            stdY = torch.from_numpy(stdY)#.to(device='cuda:0')

            # Sub-sample dataset
            model = SingleTaskGP(
                train_X=X,
                train_Y=stdY.view(-1, 1),
                mean_module=ConstantMean(),
                covar_module=ScaleKernel(MixtureKernel(
                    continuous_dims=np.arange(X.shape[-1])[:len(num_names)].tolist(),
                    categorical_dims=np.arange(X.shape[-1])[len(num_names):].tolist()
                ))#.to('cuda:0')
            )
            # model.cuda(device=0)
            mll = ExactMarginalLogLikelihood(model.likelihood, model)#.cuda(device=0)

            try:
                mll.cpu()
                _ = fit_gpytorch_mll(mll=mll)
            except (RuntimeError, botorch.exceptions.errors.ModelFittingError) as e:
                print(e)
                try:
                    print('Try fit on GPU')
                    mll.cuda()
                    _ = fit_gpytorch_mll_torch(mll)
                except RuntimeError as e:
                    print(f'Error during the GP fit on {dataset}.')
                    X = X.cpu().numpy()
                    stdY = stdY.cpu().numpy()
                    model = model.cpu()
                    mll = mll.cpu()
                    del model, mll
                    torch.cuda.empty_cache()
                    continue

            with torch.no_grad():
                torch.save(model, os.path.join(mip2_data_root, gp_name))
            print(f"saved model at {os.path.join(mip2_data_root, gp_name)}")

            X = X.cpu()
            stdY = stdY.cpu()
            model = model.cpu()
            mll = mll.cpu()
            model.eval()
            del X, stdY, model, mll
            torch.cuda.empty_cache()

        else:
            print(f'{dataset} GP already fit and saved.')
