import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm
import matplotlib.ticker as mticker
import math



def EJ(a, T, sigma):
    if a==0:
        return sigma**2 * (T**2) / 2
    else:
        return (sigma**2)/(4 * a**2)*(np.exp(2*a*T)-1-2*a*T)

def cubic_root(x):
    return math.copysign(math.pow(abs(x), 1.0/3.0), x)

def E1(T,h,a,o):
    top = (o ** 4) * ((-2 * a * h + np.exp(2 * a * h) - 1) ** 2) * ((np.exp(2 * a * T) - 1) ** 2)
    bottom = 16 * (a ** 4) * ((np.exp(2 * a * h) -1 ) ** 2)
    return top / bottom

def E2(T,h,a,o):
    inner_left = h * (np.exp(2 * a * T) - 1) * (4 * np.exp(2 * a * h) + np.exp(2 * a * T) + 1)
    inner_right = (np.exp(2 * a * h) - 1) * (np.exp(2 * a * h) + 4 * np.exp(2 * a * T) + 1) * T
    top = (o ** 4) * T * (inner_left - inner_right)
    bottom = 2 * (a ** 2) * ((np.exp(2 * a * h) - 1) ** 2)
    return top / bottom

def E1_h2(T,h,a,o):
    top = (np.exp(2 * a * T) - 1) ** 2 * (h**2)
    bottom = 16 * (a ** 2)
    term1 = top/bottom
    return (o ** 4)*term1

def E2_hinv(T,h,a,o):
    top =  -T * (5 + 4*a*T + np.exp(2*a*T)*(8*a*T - 4) - np.exp(4*a*T))
    bottom = 8* (a ** 4) * h
    term1 = top/bottom
    return (o ** 4)*term1

def E2_h2_B_inv(T,h,a,o):
    top =  -T * (5 + 4*a*T + np.exp(2*a*T)*(8*a*T - 4) - np.exp(4*a*T))
    bottom = 8* (a ** 4) * h
    term1 = top/bottom
    term2 = T*(1+4*a*T*np.exp(2*a*T) - np.exp(4*a*T)) / (4*(a**3))
    term3 = - T*(1+4*a*T+np.exp(2*a*T)*(8*a*T+4) -5*np.exp(4*a*T))*h/(24 * a**2)
    term4 = - T*(np.exp(4*a*T)-1)* h**2 / (12*a)
    return (o ** 4)*(term1 + term2 + term3 + term4)

def MSE_approx(T_, h_, a_, o_, B_):
    e1 = E1_h2(T_,h_,a_,o_)
    e2 = E2_hinv(T_,h_,a_,o_)
    return e1 + e2/B_

def MSE_approx_higher_acc(T_, h_, a_, o_, B_):
    e1 = E1_h2(T_,h_,a_,o_)
    e2 = E2_h2_B_inv(T_,h_,a_,o_)
    return e1 + e2/B_

def MSE(T_, h_, a_, o_, B_):
    if a_ < 0:
        e1 = E1(T_,h_,a_,o_)
        e2 = E2(T_,h_,a_,o_)
    else:
        e1 = o_**4 * T_**2 /4 * h_**2
        e2 = o_**4 * T_**5 / (3*h_) + o_**4 * T_**2 *(-2*T_**2 + 2*h_*T_ - h_**2)
    return e1 + e2/B_

def exact_MSE(h_list_, B_, T_, a_, o_):
    # takes a list of h and output the MSE
    MSE_list = []
    for h_ in h_list_:
        MSE_h = MSE(T_, h_, a_, o_, B_)
        MSE_list.append(MSE_h)
    return MSE_list

def approx_MSE(h_list_, B_, T_, a_, o_):
    # takes a list of h and output the MSE
    MSE_list = []
    for h_ in h_list_:
        MSE_h = MSE_approx(T_, h_, a_, o_, B_)
        MSE_list.append(MSE_h)
    return MSE_list

def approx_MSE_higher_acc(h_list_, B_, T_, a_, o_):
    # takes a list of h and output the MSE
    MSE_list = []
    for h_ in h_list_:
        MSE_h = MSE_approx_higher_acc(T_, h_, a_, o_, B_)
        MSE_list.append(MSE_h)
    return MSE_list

def y_over_h_vary_a(paper=True, analytical_MSE=False, exact=True):
    print("running objective over h, with fixed b and varying a")
    ft_size = 20
    for B in B_list:
        for T in T_list:
            N0 = 2**16
            h0 = T/N0 # 2^-16
            # B = 2**12
            imax = 16
            #M0 = int(B*(2**imax)/N0) # longest number of trajectories that we need
            sigma = 1
            h_list = 2**np.arange(1,imax+1) * h0 # [2^-17, ... 2^-2]
            print(f"List of h: h = {h_list}")

            # load data
            if analytical_MSE and not exact:
                a_list.remove(0.0) # ignore a=0 since approximate mse doesn't make sense
            print(f"List of a: a = {a_list}")
            print(f"The equivalent a in DT: {np.exp(a_list)}")

            N_trials = 50
            Vhat_multi = np.zeros((len(a_list), N_trials, len(h_list)))

            path = 'data'
            for i,a in enumerate(a_list):
                npy_name = os.path.join(path, f'Vhat_a_{a}_T_{T}_B_{B}_seed_0.npy')
                print(f'loading {npy_name}')
                try:
                    with open(npy_name, 'rb') as f:
                        Vhat = np.load(f)
                except:
                    print(f"file {npy_name}, not found")
                    assert False
                    continue
                Vhat_multi[i, :, :] = Vhat

            # analyze result
            # This is the objective function that we are looking to minimize
            fig, ax = plt.subplots()
            if paper: # exclude some h depending on the data budget (avoid truncation)
                if B==4096.0:
                    h_list = h_list[3:]
                    Vhat_multi = Vhat_multi[:,:,3:]
                elif B==8192.0:
                    h_list = h_list[2:]
                    Vhat_multi = Vhat_multi[:,:,2:]
                elif B==16384.0:
                    h_list = h_list[1:]
                    Vhat_multi = Vhat_multi[:,:,1:]

            log_val = True
            stde = True

            for i,a in enumerate(a_list):
                V = EJ(a, T, sigma)
                data = (Vhat_multi[i] - V)**2
                obj = np.mean(data, axis=0)
                obj_stde = np.std(data, axis=0)/np.sqrt(data.shape[0])
                line=ax.plot(h_list, obj,'-o', label=f'a={a}', markeredgewidth=1.5,
                        alpha=0.8, markeredgecolor=(0,0,0,0.2) ) # lower the frequency of data when plotting
                ax.fill_between(h_list, obj+obj_stde, obj-obj_stde, alpha=0.1)

                # overlay analytical MSE and h*
                if analytical_MSE:
                    if exact:
                        obj_analytical = exact_MSE(h_list, B, T, a, sigma)
                    else:
                        #obj_analytical = approx_MSE(h_list, B, T, a, sigma)
                        obj_analytical = approx_MSE_higher_acc(h_list, B, T, a, sigma)
                    ax.plot(h_list, obj_analytical, '--', alpha=0.8, color=line[0].get_color())

            ax.set_yscale('log')
            ax.set_xscale('log')
            ax.set_title(f'T={T:.0f}, B={B:.0f}', fontsize=ft_size)
            ax.set_ylabel(r'$(\hat{V}_M(h) - V)^2$', fontsize=ft_size)
            ax.set_xlabel("h", fontsize=ft_size)
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.legend(fontsize=ft_size-12, loc=4, fancybox=True, framealpha=0.5)
            ax.tick_params(labelsize=ft_size-4)
            stde_str = '_stde' if stde else ''
            if analytical_MSE:
                if exact:
                    analytical_MSE_suffix = '_exact_MSE'
                else:
                    analytical_MSE_suffix = '_approx_MSE'
            else:
                analytical_MSE_suffix = ''
            fname = f'obj_T_{T}_B_{B}{stde_str}{analytical_MSE_suffix}'
            fig.savefig(fname+'.pdf', bbox_inches='tight', dpi=300)

def y_over_h_vary_b(paper=True, analytical_MSE=False, exact=True):
    print("running objective over h, with fixed a and varying b")
    for a in a_list:
        for T in T_list:
            N0 = 2**16
            h0 = T/N0 # 2^-16 * T
            # B = 2**12
            imax = 16
            # M0 = int(B*(2**imax)/N0) # longest number of trajectories that we need
            sigma = 1
            h_list = 2**np.arange(1,imax+1) * h0 # [2^-15, ... 2^0] * T
            print(f"List of h: h = {h_list}")

            # load data
            # print(f"List of B: B = {B_list}")

            N_trials = 50
            Vhat_multi = np.zeros((len(B_list), N_trials, len(h_list)))

            path = 'data'
            for i,B in enumerate(B_list):
                #npy_name = f'Vhat_a_{a}_T_{T}_seed_0.npy'
                npy_name = os.path.join(path, f'Vhat_a_{a}_T_{T}_B_{B}_seed_0.npy')
                print(f'loading {npy_name}')
                try:
                    with open(npy_name, 'rb') as f:
                        Vhat = np.load(f)
                except:
                    print(f"file {npy_name}, not found")
                    assert False
                    continue
                Vhat_multi[i, :, :] = Vhat

            # analyze result
            # This is the objective function that we are looking to minimize
            log_val = False #True # False #True
            stde = True
            fs = 20
            fig, ax = plt.subplots()
            V = EJ(a, T, sigma)
            for i,B in enumerate(B_list):
                if paper:
                    h_list_bk = h_list
                    if B==4096.0:
                        h_list = h_list[3:]
                        data = (Vhat_multi[i, :, 3:] - V)**2
                    elif B==8192.0:
                        h_list = h_list[2:]
                        data = (Vhat_multi[i, :, 2:] - V)**2
                    elif B==16384.0:
                        h_list = h_list[1:]
                        data = (Vhat_multi[i, :, 1:] - V)**2
                    else:
                        data = (Vhat_multi[i, :, :] - V)**2
                else:
                    data = (Vhat_multi[i, :, :] - V)**2
                if log_val:
                    data = np.log10(data)
                obj = np.mean(data, axis=0)
                nb_runs = data.shape[0]
                print(f"nb_runs: {nb_runs}")
                obj_stde = np.std(data, axis=0) / np.sqrt(nb_runs)
                line = ax.plot(h_list, obj,'-o', label=f'B={B:.0f}', markeredgewidth=1.5,
                        alpha=0.8, markeredgecolor=(0,0,0,0.2) ) # lower the frequency of data when plotting
                ax.fill_between(h_list, obj+obj_stde, obj-obj_stde, alpha=0.1)
                if paper:
                    h_list = h_list_bk


                # overlay analytical MSE and h*
                if analytical_MSE:
                    if exact:
                        obj_analytical = exact_MSE(h_list, B, T, a, sigma)
                    else:
                        obj_analytical = approx_MSE_higher_acc(h_list, B, T, a, sigma)
                        #obj_analytical = approx_MSE(h_list, B, T, a, sigma)
                    ax.plot(h_list, obj_analytical, '--', alpha=0.8, color=line[0].get_color())
             
            if not log_val:
                ax.set_yscale('log')
                ax.set_ylabel(r'$(\hat{V}_M(h) - V)^2$', fontsize=fs)
                ax.set_title(f'T={T:.0f}, a={a}', fontsize=fs)
            else:
                ax.set_title(r'$\log(\hat{V}_M(h) - V)^2$'+f', T={T:.0f}, a={a}')
            ax.set_xscale('log')
            #ax.set_title(r'$\mathbb{E}[(\hat{V}_M(h) - V)^2]$'+f', T={T:.0f}, a={a}')
            ax.set_xlabel("h", fontsize=fs)
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.tick_params(labelsize=fs-4)
            ax.legend(fontsize=fs-8, loc='lower left', fancybox=True, framealpha=0.5)
            stde_suffix = '_stde' if stde else ''
            log_suffix ='_log' if log_val else ''
            if analytical_MSE:
                if exact:
                    analytical_MSE_suffix = '_exact_MSE'
                else:
                    analytical_MSE_suffix = '_approx_MSE'
            else:
                analytical_MSE_suffix = ''
            fname = \
            f'obj_T_{T}_a_{a}_B{stde_suffix}{log_suffix}{analytical_MSE_suffix}'
            fig.savefig(fname+'.pdf', bbox_inches='tight', dpi=300)


T_list = [8.0]
B_list = [4096.0, 8192.0, 16384.0, 32768.0, 65536.0]
a_list = [-1.0]
y_over_h_vary_b() # Fig. 1(a)
y_over_h_vary_b(True, analytical_MSE=True, exact=True) # Fig. 5(a)
y_over_h_vary_b(True, analytical_MSE=True, exact=False) # Fig. 5(b)

B_list = [16384.0]
a_list = [-16.0, -8.0, -4.0, -2.0, -1.0, -0.5, -0.25, 0.0]
y_over_h_vary_a() # Fig. 1(b)
y_over_h_vary_a(True, analytical_MSE=True, exact=True) # Fig. 5(c)
y_over_h_vary_a(True, analytical_MSE=True, exact=False) # Fig. 5(d)

