import os
import shutil
import math
import torch as th
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from matplotlib.patches import *
from matplotlib.lines import Line2D
from tueplots import bundles, fontsizes
plt.rcParams.update(bundles.neurips2024())
fontsizes.neurips2024()

base_dir = 'output/density_estimation'
scms = ['simpson_nlin', 'triangle_nlin', 'largebd_nlin']
models = ['maf', 'nice']
methods = ['exom', 'causal_nf', 'rs', 'rs_causal_nf']
js = [1, 3, 5]
seeds = [0, 7, 42, 3407, 65535]
dims = {
    'simpson_nlin': 4,
    'triangle_nlin': 3,
    'largebd_nlin': 9,
}


def get_log(method, scm, model, j, seed, seed2):
    if method == 'exom':
        sub_dir = f'm={model},j={j},c=e+t.w_e.w_t,k=mb1.mb1.em,r=attn,t=5,h=64x2'
    elif method == 'causal_nf':
        sub_dir = f'm={model},j={j},c=e+t.w_e.w_t,k=mb1.mb1.em,r=attn,t=5,h=64x2'
    elif method == 'rs':
        sub_dir = f'j={j}'
    elif method == 'rs_causal_nf':
        sub_dir = f'j={j}'
    sub_dir = f'{method}/{scm}/' + sub_dir
    if seed > 0:
        sub_dir += f',{seed}'
    if seed2 > 0:
        sub_dir += f',cfs={seed2}'
    if not os.path.exists(os.path.join(base_dir, sub_dir, 'logs/val_logs.pt')):
        return False
    res = th.load(os.path.join(base_dir, sub_dir, 'logs/val_logs.pt'))
    return res


def get_estimate(log):
    last = list(log.keys())[-1]
    return th.cat([
        log[last][i]['estimate'] for i in range(len(log[last]))
    ], dim=-1).detach().cpu().numpy()


def df_csv_o(path):
    if os.path.exists(path):
        return pd.read_csv(path)
    data = {
        'scm': [],
        'method': [],
        'model': [],
        'seed': [],
        'j': [],
        'id': [],
        'estimate': [],
    }

    def add_log(scm, method, model):
        for j in js:
            print(scm, method, model, j)
            for s in seeds:
                log = get_log(method, scm, model, j, s, 0)
                if not log:
                    continue
                log = get_estimate(log)
                for i, estimate in enumerate(log):
                    data['scm'].append(scm)
                    data['method'].append(method)
                    data['model'].append(model)
                    data['seed'].append(s)
                    data['j'].append(j)
                    data['id'].append(i)
                    data['estimate'].append(estimate)

    for scm in scms:
        for method in ['rs', 'exom']:
            if method == 'rs':
                add_log(scm, method, '-')
                continue
            for model in models:
                add_log(scm, method, model)

    df = pd.DataFrame.from_dict(data)
    df.to_csv(path)
    return df


dfo = df_csv_o('script/figure/density_estimation_original.csv')


def df_csv_p(path):
    if os.path.exists(path):
        return pd.read_csv(path)
    data = {
        'scm': [],
        'method': [],
        'model': [],
        'seed': [],
        'seed2': [],
        'j': [],
        'id': [],
        'estimate': [],
    }

    def add_log(scm, method, model):
        for j in js:
            print(scm, method, model, j)
            for s in seeds:
                for s2 in seeds:
                    log = get_log(method, scm, model, j, s, s2)
                    if not log:
                        continue
                    log = get_estimate(log)
                    for i, estimate in enumerate(log):
                        data['scm'].append(scm)
                        data['method'].append(method)
                        data['model'].append(model)
                        data['seed'].append(s)
                        data['seed2'].append(s2)
                        data['j'].append(j)
                        data['id'].append(i)
                        data['estimate'].append(estimate)

    for scm in scms:
        for method in ['rs_causal_nf', 'causal_nf']:
            if method == 'rs_causal_nf':
                add_log(scm, method, '-')
                continue
            for model in models:
                add_log(scm, method, model)

    df = pd.DataFrame.from_dict(data)
    df.to_csv(path)
    return df


dfp = df_csv_p('script/figure/density_estimation_proxy.csv')


dim = {}
for scm in scms:
    dim[scm] = {}
    for j in js:
        val_dataset_path = f'script/counterfactual/density_estimation/val_saves/{scm}_{j}.pt'
        v = th.load(val_dataset_path)
        p = v['w_e_batched']
        p = p.reshape(-1, j * p.size(-1)).float().sum(dim=-1)
        dim[scm][j] = p.detach().cpu().numpy()


def select(df: pd.DataFrame, scm, j, method, model=None, seed2=None):
    if seed2 is None:
        if model is None:
            return df[(df['scm'] == scm) &
                      (df['j'] == j) &
                      (df['method'] == method)
                      ][['id', 'seed', 'estimate']]
        else:
            return df[(df['scm'] == scm) &
                      (df['j'] == j) &
                      (df['method'] == method) &
                      (df['model'] == model)
                      ][['id', 'seed', 'estimate']]
    else:
        if model is None:
            return df[(df['scm'] == scm) &
                      (df['j'] == j) &
                      (df['method'] == method) &
                      (df['seed2'] == seed2)
                      ][['id', 'seed', 'estimate']]
        else:
            return df[(df['scm'] == scm) &
                      (df['j'] == j) &
                      (df['method'] == method) &
                      (df['model'] == model) &
                      (df['seed2'] == seed2)
                      ][['id', 'seed', 'estimate']]


def to_exp_np(df: pd.DataFrame):
    X = df.sort_values(['id']).to_numpy()
    n = (X[:, 0].max().astype(int) + 1)
    X = X[:, -1].reshape(n, -1)
    return np.exp(X)


def ci_95(X: np.ndarray, dim: np.ndarray):
    Y = X.copy()
    for i in range(X.shape[-1]):
        Y[:, i] = X[:, i] ** (1/dim)
    mask = Y == 0
    Y = np.ma.masked_array(Y, mask)
    return 2 * (Y.std(axis=-1).mean(axis=0))


def table(scms=['simpson_nlin', 'triangle_nlin', 'largebd_nlin']):
    data = {
        'item': [],
        'scm': [],
        'j': [],
        'ci95': [],
    }

    def add_data(item, scm, j, ci95):
        data['item'].append(item)
        data['scm'].append(scm)
        data['j'].append(j)
        data['ci95'].append(ci95)

    def bias_o(model: str):
        for scm in scms:
            for j in js:
                if model == 'rs':
                    Y = select(dfo, scm, j, 'rs')
                else:
                    Y = select(dfo, scm, j, 'exom', model=model)
                x = ci_95(to_exp_np(Y), dim[scm][j])
                add_data(f'{model}_o', scm, j, x)

    def bias_p(model: str):
        for scm in scms:
            for j in js:
                x1 = []
                for s2 in seeds:
                    if model == 'rs':
                        Y = select(dfp, scm, j, 'rs_causal_nf', seed2=s2)
                    else:
                        Y = select(dfp, scm, j, 'causal_nf',
                                   model=model, seed2=s2)
                    if len(Y) > 0:
                        x1.append(ci_95(to_exp_np(Y), dim[scm][j]))
                if len(x1) == 0:
                    continue
                x = np.mean(x1)
                add_data(f'{model}_p', scm, j, x)

    bias_o('rs')
    bias_p('rs')
    bias_o('maf')
    bias_p('maf')
    bias_o('nice')
    bias_p('nice')

    return pd.DataFrame.from_dict(data)


table().to_csv('script/figure/denstiy_ci95.csv')


tmpl = """
\\label{{tab:2}}
\\centering
\\begin{{tabular}}{{ccccccccc}}
    \\toprule
    \\multicolumn{{2}}{{c}}{{}} & \\multicolumn{{3}}{{c}}{{SIMPSON-NLIN}} & \\multicolumn{{4}}{{c}}{{FAIRNESS}}\\\\
    \\cmidrule(r){{3-5}} \\cmidrule(r){{6-9}}
    Method & SCM & $|s|=1$ & $|s|=3$ & $|s|=5$ & ATE & ETT & NDE & CtfDE\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{RS}} & O & {} & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {} & {}\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[MAF]}} & O & {} & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {} & {}\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[NICE]}} & O & {} & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {} & {}\\\\
    \\bottomrule
\\end{{tabular}}
"""


def tab2():
    df1 = pd.read_csv('script/figure/denstiy_ci95.csv')
    if not os.path.exists('script/figure/effect_ci95.csv'):
        return
    df2 = pd.read_csv('script/figure/effect_ci95.csv')
    values = []
    for method in ['rs', 'maf', 'nice']:
        for i in ['o', 'p']:
            for j in [1, 3, 5]:
                scm = 'simpson_nlin'
                result = df1[(df1['item'] == f'{method}_{i}') &
                             (df1['j'] == j) &
                             (df1['scm'] == scm)][['ci95']]
                val = result.to_numpy().item()
                if val == math.nan:
                    values.append('-')
                else:
                    val = float(val)
                    values.append('$\pm{:0.3f}$'.format(val))
            for q in ['ate', 'ett', 'nde', 'ctfde']:
                scm = 'fairness'
                result = df2[(df2['item'] == f'{method}_{i}') &
                             (df2['q'] == q) &
                             (df2['scm'] == scm)][['ci95']]
                val = result.to_numpy().item()
                val = float(val)
                if str(val) == 'nan':
                    values.append('-')
                else:
                    values.append('$\pm{:0.3f}$'.format(val))
    tab2 = tmpl.format(*values)
    with open('script/figure/tabs/tab2.tex', 'w+', encoding='utf-8') as f:
        f.write(tab2)


tab2()


tmpl = """
\\label{{tab:91}}
\\centering
\\begin{{tabular}}{{cccccccc}}
    \\toprule
    \\multicolumn{{2}}{{c}}{{}} & \\multicolumn{{3}}{{c}}{{SIMPSON-NLIN}} & \\multicolumn{{2}}{{c}}{{TRIANGLE-NLIN}} & \\multicolumn{{1}}{{c}}{{LARGEBD-LIN}}\\\\
    \\cmidrule(r){{3-5}} \\cmidrule(r){{6-7}} \\cmidrule(r){{8-8}}
    Method & SCM & $|s|=1$ & $|s|=3$ & $|s|=5$ & $|s|=1$ & $|s|=3$ & $|s|=1$ \\\\
    \\midrule
    \\multirow{{2}}{{*}}{{RS}} & O & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {}\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[MAF]}} & O & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {}\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[NICE]}} & O & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {}\\\\
    \\bottomrule
\\end{{tabular}}
"""


def tabc91():
    df1 = pd.read_csv('script/figure/denstiy_ci95.csv')
    values = []
    for method in ['rs', 'maf', 'nice']:
        for i in ['o', 'p']:
            for scm in ['simpson_nlin', 'triangle_nlin', 'largebd_nlin']:
                for j in [1, 3, 5]:
                    if scm == 'triangle_nlin' and j > 3:
                        continue
                    if scm == 'largebd_nlin' and j > 1:
                        continue
                    result = df1[(df1['item'] == f'{method}_{i}') &
                                 (df1['j'] == j) &
                                 (df1['scm'] == scm)][['ci95']]
                    if len(result) == 0:
                        s = '-'
                    else:
                        val = result.to_numpy().item()
                        val = float(val)
                        if str(val) == 'nan':
                            s = '-'
                        else:
                            s = '$\pm{:0.3f}$'.format(val)
                    values.append(s)
    tabc91 = tmpl.format(*values)
    with open('script/figure/tabs/tabc91.tex', 'w+', encoding='utf-8') as f:
        f.write(tabc91)


tabc91()
