# this file plots the linear portion of the synthetic data results (bottom row of figure 1)
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os.path import exists

# load results
n, L, trials = 100000, 100, 5  # horizon, number of items, and number of trials
values = np.arange(4, 18, 2)  # all values of d and K we tested
all_tests = np.concatenate((np.stack((values, 10 * np.ones(7))), np.stack((10 * np.ones(7), values))), axis=1)
all_tests = np.transpose(all_tests)  # rows now contain all K, d pairs we want to test
num_tests = all_tests.shape[0]  # number of tests
regret_avg = np.zeros((2, n, num_tests))  # average of the regret across trials
regret_std = np.zeros((2, n, num_tests))  # standard deviation of the regret across trials
for test_idx in range(num_tests):
    # try to load results for current test, printing error message and exiting if it doesn't exist
    [K, d] = all_tests.astype(int)[test_idx, :]
    result_file = 'results/syn_lin_' + str(K) + '_' + str(d) + '.npz'
    if not exists(result_file):
        print('error -- file ' + result_file + ' does not exist -- see README')
        exit()
    regret = np.load(result_file)['arr_0']
    regret_avg[:, :, test_idx] = np.mean(regret, axis=2)
    regret_std[:, :, test_idx] = np.std(regret, axis=2)

# if we've reached this line, we've loaded all the data and just need to make the plots
# as a disclaimer, this plotting code is hacked together from incongruous matplotlib examples so may be perplexing

# initialize plot
matplotlib.rcParams.update({'font.size': 18})
fig, ax = plt.subplots(1, 3)
fig.set_size_inches(11, 3)
fig.subplots_adjust(wspace=0.05)
plt.tight_layout()
plt.subplots_adjust(left=0.075, right=0.98, bottom=0.21, top=0.7)

# first plot
avg = regret_avg[:, n - 1, 0:7]
std = regret_std[:, n - 1, 0:7]
ax[0].plot(values, avg[0, :], 'r--')
ax[0].fill_between(values, avg[0, :] - std[0, :], avg[0, :] + std[0, :], color='r', alpha=0.2)
ax[0].plot(values, avg[1, :], 'g:')
ax[0].fill_between(values, avg[1, :] - std[1, :], avg[1, :] + std[1, :], color='g', alpha=0.2)
ax[0].set(xlabel=r'$K$')
ax[0].set(xlim=[min(values), max(values)])
ax[0].set_xticks([4, 6, 8, 10, 12, 14, 16])
ax[0].set(ylabel='Regret')
ax[0].set(ylim=[0, 18e2])
ax[0].set_yticks([0, 6e2, 12e2, 18e2])
ax[0].ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
ax[0].set_title(r'$(n,d)=(10^5,10)$', fontdict={'fontsize': 18}, loc='right')
ax[0].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# second plot
avg = regret_avg[:, n - 1, 7:16]
std = regret_std[:, n - 1, 7:16]
ax[1].plot(values, avg[0, :], 'r--')
ax[1].fill_between(values, avg[0, :] - std[0, :], avg[0, :] + std[0, :], color='r', alpha=0.2)
ax[1].plot(values, avg[1, :], 'g:')
ax[1].fill_between(values, avg[1, :] - std[1, :], avg[1, :] + std[1, :], color='g', alpha=0.2)
ax[1].set(xlabel='d')
ax[1].set(xlim=[min(values), max(values)])
ax[1].set_xticks([4, 6, 8, 10, 12, 14, 16])
ax[1].set(ylabel='Regret')
ax[1].set(ylim=[0, 18e2])
ax[1].set_yticks([0, 6e2, 12e2, 18e2])
ax[1].ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
ax[1].set_title(r'$(n,K)=(10^5,10)$', fontdict={'fontsize': 18}, loc='right')
ax[1].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# third plot
avg = regret_avg[:, :, 3]
std = regret_std[:, :, 3]
ax[2].plot(range(n), avg[0, :], 'r--', label='CascadeLinUCB')
ax[2].fill_between(range(n), avg[0, :] - std[0, :], avg[0, :] + std[0, :], color='r', alpha=0.2)
ax[2].plot(range(n), avg[1, :], 'g:', label='CascadeWOFUL')
ax[2].fill_between(range(n), avg[1, :] - std[1, :], avg[1, :] + std[1, :], color='g', alpha=0.2)
ax[2].set(xlabel=r'$n$')
ax[2].set(xlim=[0, n])
ax[2].set_xticks([0, 2e4, 4e4, 6e4, 8e4, 10e4])
ax[2].ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
ax[2].set(ylabel='Regret')
ax[2].set(ylim=[0, 12e2])
ax[2].set_yticks([0, 3e2, 6e2, 9e2, 12e2])
ax[2].ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
ax[2].set_title(r'$d=K=10$', fontdict={'fontsize': 18}, loc='right')
ax[2].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# add legend and save figure
handles, labels = ax[2].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=4)
fig.savefig('plots/syn_lin.png', dpi=300)
