from tqdm import tqdm
import pandas as pd
import cvxpy as cp
import numpy as np

import pickle
from itertools import combinations
from collections import defaultdict

def load_data(path, env_name):
    d = pickle.load(open(path + env_name + f'/results_' + env_name + '.p', 'rb'))
    return d

def get_policy_perf(d):
    policy_perf = d['pi_e_value']
    return policy_perf

def get_ope_bootstrap(d, ope_name, i):
    # return (num_datasets, num_bootstrap)
    return d['all_bs_scores'][ope_name][i, :]

def get_ope_score(d, ope_name, i):
    # return (num_datasets)
    return d['all_true_scores'][ope_name][i]

def solve_for_alpha(ope_scores, error_matrix_A, verbose=True):
    n = error_matrix_A.shape[0]

    x = cp.Variable((n, 1))
    objective = cp.Minimize(cp.quad_form(x, error_matrix_A))
    constraints = [cp.sum(x) == 1]
    # prob = cp.Problem(objective, constraints)
    prob = cp.Problem(objective, constraints)
    if verbose:
        print('cvxpy loss: ', prob.solve())
    else:
        prob.solve()

    alpha = x.value.flatten()

    ope_scores = ope_scores.flatten()

    score = (ope_scores * alpha).sum()

    return alpha, score


def create_mse_matrix(ope_scores, ope_bootstrapped_scores):
    n_estimators = ope_scores.shape[0]
    num_bootstrap = ope_bootstrapped_scores.shape[1]
    pre_A = np.zeros((n_estimators, num_bootstrap))
    est_variance = np.zeros(n_estimators)
    est_bias = np.zeros(n_estimators)

    # not adding the scale_ratio (we don't have n, k yet, and it doesn't change optimization, only helps MSE)

    for j in range(n_estimators):
        # est_bias = ope_bootstrapped_scores[j, :].mean() - ope_scores[j]
        est_bias[j] = (ope_bootstrapped_scores[j, :] - ope_scores[j]).mean()
        est_variance[j] = np.mean((ope_bootstrapped_scores[j, :] - ope_scores[j]) ** 2)
        pre_A[j, :] = ope_bootstrapped_scores[j, :] - ope_scores[j]

    # TODO: we need to do the scaling here
    error_matrix_A = (1 / num_bootstrap) * np.matmul(pre_A, pre_A.T)
    ope_mse = np.diagonal(error_matrix_A)

    return error_matrix_A, ope_mse, est_bias, est_variance


def run_experiment(path, env_names, estimator_names):
    # collect data to save
    alpha_mat = {}
    score_mat = {}
    switch_ope_mat = {}
    avg_ope_mat = {}
    magic_ope_mat = {}

    ensemble_ope_mat = {}
    ensemble_mse_mat = {}

    true_mdp_sample_times = 10
    n_copies = 100

    for env_name in tqdm(env_names):

        d = load_data(path, env_name)

        # we add component OPEs here
        ensemble_ope_mat[env_name] = defaultdict(list)
        ensemble_mse_mat[env_name] = {}

        # get all the data
        for i in range(true_mdp_sample_times):
            ope_scores = np.zeros(len(estimator_names))
            ope_bootstrapped_scores = np.zeros((len(estimator_names), n_copies))
            for j, ope_name in enumerate(estimator_names):
                ope_scores[j] = get_ope_score(d, ope_name, i)
                ope_bootstrapped_scores[j, :] = get_ope_bootstrap(d, ope_name, i)

            error_matrix_A, ope_mse, est_bias, est_variance = create_mse_matrix(ope_scores, ope_bootstrapped_scores)

            alpha, score = solve_for_alpha(ope_scores, error_matrix_A)

            ensemble_ope_mat[env_name]['OPERA'].append(score)

            mse_smallest_idx = np.argmin(ope_mse)
            ensemble_ope_mat[env_name]['SwitchOPE'].append(ope_scores[mse_smallest_idx])
            ensemble_ope_mat[env_name]['AvgOPE'].append(ope_scores.mean())

            for j, ope_name in enumerate(estimator_names):
                ensemble_ope_mat[env_name][ope_name].append(ope_scores[j])

        # then we compute MSE outside
        true_perf = get_policy_perf(d)
        for ope_name, ope_scores in ensemble_ope_mat[env_name].items():
            ensemble_mse_mat[env_name][ope_name] = np.mean((np.array(ope_scores) - true_perf) ** 2)

    pickle.dump(ensemble_ope_mat, open("graph_results/ensemble_ope_mat.pkl", "wb"))
    pickle.dump(ensemble_mse_mat, open("graph_results/ensemble_mse_mat.pkl", "wb"))

    opera_across = []
    switch_across = []
    avg_across = []

    # create the MSE csv
    header = ['env_name'] + ['OPERA', 'SwitchOPE', 'AvgOPE'] + [f'{e}' for e in estimator_names]
    rows = []
    for env_name, ope_name_to_mse in ensemble_mse_mat.items():
        row = [env_name]

        for name in ['OPERA', 'SwitchOPE', 'AvgOPE'] + [f'{e}' for e in estimator_names]:
            row += [ope_name_to_mse[name]]

        rows.append(row)
        opera_across.append(ope_name_to_mse['OPERA'])
        switch_across.append(ope_name_to_mse['SwitchOPE'])
        avg_across.append(ope_name_to_mse['AvgOPE'])

    df = pd.DataFrame(rows, columns=header)
    df.to_csv(f'graph_results/{env_type}_mse.csv', index=False)

    return opera_across, switch_across, avg_across

if __name__ == '__main__':
    estimators =  ['IS', 'WIS']
    run_experiment(result_path, env_names, estimators)
