import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from utils import *

home_dir = "../"

data_list = []
data_list = data_list + ["ideal"]
data_list = data_list + ["cartpole_masscart_0", "cartpole_masscart_1", "cartpole_masscart_2", "cartpole_lenpole_0", "cartpole_lenpole_1", "cartpole_lenpole_2", "cartpole_masspole_0", "cartpole_masspole_1", "cartpole_masspole_2"]
data_list = data_list + ["pendulum_dt_0", "pendulum_dt_1", "pendulum_dt_2", "pendulum_l_0", "pendulum_l_1", "pendulum_l_2", "pendulum_m_0", "pendulum_m_1", "pendulum_m_2"]
data_list = data_list + ["walker_friction_0", "walker_friction_1", "walker_friction_2", "walker_gravity_0", "walker_gravity_1", "walker_gravity_2", "walker_scale_0", "walker_scale_1", "walker_scale_2"]
data_list = data_list + ['halfcheetah_friction_0', 'halfcheetah_friction_1', 'halfcheetah_friction_2', 'halfcheetah_gravity_0', 'halfcheetah_gravity_1', 'halfcheetah_gravity_2', 'halfcheetah_stiffness_0', 'halfcheetah_stiffness_1', 'halfcheetah_stiffness_2']
data_list = data_list + ["cartpole_lenpole_ppo_0", "cartpole_lenpole_ppo_1", "cartpole_lenpole_ppo_2", "cartpole_masspole_ppo_0", "cartpole_masspole_ppo_1", "cartpole_masspole_ppo_2", "cartpole_masscart_ppo_0", "cartpole_masscart_ppo_1", "cartpole_masscart_ppo_2"]
data_list = data_list + ["cartpole_lenpole_a2c_0", "cartpole_lenpole_a2c_1", "cartpole_lenpole_a2c_2", "cartpole_masspole_a2c_0", "cartpole_masspole_a2c_1", "cartpole_masspole_a2c_2", "cartpole_masscart_a2c_0", "cartpole_masscart_a2c_1", "cartpole_masscart_a2c_2"]
data_list = data_list + ["intersection_flow_0", "intersection_flow_1", "intersection_flow_2", "intersection_speed_0", "intersection_speed_1", "intersection_speed_2", "intersection_length_0", "intersection_length_1", "intersection_length_2"]
data_list = data_list + ["advisoryautonomy_ring_acc", "advisoryautonomy_ring_vel", "advisoryautonomy_ramp_acc", "advisoryautonomy_ramp_vel"]
data_list = data_list + ["advisoryautonomy_ring_acc_0", "advisoryautonomy_ring_acc_1", "advisoryautonomy_ring_acc_2", "advisoryautonomy_ring_acc_3", "advisoryautonomy_ring_acc_4", 
                         "advisoryautonomy_ring_vel_0", "advisoryautonomy_ring_vel_1", "advisoryautonomy_ring_vel_2", "advisoryautonomy_ring_vel_3", "advisoryautonomy_ring_vel_4",
                         "advisoryautonomy_ramp_acc_0", "advisoryautonomy_ramp_acc_1", "advisoryautonomy_ramp_acc_2", "advisoryautonomy_ramp_acc_3", "advisoryautonomy_ramp_acc_4",
                         "advisoryautonomy_ramp_vel_0", "advisoryautonomy_ramp_vel_1", "advisoryautonomy_ramp_vel_2", "advisoryautonomy_ramp_vel_3", "advisoryautonomy_ramp_vel_4"]
data_list = data_list + ['no-stop_green_0', 'no-stop_green_1', 'no-stop_green_2', 'no-stop_penrate_0', 'no-stop_penrate_1', 'no-stop_penrate_2', 'no-stop_inflow_0', 'no-stop_inflow_1', 'no-stop_inflow_2']

for data_name in data_list:
    print("DATA: ", data_name)
    RANDOMNESS = False
    np.random.seed(42)
    
    
    data_transfer, deltas, delta_min, delta_max, slope, lower_bound, upper_bound, unguided = import_data(home_dir, data_name, RANDOMNESS)
    plot_heatmap(data_transfer, home_dir, data_name)
    
    num_transfer_steps = 15
    
    source_tasks_greedy, performance_greedy = greedy_heursitc_sts(data_transfer, deltas, num_transfer_steps, delta_min, delta_max, slope)
    source_tasks_greedy_marginal, performance_greedy_marginal = greedy_heursitc_sts_marginal(data_transfer, deltas, num_transfer_steps, delta_min, delta_max, slope)
    
    source_tasks_linucb, transfer_results_LinUCB = run_linucb(data_transfer, deltas, num_transfer_steps)
    
    source_tasks_GP_new, transfer_results_GP_new, _, _ = model_based_learning_with_GP_new(home_dir, data_name, deltas, num_transfer_steps, data_transfer, acquisition_function='new', gap_function='linear', slope=slope)
    source_tasks_GP_new_ucb, transfer_results_GP_new_ucb, _, _ = model_based_learning_with_GP_new(home_dir, data_name, deltas, num_transfer_steps, data_transfer, acquisition_function='new_ucb', gap_function='linear', slope=slope)
    source_tasks_GP_new_ucb_beta, transfer_results_GP_new_ucb_beta, _, _ = model_based_learning_with_GP_new(home_dir, data_name, deltas, num_transfer_steps, data_transfer, acquisition_function='new_ucb_beta', gap_function='linear', slope=slope)
    source_tasks_GP_new_ucb_beta_log, transfer_results_GP_new_ucb_beta_log, _, _ = model_based_learning_with_GP_new(home_dir, data_name, deltas, num_transfer_steps, data_transfer, acquisition_function='new_ucb_beta_log', gap_function='linear', slope=slope)
    source_tasks_GP_new_ucb_transfer, transfer_results_GP_new_ucb_transfer, _, _ = model_based_learning_with_GP_new(home_dir, data_name, deltas, num_transfer_steps, data_transfer, acquisition_function='new_ucb_transfer', gap_function='linear', slope=slope)
    
    source_tasks_GP_EI_F, transfer_results_GP_EI_F, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='EI', transferred=False, gap_function='linear', slope=slope)
    source_tasks_GP_EI_T, transfer_results_GP_EI_T, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='EI', transferred=True, gap_function='linear', slope=slope)
    source_tasks_GP_UCB_F, transfer_results_GP_UCB_F, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='UCB', transferred=False, gap_function='linear', slope=slope)
    source_tasks_GP_UCB_T, transfer_results_GP_UCB_T, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='UCB', transferred=True, gap_function='linear', slope=slope)
    source_tasks_GP_LCB_F, transfer_results_GP_LCB_F, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='LCB', transferred=False, gap_function='linear', slope=slope)
    source_tasks_GP_LCB_T, transfer_results_GP_LCB_T, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='LCB', transferred=True, gap_function='linear', slope=slope)
    source_tasks_GP_VR_F, transfer_results_GP_VR_F, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='VR', transferred=False, gap_function='linear', slope=slope)
    source_tasks_GP_VR_T, transfer_results_GP_VR_T, _, _ = model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='VR', transferred=True, gap_function='linear', slope=slope)
    
    selected_source_task_SoftImpute, performance_SoftImpute, _= matrix_estimation(data_transfer, num_transfer_steps, me_method='SoftImpute')
    selected_source_task_KNN, performance_KNN, _= matrix_estimation(data_transfer, num_transfer_steps, me_method='KNN')
    selected_source_task_MF, performance_MF, _= matrix_estimation(data_transfer, num_transfer_steps, me_method='MatrixFactorization')
    selected_source_task_SimpleFill, performance_SimpleFill, _= matrix_estimation(data_transfer, num_transfer_steps, me_method='SimpleFill')
    
    source_tasks_gttl = greedy_temporal_transfer_learning(deltas, num_transfer_steps, delta_min=delta_min, delta_max=delta_max)
    J_gttl = collect_J_matrix(data_transfer, source_tasks_gttl, deltas, num_transfer_steps)
    performance_greedy_temporal_transfer_learning = J_gttl.mean(axis=0)
    
    source_tasks_cttl5 = coarse_to_fine_temporal_transfer_learning(deltas, 5, delta_min, delta_max)
    source_tasks_cttl10 = coarse_to_fine_temporal_transfer_learning(deltas, 10, delta_min, delta_max)
    source_tasks_cttl15 = coarse_to_fine_temporal_transfer_learning(deltas, 15, delta_min, delta_max)

    J_cttl5 = collect_J_matrix(data_transfer, source_tasks_cttl5, deltas, 5)
    J_cttl10 = collect_J_matrix(data_transfer, source_tasks_cttl10, deltas, 10)
    J_cttl15 = collect_J_matrix(data_transfer, source_tasks_cttl15, deltas, 15)

    coarse_to_fine_transfer_training_5 = J_cttl5.mean(axis=0)
    coarse_to_fine_transfer_training_10 = J_cttl10.mean(axis=0)
    coarse_to_fine_transfer_training_15 = J_cttl15.mean(axis=0)

    source_tasks_fttl5 = fine_to_coarse_temporal_transfer_learning(deltas, 5, delta_min, delta_max)
    source_tasks_fttl10 = fine_to_coarse_temporal_transfer_learning(deltas, 10, delta_min, delta_max)
    source_tasks_fttl15 = fine_to_coarse_temporal_transfer_learning(deltas, 15, delta_min, delta_max)

    J_fttl5 = collect_J_matrix(data_transfer, source_tasks_fttl5, deltas, 5)
    J_fttl10 = collect_J_matrix(data_transfer, source_tasks_fttl10, deltas, 10)
    J_fttl15 = collect_J_matrix(data_transfer, source_tasks_fttl15, deltas, 15)

    fine_to_coarse_transfer_training_5 = J_fttl5.mean(axis=0)
    fine_to_coarse_transfer_training_10 = J_fttl10.mean(axis=0)
    fine_to_coarse_transfer_training_15 = J_fttl15.mean(axis=0)
    
    coarse_to_fine_transfer_training_5 = [*coarse_to_fine_transfer_training_5, *([coarse_to_fine_transfer_training_5[-1]] * (num_transfer_steps-5))]
    coarse_to_fine_transfer_training_10 = [*coarse_to_fine_transfer_training_10, *([coarse_to_fine_transfer_training_10[-1]] * (num_transfer_steps-10))]
    fine_to_coarse_transfer_training_5 = [*fine_to_coarse_transfer_training_5, *([fine_to_coarse_transfer_training_5[-1]] * (num_transfer_steps-5))]
    fine_to_coarse_transfer_training_10 = [*fine_to_coarse_transfer_training_10, *([fine_to_coarse_transfer_training_10[-1]] * (num_transfer_steps-10))]

    oracle_transfer = [data_transfer.max(axis=0).mean()] * num_transfer_steps
    
    data_transfer_diagonal = np.zeros(len(deltas))
    for i in range(len(deltas)):
        data_transfer_diagonal[i] = data_transfer.iloc[i][i]
        
    exhaustive_training = [data_transfer_diagonal.mean()] * num_transfer_steps
    
    sequential_oracle_training = []
    sot_deltas = []
    sot_deltas.append(data_transfer.mean(axis=1).argmax())
    sequential_oracle_training.append(data_transfer.iloc[data_transfer.mean(axis=1).argmax(),:].mean())
    for _ in range(num_transfer_steps-1):
        candidate_indices = [x for x in range(len(deltas)) if x not in sot_deltas]
        index_tmp = [data_transfer.T[sot_deltas+[i]].max(axis=1).mean() for i in candidate_indices].index(max([data_transfer.T[sot_deltas+[i]].max(axis=1).mean() for i in candidate_indices]))
        sot_deltas.append(candidate_indices[index_tmp])
        sequential_oracle_training.append(data_transfer.T[sot_deltas].max(axis=1).mean())

    sequential_marginal_oracle_training = []
    smot_deltas = []
    # 1st step
    smot_deltas.append(data_transfer.mean(axis=1).argmax())
    sequential_marginal_oracle_training.append(data_transfer.iloc[data_transfer.mean(axis=1).argmax(),:].mean())
    for _ in range(num_transfer_steps-1):
        candidate_indices = [x for x in range(len(deltas)) if x not in smot_deltas]
        index_tmp = [data_transfer.T[smot_deltas+[i]].max(axis=1).mean()-data_transfer.T[smot_deltas].max(axis=1).mean() for i in candidate_indices].index(max([data_transfer.T[smot_deltas+[i]].max(axis=1).mean()-data_transfer.T[smot_deltas].max(axis=1).mean() for i in candidate_indices]))
        smot_deltas.append(candidate_indices[index_tmp])
        sequential_marginal_oracle_training.append(data_transfer.T[smot_deltas].max(axis=1).mean())

    source_tasks_random = [np.random.choice(range(len(deltas)), num_transfer_steps, replace=False) for _ in range(100)]
    performance_random = [evaluate_on_task(data_transfer, source_task_random, deltas, num_transfer_steps) for source_task_random in source_tasks_random]

    performance_random_mean = []
    performance_random_std = []
    for j in range(num_transfer_steps):
        performance_random_mean.append(np.mean([performance_random[i][j] for i in range(100)]))
        performance_random_std.append(np.std([performance_random[i][j] for i in range(100)]))
        
        
    if data_name in ["advisoryautonomy_ring_acc", "advisoryautonomy_ring_vel", "advisoryautonomy_ramp_acc", "advisoryautonomy_ramp_vel", 
                     "advisoryautonomy_ring_acc_0", "advisoryautonomy_ring_acc_1", "advisoryautonomy_ring_acc_2", "advisoryautonomy_ring_acc_3", "advisoryautonomy_ring_acc_4", 
                     "advisoryautonomy_ring_vel_0", "advisoryautonomy_ring_vel_1", "advisoryautonomy_ring_vel_2", "advisoryautonomy_ring_vel_3", "advisoryautonomy_ring_vel_4",
                     "advisoryautonomy_ramp_acc_0", "advisoryautonomy_ramp_acc_1", "advisoryautonomy_ramp_acc_2", "advisoryautonomy_ramp_acc_3", "advisoryautonomy_ramp_acc_4",
                     "advisoryautonomy_ramp_vel_0", "advisoryautonomy_ramp_vel_1", "advisoryautonomy_ramp_vel_2", "advisoryautonomy_ramp_vel_3", "advisoryautonomy_ramp_vel_4"]:
        multitask_x = [5, 10, 15, 20]
        if data_name in ["advisoryautonomy_ring_acc", "advisoryautonomy_ring_acc_0", "advisoryautonomy_ring_acc_1", "advisoryautonomy_ring_acc_2", "advisoryautonomy_ring_acc_3", "advisoryautonomy_ring_acc_4"]:
            multitask_y = [2.89645838099182, 3.24553202733261, 3.36619166298357, 4.08805350213638]
            multitask_final = 4.14385668954332
        elif data_name in ["advisoryautonomy_ring_vel", "advisoryautonomy_ring_vel_0", "advisoryautonomy_ring_vel_1", "advisoryautonomy_ring_vel_2", "advisoryautonomy_ring_vel_3", "advisoryautonomy_ring_vel_4"]:
            multitask_y = [4.08000940654336, 4.21513941235835, 4.23068331816736, 4.38254562363065]
            multitask_final = 4.37412074898725
        elif data_name in ["advisoryautonomy_ramp_acc", "advisoryautonomy_ramp_acc_0", "advisoryautonomy_ramp_acc_1", "advisoryautonomy_ramp_acc_2", "advisoryautonomy_ramp_acc_3", "advisoryautonomy_ramp_acc_4"]:
            multitask_y = [4.5501157694424, 4.65746551337931, 5.09541004830744, 5.26686869050595]
            multitask_final = 4.52905413769073
        elif data_name in ["advisoryautonomy_ramp_vel", "advisoryautonomy_ramp_vel_0", "advisoryautonomy_ramp_vel_1", "advisoryautonomy_ramp_vel_2", "advisoryautonomy_ramp_vel_3", "advisoryautonomy_ramp_vel_4"]:
            multitask_y = [4.55376485404882, 4.65968705007209, 4.69240404506602, 4.75936752044232]
            multitask_final = 4.41948549907573
        else:
            multitask_y = [0, 0, 0, 0]
            
        for i in range(len(multitask_x)):
            multitask_y[i] = (multitask_y[i] - lower_bound) / (upper_bound - lower_bound)
            
        multitask_final = (multitask_final - lower_bound) / (upper_bound - lower_bound)
        multitask_final = [multitask_final] * num_transfer_steps
        
    # Save results
    result = pd.DataFrame(columns=range(num_transfer_steps))

    result.loc['LinUCB'] = transfer_results_LinUCB
    result.loc['GP_new'] = transfer_results_GP_new
    result.loc['GP_new_ucb'] = transfer_results_GP_new_ucb
    result.loc['GP_new_ucb_beta'] = transfer_results_GP_new_ucb_beta
    result.loc['GP_new_ucb_beta_log'] = transfer_results_GP_new_ucb_beta_log
    # result.loc['GP_new_ucb_dist'] = transfer_results_GP_new_ucb_dist
    result.loc['GP_new_ucb_transfer'] = transfer_results_GP_new_ucb_transfer
    result.loc['GP_EI_F'] = transfer_results_GP_EI_F
    result.loc['GP_EI_T'] = transfer_results_GP_EI_T
    result.loc['GP_UCB_F'] = transfer_results_GP_UCB_F
    result.loc['GP_UCB_T'] = transfer_results_GP_UCB_T
    result.loc['GP_LCB_F'] = transfer_results_GP_LCB_F
    result.loc['GP_LCB_T'] = transfer_results_GP_LCB_T
    result.loc['GP_VR_F'] = transfer_results_GP_VR_F
    result.loc['GP_VR_T'] = transfer_results_GP_VR_T
    result.loc['SoftImpute'] = performance_SoftImpute
    result.loc['KNN'] = performance_KNN
    result.loc['MF'] = performance_MF
    result.loc['SimpleFill'] = performance_SimpleFill
    result.loc['Pseudogreedy Strategy'] = performance_greedy_temporal_transfer_learning
    result.loc['Equidistant (C2F) Strategy (5)'] = coarse_to_fine_transfer_training_5
    result.loc['Equidistant (C2F) Strategy (10)'] = coarse_to_fine_transfer_training_10
    result.loc['Equidistant (C2F) Strategy (15)'] = coarse_to_fine_transfer_training_15
    result.loc['Equidistant (F2C) Strategy (5)'] = fine_to_coarse_transfer_training_5
    result.loc['Equidistant (F2C) Strategy (10)'] = fine_to_coarse_transfer_training_10
    result.loc['Equidistant (F2C) Strategy (15)'] = fine_to_coarse_transfer_training_15
    result.loc['Oracle Transfer'] = oracle_transfer
    result.loc['Exhaustive Training'] = exhaustive_training
    result.loc['Sequential Oracle Training'] = sequential_oracle_training
    result.loc['Sequential Marginal Oracle Training'] = sequential_marginal_oracle_training
    result.loc['Greedy Heuristic'] = performance_greedy
    result.loc['Greedy Heuristic Marginal'] = performance_greedy_marginal
    result.loc['Random_mean'] = performance_random_mean
    result.loc['Random_stdev'] = performance_random_std
    # add dataframe column that averages all columns until now
    result['Average'] = result.mean(axis=1)

    result.to_csv(f'{home_dir}/analysis/alg_raw_results_{data_name}.csv')
    
    
    # # Plot
    # plt.clf()
    # plt.rcParams['font.family'] = 'sans-serif'
    # plt.figure(figsize=(8,6))
    # # change font size
    # plt.rcParams.update({'font.size': 12})

    # plt.plot(range(1,num_transfer_steps+1), performance_random_mean, '--m.', label='Random')
    # plt.fill_between(range(1,num_transfer_steps+1), np.array(performance_random_mean)-np.array(performance_random_std), np.array(performance_random_mean)+np.array(performance_random_std), color='m', alpha=0.1)
    # plt.plot(range(1,num_transfer_steps+1), oracle_transfer, '--r.', label='Oracle Transfer')
    # plt.plot(range(1,num_transfer_steps+1), exhaustive_training, '--g.', label='Exhaustive Training')
    # plt.plot(range(1,num_transfer_steps+1), sequential_oracle_training, '--b.', label='Sequential Oracle Training')
    # plt.plot(range(1,num_transfer_steps+1), coarse_to_fine_transfer_training_15, label='CTTL (T-RO, $K=15$)', color='grey', linestyle='dashed', marker='.')
    # plt.plot(range(1,num_transfer_steps+1), performance_greedy_temporal_transfer_learning, '--k.', label='GTTL (T-RO)')
    # plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_UCB_T, '-c.', label='GP (UCB) with TL')
    # plt.plot(range(1,num_transfer_steps+1), performance_SimpleFill, '-r.', label='ME (SimpleFill)', alpha=0.5)
    # plt.plot(range(1,num_transfer_steps+1), performance_KNN, '-b.', label='ME (KNN)', alpha=0.5)
    # plt.plot(range(1,num_transfer_steps+1), performance_MF, '-y.', label='ME (MF)', alpha=0.5)
    # plt.plot(range(1,num_transfer_steps+1), performance_SoftImpute, '-c.', label='ME (SoftImpute)', alpha=0.5)
    # plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new, '-k.', label='GP (new)')
    # plt.legend()
    # # plt.ylim((0.5,1.05))
    # plt.legend(loc="lower right", fontsize=10)
    # plt.ylabel("Normalized aggregate performance")
    # plt.xlabel("Transfer steps")
    # plt.title(data_name)
    # plt.grid(color='gray', linestyle='dashed', alpha=0.5)
    # plt.savefig(home_dir+'/data/figure/result_me_'+data_name+'.png', dpi=600, bbox_inches="tight")

    # Plot
    plt.clf()
    plt.rcParams['font.family'] = 'sans-serif'
    plt.figure(figsize=(8,6))
    # change font size
    plt.rcParams.update({'font.size': 12})

    plt.plot(range(1,num_transfer_steps+1), performance_random_mean, '--m.', label='Random')
    plt.fill_between(range(1,num_transfer_steps+1), np.array(performance_random_mean)-np.array(performance_random_std), np.array(performance_random_mean)+np.array(performance_random_std), color='m', alpha=0.1)
    plt.plot(range(1,num_transfer_steps+1), transfer_results_LinUCB, '-m.', label='GTTL with LinUCB')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_EI_T, '-r.', label='GP (EI) with TL')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_VR_T, '-y.', label='GP (VR) with TL')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_UCB_T, '-c.', label='GP (UCB) with TL')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new, '-k.', label='GP (new)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb, '-k.', label='GP (new ucb)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb_beta, '-k.', label='GP (new ucb beta)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb_beta_log, '-k.', label='GP (new ucb beta log)')
    # plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb_dist, '-k.', label='GP (new ucb dist)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb_transfer, '-k.', label='GP (new ucb transfer)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_LCB_T, label='GP (LCB) with TL', color='orange', linestyle='solid', marker='.')
    # plt.plot(range(1,num_transfer_steps+1), performance_SimpleFill, '--r.', label='SimpleFill', alpha=0.5)
    # plt.plot(range(1,num_transfer_steps+1), performance_KNN, '--b.', label='KNN', alpha=0.5)
    plt.plot(range(1,num_transfer_steps+1), performance_MF, '--y.', label='ME (MF)', alpha=0.5)
    # plt.plot(range(1,num_transfer_steps+1), performance_SoftImpute, '--c.', label='SoftImpute', alpha=0.5)
    plt.plot(range(1,num_transfer_steps+1), oracle_transfer, '--r.', label='Oracle Transfer')
    plt.plot(range(1,num_transfer_steps+1), exhaustive_training, '--g.', label='Exhaustive Training')
    plt.plot(range(1,num_transfer_steps+1), sequential_oracle_training, '--b.', label='Sequential Oracle Training')
    plt.plot(range(1,num_transfer_steps+1), coarse_to_fine_transfer_training_15, label='CTTL (T-RO, $K=15$)', color='grey', linestyle='dashed', marker='.')
    plt.plot(range(1,num_transfer_steps+1), coarse_to_fine_transfer_training_10, label='CTTL (T-RO, $K=10$)', color='grey', linestyle='dotted', marker='.')
    plt.plot(range(1,num_transfer_steps+1), coarse_to_fine_transfer_training_5, label='CTTL (T-RO, $K=5$)', color='grey', linestyle='solid', marker='.')
    plt.plot(range(1,num_transfer_steps+1), performance_greedy_temporal_transfer_learning, '--k.', label='GTTL (T-RO)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new, '-k.', label='GP (new)')

    plt.xlim((0,num_transfer_steps+1))
    # plt.ylim((0.5,1.05))
    plt.legend(loc="lower right", fontsize=10)
    plt.ylabel("Normalized aggregate performance")
    plt.xlabel("Transfer steps")
    plt.title(data_name)
    plt.grid(color='gray', linestyle='dashed', alpha=0.5)
    plt.savefig(home_dir+'/data/figure/result_gp'+data_name+'.png', dpi=600, bbox_inches="tight")

    # # Plot
    # plt.clf()
    # plt.rcParams['font.family'] = 'sans-serif'
    # plt.figure(figsize=(8,6))
    # # change font size
    # plt.rcParams.update({'font.size': 12})

    # plt.plot(range(1,num_transfer_steps+1), performance_random_mean, '--m.', label='Random')
    # plt.fill_between(range(1,num_transfer_steps+1), np.array(performance_random_mean)-np.array(performance_random_std), np.array(performance_random_mean)+np.array(performance_random_std), color='m', alpha=0.1)
    # plt.plot(range(1,num_transfer_steps+1), oracle_transfer, '--r.', label='Oracle Transfer')
    # plt.plot(range(1,num_transfer_steps+1), exhaustive_training, '--g.', label='Exhaustive Training')
    # plt.plot(range(1,num_transfer_steps+1), sequential_oracle_training, '--b.', label='Sequential Oracle Training')
    # plt.plot(range(1,num_transfer_steps+1), coarse_to_fine_transfer_training_15, label='CTTL (T-RO, $K=15$)', color='grey', linestyle='dashed', marker='.')
    # plt.plot(range(1,num_transfer_steps+1), performance_greedy_temporal_transfer_learning, '--k.', label='GTTL (T-RO)')
    # # plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new, '-k.', label='GP (new)')
    # plt.legend()
    # # plt.ylim((0.5,1.05))
    # plt.legend(loc="lower right", fontsize=10)
    # plt.ylabel("Normalized aggregate performance")
    # plt.xlabel("Transfer steps")
    # plt.title(data_name)
    # plt.grid(color='gray', linestyle='dashed', alpha=0.5)
    # plt.savefig(home_dir+'/data/figure/result_t-ro_'+data_name+'.png', dpi=600, bbox_inches="tight")
    
    
    # Plot
    plt.clf()
    plt.rcParams['font.family'] = 'sans-serif'
    plt.figure(figsize=(8,6))
    # change font size
    plt.rcParams.update({'font.size': 12})

    plt.plot(range(1,num_transfer_steps+1), oracle_transfer, '--r.', label='Oracle Transfer')
    plt.plot(range(1,num_transfer_steps+1), exhaustive_training, '--g.', label='Exhaustive Training')
    plt.plot(range(1,num_transfer_steps+1), sequential_oracle_training, '--b.', label='Sequential Oracle Training')
    plt.plot(range(1,num_transfer_steps+1), coarse_to_fine_transfer_training_15, label='CTTL (T-RO, $K=15$)', color='grey', linestyle='dashed', marker='.')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new, '-k.', label='GP (new)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb, '-k.', label='GP (new ucb)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb_beta, '-k.', label='GP (new ucb beta)')
    # plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb_dist, '-k.', label='GP (new ucb dist)')
    plt.plot(range(1,num_transfer_steps+1), transfer_results_GP_new_ucb_transfer, '-k.', label='GP (new ucb transfer)')
    plt.plot(range(1,num_transfer_steps+1), performance_greedy_marginal, '-r.', label='Greedy (marginal)')
    plt.plot(range(1,num_transfer_steps+1), performance_greedy, '-b.', label='Greedy (max)')
    plt.plot(range(1,num_transfer_steps+1), performance_random_mean, '--m.', label='Random')
    if data_name in ["advisoryautonomy_ring_acc", "advisoryautonomy_ring_vel", "advisoryautonomy_ramp_acc", "advisoryautonomy_ramp_vel", 
                     "advisoryautonomy_ring_acc_0", "advisoryautonomy_ring_acc_1", "advisoryautonomy_ring_acc_2", "advisoryautonomy_ring_acc_3", "advisoryautonomy_ring_acc_4", 
                     "advisoryautonomy_ring_vel_0", "advisoryautonomy_ring_vel_1", "advisoryautonomy_ring_vel_2", "advisoryautonomy_ring_vel_3", "advisoryautonomy_ring_vel_4",
                     "advisoryautonomy_ramp_acc_0", "advisoryautonomy_ramp_acc_1", "advisoryautonomy_ramp_acc_2", "advisoryautonomy_ramp_acc_3", "advisoryautonomy_ramp_acc_4",
                     "advisoryautonomy_ramp_vel_0", "advisoryautonomy_ramp_vel_1", "advisoryautonomy_ramp_vel_2", "advisoryautonomy_ramp_vel_3", "advisoryautonomy_ramp_vel_4"]:
        plt.plot(multitask_x, multitask_y, '--y', marker='x', label='Multi-task Learning')
        plt.plot(range(1,num_transfer_steps+1), multitask_final, '--y.')
    plt.fill_between(range(1,num_transfer_steps+1), np.array(performance_random_mean)-np.array(performance_random_std), np.array(performance_random_mean)+np.array(performance_random_std), color='m', alpha=0.1)
    plt.plot(range(1,num_transfer_steps+1), performance_greedy_temporal_transfer_learning, '--k.', label='GTTL (T-RO)')
    plt.legend()
    # plt.ylim((0.5,1.05))
    plt.legend(loc="lower right", fontsize=10)
    plt.ylabel("Normalized aggregate performance")
    plt.xlabel("Transfer steps")
    plt.title(data_name)
    plt.grid(color='gray', linestyle='dashed', alpha=0.5)
    plt.savefig(home_dir+'/data/figure/result_short_'+data_name+'.png', dpi=600, bbox_inches="tight")
    
