import torch
import torch.nn
import torch.nn.functional as F
from IPython import embed
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.cm import viridis
import time
import argparse
import torch.nn.functional as F
import time
import os
import pickle
from dataset import TrajDataset
from net import Net, Transformer, TransformerTall, TransformerBERT
from lqr_env import LQREnv, LQRController, TransformerController
from darkroom_env import DarkroomEnv, DarkroomEnvStitch, DarkroomEnvPermuted, DarkroomEnvVec, DarkroomOptPolicy, DarkroomTransformerController, RandCommit
import bandit_env
from bandit_env import BanditTransformerController, BanditEnv, OptPolicy, GreedyOptPolicy, PessMeanPolicy, EmpMeanPolicy, UCBPolicy
from bandit_env import TopKBanditTransformerController, TopKBanditEnv, TopKRandCommitPolicy, LinUCB, ThompsonSamplingPolicy
import pandas as pd
import scipy.stats
from evals import eval_darkroom
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




def analyze_pess(consts, eval_trajs):
    
    const_values = []
    for const in consts:

        all_rs = []
        for i in range(len(eval_trajs)):
            traj = eval_trajs[i]
            means = traj['means']
            env = BanditEnv(means, H, var=var)
            
            batch = {
                'rollin_xs': torch.tensor(traj['rollin_xs'][None,:,:]).float().to(device),
                'rollin_us': torch.tensor(traj['rollin_us'][None,:,:]).float().to(device),
                'rollin_xps': torch.tensor(traj['rollin_xps'][None,:,:]).float().to(device),
                'rollin_rs': torch.tensor(traj['rollin_rs'][None,:,None]).float().to(device)
            }
            pess = PessMeanPolicy(env, const=const)
            pess.set_batch(batch)

            xs_pess, us_pess, xps_pess, rs_pess = env.deploy_eval(pess)
            all_rs.append(rs_pess)

        const_values.append(all_rs)
    
    return np.array(const_values)
            


if __name__ == '__main__':
    if not os.path.exists('figs/loss'):
        os.makedirs('figs/loss', exist_ok=True)
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=1000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--embd", type=int, required=False, default=32, help="Embedding")
    parser.add_argument("--head", type=int, required=False, default=1, help="Embedding")
    parser.add_argument("--layer", type=int, required=False, default=3, help="Embedding")
    parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Dimension")
    parser.add_argument("--dim", type=int, required=False, default=1, help="Dimension")
    parser.add_argument("--epoch", type=int, required=False, default=-1, help="Epoch to evaluate")
    parser.add_argument("--opt", type=int, required=False, default=0, help="Optimizer type")
    parser.add_argument("--dropout", type=float, required=False, default=0, help="Dropout")
    parser.add_argument("--var", type=float, required=False, default=0.0, help="Variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage")
    parser.add_argument("--test_cov", type=float, required=False, default=-1.0, help="Test coverage")
    parser.add_argument("--trans", type=int, required=False, default=0, help="Transformer type")
    parser.add_argument("--hor", type=int, required=False, default=-1, help="Episode horizon (for mdp)")
    parser.add_argument("--env", type=str, required=True, help="Environment")
    parser.add_argument("--k", type=int, required=False, default=1, help="Top K value")
    parser.add_argument("--alg", type=str, required=False, default="random", help="Algorithm to generate data")
    parser.add_argument('--dataset_prefix', type=str, required=False, default='', help="Dataset prefix")
    parser.add_argument('--model_prefix', type=str, required=False, default='models', help="Model prefix")
    parser.add_argument(
        "--eval_with_expert_trajs",
        type=lambda x: (str(x).lower() == 'true'),
        required=False,
        default=False,
        help="Whether to evaluate with expert context")
    parser.add_argument(
        "--eval_in_train_tasks",
        type=lambda x: (str(x).lower() == 'true'),
        required=False,
        default=False,
        help="Whether to evaluate in train tasks")

    parser.add_argument('--full', default=False, action='store_true')
    parser.add_argument('--shuffle', default=False, action='store_true')
    parser.add_argument('--test', default=False, action='store_true')
    parser.add_argument("--include_partial_hist", default=False, action='store_true')
    parser.add_argument("--grow_context", default=False, action='store_true')

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']
    dim = args['dim']
    dx = dim
    du = dim
    n_embd = args['embd']
    n_head = args['head']
    n_layer = args['layer']
    lr = args['lr']
    epoch = args['epoch']
    shuffle = args['shuffle']
    full = args['full']
    opt = args['opt']
    dropout = args['dropout']
    var = args['var']
    trans = args['trans']
    cov = args['cov']
    test_cov = args['test_cov']
    envname = args['env']
    horizon = args['hor']
    use_test = args['test']
    topk = args['k']
    alg = args['alg']
    dataset_prefix = args['dataset_prefix']
    model_prefix = args['model_prefix']

    include_partial_hist = args['include_partial_hist']
    grow_context = args['grow_context']
    eval_with_expert_trajs = args['eval_with_expert_trajs']
    eval_in_train_tasks = args['eval_in_train_tasks']

    use_net = False
    save_video = True

    if test_cov < 0:
        test_cov = cov
    if horizon < 0:
        horizon = H

    if envname in ['bandit', 'bandit_ood']:
        bandit = True
        dx = 1
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_cov{cov}_H{H}_d{dim}'
        bandit_type = 'uniform'

    elif envname == 'bandit_thompson':
        bandit = True
        dx = 1
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_cov{cov}_H{H}_d{dim}'
        bandit_type = 'bernoulli'

    elif envname == 'bandit_topk':
        dx = 1
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_k{topk}_H{H}_d{dim}'
    
    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch', 'darkroom_permuted'] or envname.startswith('darkroom_heldout'):
        bandit = False
        dx = 2
        du = 5
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_alg{alg}'
    
    else:
        raise NotImplementedError

    config = {
        'H': H,
        'dx': dx,
        'du': du,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'Q': False,
        'full': full,
        'dropout': dropout,
    }

    if use_net:                 model = Net(config).to(device)
    elif trans == 0:            model = Transformer(config).to(device)
    elif trans == 1:            model = TransformerTall(config).to(device)
    else:                       model = TransformerBERT(config).to(device)

    if epoch < 0:       model_path = f'{model_prefix}/{filename}.pt'
    else:               model_path = f'{model_prefix}/{filename}_epoch{epoch}.pt'
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint)
    model.eval()
    model.config['full'] = False
    
    # load eval trajs from evaluation datasets
    n_eval = 100
    H_eval = 10
    
    # if envname == 'bandit':             eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{H}_d{dim}_var{var}_cov{test_cov}.pkl'
    # elif envname == 'darkroom':         eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{H}_d{dim}.pkl'
    # else:                               raise ValueError(f'Environment {envname} not supported')

    if envname in ['bandit', 'bandit_ood', 'bandit_thompson']:
        eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{horizon}_d{dim}_var{var}_cov{test_cov}.pkl'   
        save_filename = f'{filename}_testcov{test_cov}_hor{horizon}.pkl'
    elif envname == 'bandit_topk':
        eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{horizon}_d{dim}_var{var}_k{topk}.pkl'   
        save_filename = f'{filename}_hor{horizon}.pkl'
    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch', 'darkroom_permuted'] or envname.startswith('darkroom_heldout'):
        if alg == "random" or alg == "random_ppo":
            prefix = "datasets"
        elif alg == "ppo":
            prefix = "ppo_datasets"
        else:
            raise NotImplementedError

        if dataset_prefix:
            prefix = dataset_prefix

        if use_test:
            eval_filepath = f'{prefix}/trajs_{envname}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{horizon}_d{dim}_test.pkl'
            eval_filepath_train = f'{prefix}/trajs_{envname}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{horizon}_d{dim}_train.pkl'
        else:
            traj_str = 'expert_trajs' if eval_with_expert_trajs else 'trajs'
            traj_str += '_train' if eval_in_train_tasks else '_eval'
            if envname.startswith('darkroom_heldout_random_init'):
                eval_filepath = f'{prefix}/{traj_str}_darkroom_heldout_envs{n_eval}_H{horizon}_d{dim}.pkl'
            elif envname.startswith('darkroom_heldout'):
                eval_filepath = f'{prefix}/{traj_str}_darkroom_heldout_envs{n_eval}_H{horizon}_d{dim}.pkl'
            else:
                eval_filepath = f'{prefix}/trajs_{traj_str}_{envname}_envs{n_eval}_H{horizon}_d{dim}.pkl'
        save_filename = f'{filename}_hor{horizon}_test{use_test}_iph{include_partial_hist}_gc{grow_context}_train{eval_in_train_tasks}_expert{eval_with_expert_trajs}.pkl'
    else:
        raise ValueError(f'Environment {envname} not supported')   


    file = open(eval_filepath, 'rb')
    eval_trajs = pickle.load(file)
    file.close()
    
    if (['darkroom', 'darkroom_heldout', 'darkroom_stitch', 'darkroom_permuted'] or envname.startswith('darkroom_heldout')) and use_test:
        file = open(eval_filepath_train, 'rb')
        eval_trajs_train = pickle.load(file)
        file.close()

        train_goals = []
        for traj in eval_trajs_train:
            goal = tuple(traj['goal'])
            if goal not in train_goals:
                train_goals.append(goal)
        
        if envname != 'darkroom_stitch' and envname != 'darkroom_permuted':    # allow repeat goals
            eval_trajs = [traj for traj in eval_trajs if tuple(traj['goal']) not in train_goals]

        eval_trajs2 = []
        eval_goals = []
        for traj in eval_trajs:
            goal = tuple(traj['goal'])
            if goal not in eval_goals:
                eval_trajs2.append(traj)
                eval_goals.append(goal)
        eval_trajs = eval_trajs2

        # repeat eval_trajs until at most n_eval
        assert len(eval_trajs) > 0, "No eval trajs found"
        while len(eval_trajs) < n_eval:
            eval_trajs += eval_trajs
        eval_trajs = eval_trajs[:n_eval]


    n_eval = min(n_eval, len(eval_trajs))

    evals_filename = f"evals_epoch{epoch}"
    if not os.path.exists(f'figs/{evals_filename}'):
        os.makedirs(f'figs/{evals_filename}', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/pess'):
        os.makedirs(f'figs/{evals_filename}/pess', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/bar'):
        os.makedirs(f'figs/{evals_filename}/bar', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/lines'):
        os.makedirs(f'figs/{evals_filename}/lines', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/online'):
        os.makedirs(f'figs/{evals_filename}/online', exist_ok=True)

    if save_video and not os.path.exists(f'videos/{save_filename}/{evals_filename}'):
        os.makedirs(f'videos/{save_filename}/{evals_filename}', exist_ok=True)


    if envname == 'bandit' or envname == 'bandit_topk' or envname == 'bandit_thompson':
        config = {
            'H': H,
            'horizon': horizon,
            'var': var,
            'n_eval': n_eval,
            'envname': envname,
            'k': topk,
            'type': bandit_type,
        }
        eval_bandit.online(eval_trajs, model, **config)
        plt.savefig(f'figs/{evals_filename}/online/{save_filename}.png')
        plt.clf()
    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch', 'darkroom_permuted'] or envname.startswith('darkroom_heldout'):
        config = { 
            'Heps': 40,
            'horizon': horizon,
            'H': H,
            'n_eval': min(20, n_eval),
            'dim': dim,
            'stitch': True if envname == 'darkroom_stitch' else False,
            'permuted': True if envname == 'darkroom_permuted' else False,
            'random_init': False,  # True if envname == 'darkroom_heldout_random_init' else False,
            'include_partial_hist': include_partial_hist,
            'grow_context': grow_context,
            'filename': f'{save_filename}/{evals_filename}',
        }
        start_time = time.time()
        eval_darkroom.online_vec(eval_trajs, model, **config)
        print("Online evaluation took: ", time.time() - start_time, " seconds")
        plt.savefig(f'figs/{evals_filename}/online/{save_filename}.png')
        plt.clf()









    all_xs = []
    all_rs_lnr = []
    all_rs_greedy = []
    all_rs_opt = []
    all_rs_emp = []
    all_rs_pess = []
    all_rs_rnd = []
    all_rs_lnr_greedy = []
    all_rs_lin = []
    all_rs_thmp = []

    envs = []
    trajs = []

    # OFFLINE EVALUATION SINGLE
    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")

        traj = eval_trajs[i_eval]        
        batch = {
            'rollin_xs': torch.tensor(traj['rollin_xs'][None,:,:]).float().to(device),
            'rollin_us': torch.tensor(traj['rollin_us'][None,:,:]).float().to(device),
            'rollin_xps': torch.tensor(traj['rollin_xps'][None,:,:]).float().to(device),
            'rollin_rs': torch.tensor(traj['rollin_rs'][None,:,None]).reshape(1,-1,1).float().to(device)
        }


        if envname in ['bandit', 'bandit_ood', 'bandit_thompson']:
            means = traj['means']
            env = BanditEnv(means, horizon, var=var, type=bandit_type)       # naming issue here for length of contexts

            true_opt = OptPolicy(env)
            greedy = GreedyOptPolicy(env)
            lnr = BanditTransformerController(model, sample=False)
            emp = EmpMeanPolicy(env)
            pess = PessMeanPolicy(env, .8)
            thmp = ThompsonSamplingPolicy(env, var=var)

            true_opt.set_batch(batch)
            greedy.set_batch(batch)
            lnr.set_batch(batch)
            emp.set_batch(batch)
            pess.set_batch(batch)
            thmp.set_batch(batch)

            xs_greedy, us_greedy, xps_greedy, rs_greedy = env.deploy_eval(greedy)
            xs_opt, us_opt, xps_opt, rs_opt = env.deploy_eval(true_opt)
            xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy_eval(lnr)
            xs_emp, us_emp, xps_emp, rs_emp = env.deploy_eval(emp)
            xs_pess, us_pess, xps_pess, rs_pess = env.deploy_eval(pess)
            xs_thmp, us_thmp, xps_thmp, rs_thmp = env.deploy_eval(thmp)

            all_xs.append((xs_opt, xs_lnr))
            all_rs_opt.append(np.sum(rs_opt))
            all_rs_lnr.append(np.sum(rs_lnr))
            all_rs_greedy.append(np.sum(rs_greedy))
            all_rs_emp.append(np.sum(rs_emp))
            all_rs_pess.append(np.sum(rs_pess))
            all_rs_thmp.append(np.sum(rs_thmp))


        
        elif envname == 'bandit_topk':
            means = traj['means']
            env = TopKBanditEnv(means, horizon, var=var, k=topk)

            true_opt = OptPolicy(env)
            greedy = GreedyOptPolicy(env)
            lnr = TopKBanditTransformerController(model, k=topk, sample=False)
            rnd = TopKRandCommitPolicy(env, topk, horizon, immediate=True)
            lin = LinUCB(env, topk, const=0.0)

            true_opt.set_batch(batch)
            greedy.set_batch(batch)
            lnr.set_batch(batch)
            rnd.set_batch(batch)
            lin.set_batch(batch)

            xs_greedy, us_greedy, xps_greedy, rs_greedy = env.deploy_eval(greedy)
            xs_opt, us_opt, xps_opt, rs_opt = env.deploy_eval(true_opt)
            xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy_eval(lnr)
            xs_rnd, us_rnd, xps_rnd, rs_rnd = env.deploy_eval(rnd)
            xs_lin, us_lin, xps_lin, rs_lin = env.deploy_eval(lin)

            all_xs.append((xs_opt, xs_lnr))
            all_rs_opt.append(np.sum(rs_opt))
            all_rs_lnr.append(np.sum(rs_lnr))
            all_rs_greedy.append(np.sum(rs_greedy))
            all_rs_rnd.append(np.sum(rs_rnd))
            all_rs_lin.append(np.sum(rs_lin))



        elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch', 'darkroom_permuted'] or envname.startswith('darkroom_heldout'):
            goal = traj['goal']
            if envname == 'darkroom_stitch':
                env = DarkroomEnvStitch(dim, goal, H, eval=True)
            elif envname == 'darkroom_permuted':
                env = DarkroomEnvPermuted(dim, traj['perm_index'], H)
            else:
                env = DarkroomEnv(dim, goal, H, random_init=False)

            true_opt = DarkroomOptPolicy(env)
            rnd = RandCommit(env)

            true_opt.set_batch(batch)
            # lnr.set_batch(batch)
            # lnr_greedy.set_batch(batch)
            rnd.set_batch(batch)

            xs_opt, us_opt, xps_opt, rs_opt = env.deploy_eval(true_opt)
            xs_rnd, us_rnd, xps_rnd, rs_rnd = env.deploy_eval(rnd)


            # all_xs.append((xs_opt, xs_lnr))
            all_rs_opt.append(np.sum(rs_opt))
            # all_rs_lnr.append(np.sum(rs_lnr))
            # all_rs_lnr_greedy.append(np.sum(rs_lnr_greedy))
            all_rs_rnd.append(np.sum(rs_rnd))
            envs.append(env)
            trajs.append(traj)

    if envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch', 'darkroom_permuted'] or envname.startswith('darkroom_heldout'):
        print("Running darkroom offline evaluations in parallel")
        vec_env = DarkroomEnvVec(envs)
        lnr = DarkroomTransformerController(model, batch_size=n_eval, sample=True)
        lnr_greedy = DarkroomTransformerController(model, batch_size=n_eval, sample=False)

        batch = {
            'rollin_xs': torch.tensor(np.array([traj['rollin_xs'] for traj in trajs])).float().to(device),
            'rollin_us': torch.tensor(np.array([traj['rollin_us'] for traj in trajs])).float().to(device),
            'rollin_xps': torch.tensor(np.array([traj['rollin_xps'] for traj in trajs])).float().to(device),
            'rollin_rs': torch.tensor(np.array([traj['rollin_rs'][:, None] for traj in trajs])).reshape(n_eval,-1,1).float().to(device),
        }
        lnr.set_batch(batch)
        lnr_greedy.set_batch(batch)

        xs_lnr, us_lnr, xps_lnr, rs_lnr = vec_env.deploy_eval(
            lnr, include_partial_hist=include_partial_hist, grow_context=grow_context)
        xs_lnr_greedy, us_lnr_greedy, xps_lnr_greedy, rs_lnr_greedy = vec_env.deploy_eval(
            lnr_greedy, include_partial_hist=include_partial_hist, grow_context=grow_context)
        all_rs_lnr = np.sum(rs_lnr, axis=-1)
        all_rs_lnr_greedy = np.sum(rs_lnr_greedy, axis=-1)

        if save_video:
            directions = {
                0: (-0.1, 0),
                1: (0.1, 0),
                2: (0, 0.1),
                3: (0, -0.1),
                4: (0, 0),
            }

            for i_eval in range(n_eval):
                # visualize the context
                states = batch['rollin_xs'][i_eval].cpu().numpy().astype(np.float64)
                actions = batch['rollin_us'][i_eval].cpu().numpy().astype(np.float64)
                states0 = states[:, 0]
                states1 = states[:, 1]

                actions = np.argmax(actions, axis=-1)

                colors, us, vs = [], [], []
                for j in range(len(states0)):
                    colors.append(viridis(j/len(states0)))
                    us.append(directions[actions[j]][1])
                    vs.append(directions[actions[j]][0])

                plt.quiver(states1, states0, us, vs, color=colors, alpha=0.5, scale=3)

                # visualize the rollout
                states = xs_lnr_greedy[i_eval].astype(np.float64)
                states0 = states[:, 0]
                states1 = states[:, 1]
                plt.scatter(states1, states0, c='g', marker='x', s=200)

                plt.scatter(vec_env.envs[i_eval].goal[1], vec_env.envs[i_eval].goal[0], marker='o', facecolors='none', edgecolors='b', s=200)
                plt.ylim(-1, 10)
                plt.xlim(-1, 10)
                plt.gca().invert_yaxis()
                plt.savefig(f'videos/{save_filename}/{evals_filename}/test_offline_traj{i_eval}.png')
                plt.clf()

    if envname in ['bandit', 'bandit_ood', 'bandit_thompson']:
        baselines = {
            'opt': np.array(all_rs_opt),
            'lnr': np.array(all_rs_lnr),
            # 'greedy': np.array(all_rs_greedy),
            'emp': np.array(all_rs_emp),
            'pess': np.array(all_rs_pess),
            'thmp': np.array(all_rs_thmp)
        }
        subopt_baselines = {
            'lnr': baselines['opt'] - baselines['lnr'],
            # 'greedy': baselines['opt'] - baselines['greedy'],
            'emp': baselines['opt'] - baselines['emp'],
            'pess': baselines['opt'] - baselines['pess'],
            'thmp': baselines['opt'] - baselines['thmp']
        }
    elif envname == 'bandit_topk':
        baselines = {
            'opt': np.array(all_rs_opt),
            'lnr': np.array(all_rs_lnr),
            'greedy': np.array(all_rs_greedy),
            'rnd': np.array(all_rs_rnd),
            'lin': np.array(all_rs_lin)
        }
        subopt_baselines = {
            'lnr': baselines['opt'] - baselines['lnr'],
            'greedy': baselines['opt'] - baselines['greedy'],
            'rnd': baselines['opt'] - baselines['rnd'],
            'lin': baselines['opt'] - baselines['lin']
        }
    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch', 'darkroom_permuted'] or envname.startswith('darkroom_heldout'):
        baselines = {
            'opt': np.array(all_rs_opt),
            'lnr': np.array(all_rs_lnr),
            'rnd': np.array(all_rs_rnd),
            'lnr_greedy': np.array(all_rs_lnr_greedy)
        }
        subopt_baselines = {
            'lnr': baselines['opt'] - baselines['lnr'],
            'lnr_greedy': baselines['opt'] - baselines['lnr_greedy'],
            'rnd': baselines['opt'] - baselines['rnd'],

        }
    else:
        raise NotImplementedError

    baselines_means = {
        k: np.mean(v) for k, v in baselines.items()
    }
    subopt_baselines_means = {
        k: np.mean(v) for k, v in subopt_baselines.items()
    }

    colors = plt.cm.viridis(np.linspace(0, 1, len(baselines_means)))
    plt.bar(baselines_means.keys(), baselines_means.values(), color=colors)
    plt.title(f'Mean Reward on {n_eval} Trajectories')
    plt.savefig(f'figs/{evals_filename}/bar/{save_filename}_bar.png')
    plt.clf()





    # PESSIMISM ANALYSIS

    if envname in ['bandit', 'bandit_ood', 'bandit_thompson']:
        pess_consts = np.linspace(0, 5, 41)
        pess_values = analyze_pess(pess_consts, eval_trajs)[:, :, 0]
        pess_subopts = all_rs_opt - pess_values
        pess_means = np.mean(pess_subopts, axis=1)
        pess_sems = scipy.stats.sem(pess_subopts, axis=1)

        
        # plot the means and sems in error band
        plt.plot(pess_consts, pess_means, label='PESS mean')
        plt.fill_between(pess_consts, pess_means-pess_sems, pess_means+pess_sems, alpha=.5)
        plt.plot(pess_consts, np.ones(len(pess_consts))*subopt_baselines_means['lnr'], linestyle='--', label='LNR')
        plt.savefig(f'figs/{evals_filename}/pess/{save_filename}_pess.png')
        plt.clf()




    # OFFLINE EVALUATION GRAPH


    def run_controllers(controllers):
        all_rs = { k: [] for k in controllers.keys() }
        for i_eval in range(n_eval):
            print(f"Eval traj: {i_eval}")

            for cutoff in range(1, horizon + 1):


                traj = eval_trajs[i_eval]        
                
                means = traj['means']
                if envname == 'bandit' or envname == 'bandit_thompson':
                    env = BanditEnv(means, horizon, var=var, type=bandit_type)
                elif envname == 'bandit_topk':
                    env = TopKBanditEnv(means, horizon, var=var, k=topk)
                else:
                    raise NotImplementedError
                batch = {
                    'rollin_xs': torch.tensor(traj['rollin_xs'][None,:,:]).float().to(device)[:,:cutoff,:],
                    'rollin_us': torch.tensor(traj['rollin_us'][None,:,:]).float().to(device)[:,:cutoff,:],
                    'rollin_xps': torch.tensor(traj['rollin_xps'][None,:,:]).float().to(device)[:,:cutoff,:],
                    'rollin_rs': torch.tensor(traj['rollin_rs'][None,:,None]).float().to(device)[:,:cutoff,:],
                }

                for key in controllers.keys():
                    controllers[key].set_env(env)
                    controllers[key].set_batch(batch)
                    xs, us, xps, rs = env.deploy_eval(controllers[key])
                    all_rs[key].append(np.sum(rs))

        return all_rs

    if envname in ['bandit', 'bandit_ood', 'bandit_thompson']:

        true_opt = OptPolicy(env)
        greedy = GreedyOptPolicy(env)
        lnr = BanditTransformerController(model, sample=True)
        lnr_greedy = BanditTransformerController(model, sample=False)
        emp = EmpMeanPolicy(env)
        thmp = ThompsonSamplingPolicy(env, var=var)
        pess = PessMeanPolicy(env, const=.8)
        controllers = {
            'opt': true_opt,
            # 'greedy': greedy,
            'lnr': lnr,
            'lnr_greedy': lnr_greedy,
            'emp': emp,
            'pess': pess,
            'thmp': thmp
        }

    elif envname == 'bandit_topk':
        true_opt = OptPolicy(env)
        greedy = GreedyOptPolicy(env)
        lnr = TopKBanditTransformerController(model, k=topk, sample=True)
        lnr_greedy = TopKBanditTransformerController(model, k=topk, sample=False)
        rnd = TopKRandCommitPolicy(env, topk, horizon, immediate=True)
        lin = LinUCB(env, topk, const=0.0)
        controllers = {
            'opt': true_opt,
            'greedy': greedy,
            'lnr': lnr,
            'lnr_greedy': lnr_greedy,
            'rnd': rnd,
            'lin': lin,
        }
    else:
        raise NotImplementedError

    baselines = run_controllers(controllers)
    baselines = { k: np.array(v) for k, v in baselines.items() }
    
    # calculate suboptimality baselines which are the same keys minus opt
    subopt_baselines = {
        k: baselines['opt'] - v for k, v in baselines.items() if k != 'opt'
    }




    for key in subopt_baselines.keys():
        values = subopt_baselines[key].reshape(n_eval, horizon)
        means = np.mean(values, axis=0)
        sems = scipy.stats.sem(values, axis=0)
        plt.plot(np.arange(1, horizon+1), means, label=key)
        plt.fill_between(np.arange(1, horizon+1), means-sems, means+sems, alpha=.2)
    if not os.path.exists('figs/graphs/'):
        os.makedirs(f'figs/{evals_filename}/graph', exist_ok=True)
    plt.yscale('log')
    plt.legend()
    plt.title('Suboptimality w.r.t optimal')
    plt.xlabel('Dataset size')
    plt.savefig(f'figs/{evals_filename}/graph/{save_filename}_mean_log.png')
    plt.clf()



    