# this script takes in the name of an environment, 
# outputs the pdf/png figures for the paper

# Intermediate steps are:
# 1. read the rewards array 
# 2. compute V (run it only once and store it somewhere) a separate step
# 3. read V from step 2, compute Vhat, save it for each data budget B
# 4. read file from step 3, compute (Vhat-V)^2 and plot 


import sys
import glob 
from tqdm import trange
import argparse
import numpy as np
from dotmap import DotMap
from os.path import join, exists
from os import makedirs
from logging import info, basicConfig, INFO 
import gtimer as gt
from matplotlib import pyplot as plt

## map the env-id to the time_limit, dt ...

def setup(env_id, nb_runs, shuffle, seed, outdir, show_img):
    """ read the env_id and returns a dictionary of parameters """
    args = DotMap()
    # env_id_list = ["Ant-v3", "HalfCheetah-v3", "Swimmer-v3", "InvertedDoublePendulum-v2", "Hopper-v3", "Pusher-v2"]
    # env_id_list = ["Pendulum"]
    # env_id = env_id_list[0]
    env = env_id.strip().split('-')[0] # just get the env name with no version number
    env_name = {"Ant": "Ant-v3", "Swimmer": "Swimmer-v3",
                "Pusher": "Pusher-v2", "InvertedDoublePendulum": "InvertedDoublePendulum-v2",
                "HalfCheetah": "HalfCheetah-v3", "Hopper": "Hopper-v3", 
                "Pendulum": "Pendulum-v1", "bipedal_walker": "BipedalWalker-v3"}
    time_limit = {"Ant": 50, "Swimmer": 40, "Reacher": 20, "InvertedPendulum":40,
                "Walker2d": 8, "Pusher": 50, "InvertedDoublePendulum": 50,
                "HalfCheetah": 50, "HumanoidStandup": 15, "Hopper": 8, "Humanoid": 15, 
                "Pendulum": 10, "bipedal_walker": 10}
    dt_default = {"Ant": 0.05, "Swimmer": 0.04, "Reacher": 0.02, "InvertedPendulum":0.04,
            "Walker2d": 0.008, "Pusher": 0.05, "InvertedDoublePendulum": 0.05,
            "HalfCheetah": 0.05, "HumanoidStandup": 0.015, "Hopper": 0.008, "Humanoid": 0.015,
            "Pendulum": 0.05, "bipedal_walker": 0.02}
    # the following V is computed from 150k episodes
    V = {"Ant": 124.110, "Swimmer": 14.138, "Reacher": 0.02, "InvertedPendulum":0.04,
            "Walker2d": 0.008, "Pusher": -7.315, "InvertedDoublePendulum": 467.987,
            "HalfCheetah": 269.059, "HumanoidStandup": 0.015, "Hopper": 9.358, "Humanoid": 0.015,
            "Pendulum": -0.164, "bipedal_walker": 3.086} # these values are stored in log_env_V
    # the following V are computed from nb_runs=300 , or 300k episodes
    # V = {"Ant": 124.130, "Swimmer": 14.138,
    #        "Pusher": -7.320, "InvertedDoublePendulum": 467.987, # miss hopper 150
    #        "HalfCheetah": 269.058, "Hopper": 9.358,
    #        "Pendulum": -0.164, "bipedal_walker": 3.086} # these values are stored in log_env_V
 
    dt = 0.001
    args.V = V[env]
    #args.h_list = [0.002, 0.004, 0.008, 0.01, 0.02, 0.04, 0.08, 0.1, 0.2, 0.4, 1]
    h_list_env = {"Pendulum": [0.001, 0.002, 0.004, 0.01, 0.02, 0.04, 0.1],
                  "bipedal_walker":[0.001, 0.002, 0.004, 0.01, 0.02, 0.04, 0.1],
                  "Hopper":[0.001, 0.002, 0.004, 0.01, 0.02, 0.04, 0.1],
                  "HalfCheetah":[0.002, 0.004, 0.01, 0.02, 0.04, 0.1, 0.2, 0.4],
                  "Ant":[0.002, 0.004, 0.01, 0.02, 0.04, 0.1, 0.2, 0.4],
                  "InvertedDoublePendulum":[0.002, 0.004, 0.01, 0.02, 0.04, 0.1, 0.2, 0.4, 1],
                  "Swimmer":[0.002, 0.004, 0.01, 0.02, 0.04, 0.1],
                  "Pusher":[0.002, 0.004, 0.01, 0.02, 0.04, 0.1, 0.2, 0.4, 1], # similar to the IDP
                  } 
    args.h_list = h_list_env[env]
    ratio_list = [1,2,5,10,20,40]#,30] # try to reproduce the plot, find hstar
    # ratio_list = [40] # try to reproduce the plot, find hstar
    # ratio_list = [1,2,5,10,20,4,8,16]#,30] # added 4,8,16
    B_dict = {"Pendulum": 10000, "bipedal_walker": 10000, "Hopper": 8000, "HalfCheetah": 25000, 
            "Ant": 25000, "InvertedDoublePendulum": 25000, "Swimmer": 20000, "Pusher": 25000}
    B0 = B_dict[env]
    B = [i*B0 for i in ratio_list]

    # load args from logdir if already existed, but should give key args here
    norm_dict={'bipedal_walker':'_norm'}
    normalize_suffix = norm_dict.get(env, '')

    cpu_suffix = '_cpu'
    datadir=f"log/log_{env_id}_ct_data_100_10{normalize_suffix}{cpu_suffix}/array"
    # the 100_10 here means: 100 episodes x 10 inter -> 1000 episodes in each reward file
    # cpu_suffix: most scripts were run on cpu
    logdir_V=f'log/log_{env_id}_V'
    logdir_Vhat=f'log/log_{env_id}_Vhat'

    if not exists(logdir_V):
        makedirs(logdir_V)
    
    if not exists(logdir_Vhat):
        makedirs(logdir_Vhat)

    args.data_dir = datadir
    args.logdir_V = logdir_V
    args.logdir_Vhat = logdir_Vhat
    
    args.env_id = env_id
    args.env_name = env_name[env]
    args.dt = dt
    args.time_limit = time_limit[env]
    args.nb_runs = nb_runs
    args.B = B
    args.N_trials = 30
    args.shuffle = shuffle
    args.seed = seed
    args.outdir = outdir if outdir is not None else logdir_Vhat
    args.ylabel = False # True #False #True#False
    args.legend = False
    args.show_img = show_img
    return args


def load_data_from_npy(data_dir, T, nb_runs):
    rewards = []
    paths = [join(data_dir, f'Rewards_{int(T*1000)}_1000_{i}.npy') for i in range(nb_runs)]
    for path in paths:
        reward_curr_bat = np.load(path) 
        rewards.append(reward_curr_bat)
    rewards = np.hstack(rewards)
    info(f'Loaded from {paths}')
    return rewards

def npy_to_return(data_dir, T, h, nb_runs, is_bipedal=False, is_hopper=False):
    """ load reward np array and incrementally compute return """
    nb_steps = int(T/h)
    if is_bipedal: # bipedal contains 2x data that we actually need
        paths = [join(data_dir, f'Rewards_20000_1000_{i}.npy') for i in range(nb_runs)]
    elif is_hopper: # hopper contains more steps data that we actually need
        if nb_runs >= 150:
            ## the files from the run 150 onwards have 8000 steps rather than 50k steps
            paths = [join(data_dir, f'Rewards_50000_1000_{i}.npy') for i in range(150)]
            paths += [join(data_dir, f'Rewards_8000_1000_{150+i}.npy') for i in range(nb_runs-150)]
        else:
            paths = [join(data_dir, f'Rewards_50000_1000_{i}.npy') for i in range(nb_runs)]
    else:
        paths = [join(data_dir, f'Rewards_{int(T*1000)}_1000_{i}.npy') for i in range(nb_runs)]
        #paths = [path for path in glob.glob(join(data_dir, 'Rewards_*.npy'))]
        #paths.sort()
        #paths = paths[:nb_runs]
    total_reward = 0 
    for i in trange(len(paths)):
        reward_data_curr_batch = np.load(paths[i]) 
        #import pdb;pdb.set_trace()
        # (nb_steps, 1k) -> (1k,)
        total_rew_curr_batch = h*np.sum(reward_data_curr_batch[:nb_steps, :]) 
        total_reward += total_rew_curr_batch
    #info(f'Loaded from {paths}')
    return total_reward

def load_from_compressed(data_dir, nb_runs):
    data = np.load(join(data_dir, f'compressed_reward_{nb_runs}k.npy')) # the {nb_runs}k was added later
    # as of Jan 2023, compressed_reward.npy contains data for 300 runs(each 1k episodes)
    # compressed_reward_150k.npy contains data for 150 runs
    return data

def compress_reward_from_npy(args, verbose=False):
    """ load reward np array and incrementally compute return """
    data_dir = args.data_dir
    logdir = args.logdir_Vhat
    T = args.time_limit
    env_id = args.env_id
    nb_runs = args.nb_runs
    h0 = args.dt
    h_list = args.h_list
    nb_steps = int(T/h0)
    is_bipedal = env_id == "bipedal_walker"
    is_hopper = env_id == "Hopper-v3"

    if is_bipedal: # bipedal contains 2x data that we actually need
        paths = [join(data_dir, f'Rewards_20000_1000_{i}.npy') for i in range(nb_runs)]
    elif is_hopper: # hopper contains more steps data that we actually need
        if nb_runs >= 150:
            ## the files from the run 150 onwards have 8000 steps rather than 50k steps
            paths = [join(data_dir, f'Rewards_50000_1000_{i}.npy') for i in range(150)]
            paths += [join(data_dir, f'Rewards_8000_1000_{150+i}.npy') for i in range(nb_runs-150)]
        else:
            paths = [join(data_dir, f'Rewards_50000_1000_{i}.npy') for i in range(nb_runs)]
        # the following does not follow ordering of the runs
        #paths = [path for path in glob.glob(join(data_dir, 'Rewards_*.npy'))]
        #paths.sort()
        # paths = paths[:nb_runs]
    else:
        paths = [join(data_dir, f'Rewards_{int(T*1000)}_1000_{i}.npy') for i in range(nb_runs)]
        #paths = [path for path in glob.glob(join(data_dir, 'Rewards_*.npy'))]
        #paths.sort()
        # paths = paths[:nb_runs]
    nb_episodes = nb_runs * 1000
    Jh = np.zeros((len(h_list), nb_episodes))
    for i in trange(len(paths)):
        # (nb_steps, 1k) -> (nb_step/2,1k), (nb_steps/4, 1k), ... 
        reward_data_curr_batch = np.load(paths[i]) 
        reward_data_curr_batch=reward_data_curr_batch[:nb_steps, :]
        for j, h in enumerate(h_list):
            h_ratio = int(h/h0)
            if verbose:
                info(f'h={h}, h_ratio={h_ratio}')
            xh = reward_data_curr_batch[::h_ratio, :]
            Jh[j, 1000*i:1000*(i+1)] = h * np.sum(xh, axis=0) # shape, (1k,)
            # import pdb;pdb.set_trace()
    
    #out_path = join(logdir, 'compressed_reward.npy')
    out_path = join(logdir, f'compressed_reward_{nb_runs}k.npy')
    np.save(out_path, Jh) 
    info(gt.report())
    # return Jh # compressed data array, (len(h_list), nb_episodes)
 
def compute_V(args):
    ## 60 Gb of array, does not fit in the RAM
    # what we want: V for each episode
    # approach: add the reward up, then average them by episodes
    info(f"args: {args}")
    logdir = args.logdir_V
    T = args.time_limit
    env_id = args.env_id
    nb_runs = args.nb_runs
    h = args.dt

    # load_the_file
    #rewards = load_data_from_npy(args.data_dir, T, nb_runs)
    #info(f'rewards shape: {rewards.shape}')
    #gt.stamp('Load the rewards data')
    total_reward = npy_to_return(args.data_dir, T, h, nb_runs, env_id=='bipedal_walker', env_id=='Hopper-v3')
    nb_episodes = nb_runs * 1000
    gt.stamp(f'Load and compute the total rewards from {nb_episodes} episodes')

    # xh = rewards
    # Jh = h * np.sum(xh, axis=0) # [nb_steps, nb_episodes]
    # Jh has a shape of (M)
    #info(f"shape of Jh {Jh.shape}")
    #V = np.mean(Jh)

    V = total_reward / nb_episodes
    info(f"true value is {V}")
    open(join(logdir, f"V_{V}_{env_id}_{nb_episodes}.txt"), 'w').close()
    info(gt.report())

def run_MSE(args):
    """ run MSE for one choice of dt """
    B_list = args.B
    logdir = args.logdir_Vhat
    T = args.time_limit
    h0 = args.dt
    h_list = args.h_list
    nb_runs = args.nb_runs
    N_trials = args.N_trials
    shuffle = args.shuffle
    info(f'B_list is {B_list}, \nnb_files: {nb_runs}, \nshuffle mode: {shuffle}')
    
    # load_the_file
    rewards = load_from_compressed(logdir, nb_runs)
    gt.stamp('Loaded the rewards data')

    # given an h, compute the Vhat
    if shuffle==2: # only shuffle at the beginning
        col_indices = np.random.permutation(rewards.shape[-1])
    elif shuffle==0: # no shuffle
        col_indices = np.arange(rewards.shape[-1])

    for B in B_list:
        Vhat = np.zeros((N_trials, len(h_list)))
        # shuffle the episodes for each B to reduce correlation
        if shuffle == 1:
            col_indices = np.random.permutation(rewards.shape[-1])
        for j in range(N_trials):
            for i, h in enumerate(h_list):
                M = int(B*h/T) # nb of episodes, assuming it's an integer (enough data to cover full episodes)
                Vhat[j, i] = compute_Vhat(i, M*j, M, rewards[:, col_indices], False)
            gt.stamp(f'compute Vhat for B={B}, trial {j}')

        # save Vhat
        #shuffle_suffix = '_s' if shuffle else ''
        seed = args.seed
        shuffle_suffix = f'_s{shuffle}_{seed}' if shuffle>0 else ''
        out_path = join(logdir, f'Vhat_B_{B}{shuffle_suffix}_trials{N_trials}.npy')
        np.save(out_path, Vhat) 
        gt.stamp(f'saved Vhat for h={h_list} and B={B}')

    info(gt.report())

def compute_Vhat(h_idx, base_M, M, rewards, verbose=False):
    # each col in rewards data is an episode
    # each Jh[i] is the return of nb_episodes for i-th h on the list
    Jh = rewards[h_idx, base_M:base_M+M]
    if verbose:
        info(f'h_idx={h_idx}, M={M}')
        info(f"shape of Jh: {Jh.shape}")
    #import pdb;pdb.set_trace()
    Vhat = np.mean(Jh) # avg over M trajectories
    return Vhat

def plot_MSE_over_h(args, no_log_val, std_err=False):
    ft_size=20
    h_list = args.h_list
    T = args.time_limit
    V = args.V
    B_list = args.B
    N_trials = args.N_trials
    shuffle = args.shuffle
    seed = args.seed
    shuffle_suffix = f'_s{shuffle}_{seed}' if shuffle>0 else ''
    outdir = args.outdir
    legend = args.legend
    ylabel = args.ylabel
    legend_suffix = '' if legend else '_nolegend' 
    ylabel_suffix = '' if ylabel else '_noylabel'
    env_name = args.env_name
    info(f"V is {V}; B_list is {B_list}; h_list is {h_list}")

    fig, ax = plt.subplots()
    for B in B_list:
        Vhat = np.load(join(args.logdir_Vhat, f'Vhat_B_{B}{shuffle_suffix}_trials{N_trials}.npy'))
        nb_runs = Vhat.shape[0] 
        data_global_min = 1e4
        # Vhat: [nb_trials, len(h)]
        if len(h_list) == Vhat.shape[1]:
            data = (Vhat - V)**2
            info(f"shape of the data is {data.shape}") #(10,9)
            if not no_log_val:
                data = np.log10(data)
            obj = np.mean(data, axis=0)
            # obj_med = np.median(data, axis=0)
            std = np.std(data, axis=0)
            if std_err:
                stderr = np.std(data, axis=0) / np.sqrt(nb_runs)
            dmin = np.min(data, axis=0) 
            dmax = np.max(data, axis=0) 
            data_global_min = min(data_global_min, np.min(obj))
        else: # fewer h than what's in Vhat
            data = (Vhat[:, :len(h_list)] - V)**2
            info(f"shape of the data is {data.shape}") #(10,9)
            if not no_log_val:
                data = np.log10(data)
            #obj = np.mean((Vhat[:, :len(h_list)] - V)**2, axis=0)
            obj = np.mean(data, axis=0)
            obj_med = np.median(data, axis=0)
            std = np.std(data, axis=0)
            if std_err:
                stderr = np.std(data, axis=0) / np.sqrt(nb_runs)
            dmin = np.min(data, axis=0)
            dmax = np.max(data, axis=0)
            data_global_min = min(data_global_min, np.min(obj))
        #ax.plot(h_list, obj,'o', linewidth=2, label=f'B={B:.0f}', markeredgewidth=1.5,
        #info(f'dmax for B={B}: ', dmax)
        #info(f'dmin for B={B}: ', dmin)
        min_idx = np.argmin(obj)
        info(f'h* for B={B}: {h_list[min_idx]} at index {min_idx}')
        # info(f'min(MSE) for B={B}: {obj[min_idx]}')

        scatter=False
        if not scatter:
            ax.plot(h_list, obj, '-o', linewidth=2, label=f'B={B:.0f}', markeredgewidth=1.5,
                                alpha=0.6, 
                                #markerfacecolor=(0.7,0.7,0.7,0.5),
                                markeredgecolor=(0,0,0,1) ) # lower the frequency of data when plotting
            #ax.fill_between(h_list, obj_med + std, obj- std, alpha=0.1)
            #ax.fill_between(h_list, dmax, dmin, alpha=0.1)
            if std_err:
                ax.fill_between(h_list, obj + stderr, obj - stderr, alpha=0.1)
            else:
                ax.fill_between(h_list, obj + std, obj - std, alpha=0.1)
        else:
            for j in range(data.shape[0]):
                print(data[j,:].shape)
                ax.plot(h_list, data[j, :],'o', linewidth=2, label=f'run={j+1}', markeredgewidth=1.5,
                                alpha=0.6, 
                                #markerfacecolor=(0.7,0.7,0.7,0.5),
                                markeredgecolor=(0,0,0,1) ) # lower the frequency of data when plotting
            #ax.fill_between(h_list, obj+ std, obj- std, alpha=0.1)
        #if B == 10000:
        #    ax.fill_between(h_list, dmax, dmin, alpha=0.1)
            #ax.fill_between(h_list, dmin, dmax, alpha=0.1)
    if no_log_val:
        ax.set_yscale('log')
    ax.set_xscale('log')
    ax.set_title( f'{env_name}, T={T:.0f}', fontsize=ft_size )   
    ax.set_xlabel("h", fontsize=ft_size )
    if not no_log_val:
        ylabel_str = r'$\log(\hat{V}_M(h) - V)^2$'
    else:
        ylabel_str = r'$(\hat{V}_M(h) - V)^2$'
    if ylabel:
        ax.set_ylabel(ylabel_str, fontsize=ft_size)
    info(f"global min of mean is {data_global_min}" )
    if not std_err:
        ax.set_ylim(bottom=data_global_min-1)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    if legend:
        ax.legend(bbox_to_anchor = (1.05, 0.6))
    fig.tight_layout()
    ax.tick_params(labelsize=ft_size-4)
    mode = '_stde' if std_err else ''
    log_suffix='_log' if not no_log_val else ''
    B_suffix=f'_B{max(B_list)}'
    #fname = f'MSE_T{T}_{args.env_id}{log_suffix}_trials{N_trials}{shuffle_suffix}{mode}{ylabel_suffix}{legend_suffix}' 
    fname = f'MSE_T{T}_{args.env_id}{B_suffix}{log_suffix}_trials{N_trials}{shuffle_suffix}{mode}{ylabel_suffix}{legend_suffix}' 
    img_dir = join(outdir,'img')
    if not exists(img_dir):
        makedirs(img_dir) 
    print('img_dir is: ', img_dir)
    fig.savefig(join(img_dir,fname+'.pdf'), bbox_inches='tight', dpi=300)
    # fig.savefig(join(img_dir,fname+'.png'), bbox_inches='tight', dpi=300)
    if args.show_img:
        plt.show()

def plot_mujoco_legend(outdir):
    fig, ax = plt.subplots()
    legend_names = ["Sample Budget", r"$B_0$", r"$2\times B_0$", r"$5\times B_0$",r"$10\times B_0$", r"$20\times B_0$", r"$40\times B_0$"]
    # this line is for dense budget
    # legend_names = ["Sample Budget", r"$B_0$", r"$2\times B_0$",r"$4\times B_0$",r"$5\times B_0$",r"$8\times B_0$",r"$10\times B_0$",r"$16\times B_0$", r"$20\times B_0$"]
    #color = ["#834e56", "#2f9e44", "#9932cc" ,"#009acd", "#ff2500" ,"#ff6eb4", 'k']
    #color = ["#834e56", "#2f9e44", "#9932cc" ,"#009acd", "#ff2500" ,"tab:orange", 'k']
    #color = ["#834e56", "#f7a325", "#2f9e44", "#9932cc" ,"#009acd", "#ff2500" ,"#474c4d"]
    ax.plot([],[],'', lw=0, color='w', label=legend_names[0])
    for i in range(1,len(legend_names)):
        ax.plot([],[],"-o",lw=2, label=legend_names[i], markeredgewidth=1.5,
                                alpha=0.6, 
                                markeredgecolor=(0,0,0,1) ) #
    
    legend = plt.legend(loc="lower left", ncol = len(ax.lines), framealpha=1, frameon=False)#, title='Sample Budget')
    legend._legend_box.align="left"
    fig = legend.figure
    fig.canvas.draw()
    bbox  = legend.get_window_extent()
    expand = [-5,-5,5,5]
    bbox = bbox.from_extents(*(bbox.extents + np.array(expand)))
    bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.axis('off')
    #fig.savefig(join(outdir,'img','mujoco_legend.pdf'), dpi=600, bbox_inches=bbox)
    fig.savefig(join(outdir,'img','mujoco_legend_largerB.pdf'), dpi=600, bbox_inches=bbox)
    #fig.savefig(join(outdir,'img','mujoco_legend.png'), dpi=600, bbox_inches=bbox)
    #plt.show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_id', type=str, default=None, help='env id')
    parser.add_argument('--nb_runs', type=int, default=5, 
                        help='nb of runs, each with 1k episodes')
    parser.add_argument('--seed', type=int, default=1234, 
                        help='seed')
    parser.add_argument('--out_dir', type=str, default=None, 
                        help='output dir for figs')
    parser.add_argument('--compute_V', action='store_true',
                        help='if true, compute V')
    parser.add_argument('--plot', action='store_true',
                        help='if true, create plot')
    parser.add_argument('--compress_data', action='store_true',
                        help='if true, create compressed data')
    parser.add_argument('--show_img', action='store_true',
                        help='if true, run plt.show()')
    parser.add_argument('--no_log_val', action='store_true',
                        help='if true, then not compute the statistics for log value')
    parser.add_argument('--shuffle', type=int, default=1,
                        help='1: shuffle the episodes each time; 2: shuffle at the start, 0: no shuffle')
    parser.add_argument('--plot_legend', action='store_true',
                        help='if true, then plot the legend fig')

    args = parser.parse_args()
    
    basicConfig(stream=sys.stdout, level=INFO)
    # workflow from ct data for each env: 
    # 1. compute_V only once, store it,
    # 2. compress_data only once, store it as compressed_reward_{nb_runs}k.npy
    # 3. run run_MSE, get Vhat_B_{}.npy for each data budget
    # 4. plot from reading Vhat files
    seed = args.seed
    config = setup(args.env_id, args.nb_runs, args.shuffle, seed, args.out_dir, args.show_img)
    std_err = True
    np.random.seed(seed)# 1234

    if args.compute_V:
        compute_V(config) 
    elif args.compress_data:
        compress_reward_from_npy(config)
    elif args.plot:
        plot_MSE_over_h(config, args.no_log_val, std_err=std_err) 
    elif args.plot_legend:
        plot_mujoco_legend(args.out_dir)
    else:
        run_MSE(config)
