import numpy as np
import pickle
import matplotlib.pyplot as plt
from IPython import embed
from scipy.stats import sem
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--version", type=str)
parser.parse_args()
res = parser.parse_args()

version = res.version
assert version == 'cp' or version == 'mc'
ds = [10, 50, 1000, 5000, 25000, 50000]


if version == 'cp':
    values = np.load('data/values.npy')
    Ts = [0, 1, 2, 3, 4]
    ns = [500, 700, 900, 1100, 1300, 1500]

elif version == 'mc':
    values = np.load('data/mc_values.npy')
    Ts = [0, 1, 2, 3, 4]
    ns = [200, 400, 600, 800]
    # ds = [10, 50, 100, 1000, 5000, 25000, 50000]
    values = values[:,[0, 1, 3, 4, 5, 6],:]

values = values[Ts]
values = values[:,:,:len(ns)]
means = np.mean(values, axis=0)
stds = sem(values, axis=0)

ms_values = np.zeros((len(Ts), len(ns)))
holdout_values = np.zeros((len(Ts), len(ns)))


for k, T in enumerate(Ts):
    sels = []
    for n in ns:

        if version == 'cp':
            filename = 'selection/cp_n{n}_T{T}.pkl'.format(n=n,T=T)
        elif version == 'pend':
            filename = 'selection/pend_n{n}_T{T}.pkl'.format(n=n,T=T)
        elif version == 'mc':
            filename = 'selection/mc_n{n}_T{T}.pkl'.format(n=n,T=T)
        with open(filename, 'rb') as f:
            sel = pickle.load(f)
        sels.append(sel)


    for j, sel in enumerate(sels):
        print(sel['n'])
        ms_values[k, j] = values[k, sel['ms'], j]
        holdout_values[k, j] = values[k, sel['holdout'], j]

ms_means = np.mean(ms_values, axis=0)
holdout_means = np.mean(holdout_values, axis=0)

ms_stds = sem(ms_values, axis=0)
holdout_stds = sem(holdout_values, axis=0)


embed()

lw = 3
print(means)


for i in range(len(means)):
    d = ds[i]
    plt.plot(ns, means[i], label='d: ' + str(d), linestyle='--', linewidth=lw, alpha=.75)
    plt.fill_between(ns, means[i]-stds[i], means[i]+stds[i],alpha=.0)


plt.plot(ns, holdout_means, label='Hold-out', color='red', linewidth=lw, alpha=1)
plt.fill_between(ns, holdout_means-holdout_stds, holdout_means+holdout_stds,alpha=.2, color='red')

plt.plot(ns, ms_means, label='ModBE', color='blue', linewidth=lw, alpha=1)
plt.fill_between(ns, ms_means-ms_stds, ms_means+ms_stds, color='blue', alpha=.2)

fs = 16
plt.ylabel('Reward', fontsize=fs)
plt.xlabel('Dataset size (episodes)', fontsize=fs)
plt.grid(linestyle=':', alpha=.9)  

if version == 'cp':
    plt.title('CartPole', fontsize=fs)
    plt.legend(fontsize=12, loc='lower right')

elif version == 'mc':
    plt.title('MountainCar', fontsize=fs)
    plt.legend(fontsize=12, loc='upper left')
    plt.ylim(-205, -140)
plt.tight_layout()
plt.savefig('figures/{version}.pdf'.format(version=version))
plt.show()

