import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

import os

import warnings
warnings.filterwarnings("ignore")

from scipy.stats import truncnorm

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ExpSineSquared, ConstantKernel as C

from fancyimpute import KNN, SoftImpute, MatrixFactorization, SimpleFill

def get_dataset_settings(home_dir, data_name):
    # Initialize settings dictionary
    settings = {}
    
    # Define settings for each dataset
    if data_name in ["advisoryautonomy_ring_acc", "advisoryautonomy_ring_vel"]:
        settings = {
            'lower_bound': 0,
            'unguided': 3.6,
            'upper_bound': 4.5,
            'delta_min': 0,
            'delta_max': 39,
            'slope': 0.005
        }
    elif data_name in ["advisoryautonomy_ring_acc_0", "advisoryautonomy_ring_acc_1", "advisoryautonomy_ring_acc_2", "advisoryautonomy_ring_acc_3", "advisoryautonomy_ring_acc_4", "advisoryautonomy_ring_vel_0", "advisoryautonomy_ring_vel_1", "advisoryautonomy_ring_vel_2", "advisoryautonomy_ring_vel_3", "advisoryautonomy_ring_vel_4"]:
        settings = {
            'lower_bound': 0,
            'unguided': 3.6,
            'upper_bound': 4.5,
            'delta_min': 0,
            'delta_max': 39,
            'slope': 0.005
        }
    elif data_name in ["advisoryautonomy_ramp_acc", "advisoryautonomy_ramp_vel"]:
        settings = {
            'lower_bound': 0,
            'unguided': 3.9299,
            'upper_bound': 8.78,
            'delta_min': 0,
            'delta_max': 39,
            'slope': 0.005
        }
    elif data_name in ["advisoryautonomy_ramp_acc_0", "advisoryautonomy_ramp_acc_1", "advisoryautonomy_ramp_acc_2", "advisoryautonomy_ramp_acc_3", "advisoryautonomy_ramp_acc_4", "advisoryautonomy_ramp_vel_0", "advisoryautonomy_ramp_vel_1", "advisoryautonomy_ramp_vel_2", "advisoryautonomy_ramp_vel_3", "advisoryautonomy_ramp_vel_4"]:
        settings = {
            'lower_bound': 0,
            'unguided': 3.9299,
            'upper_bound': 8.78,
            'delta_min': 0,
            'delta_max': 39,
            'slope': 0.005
        }
    elif data_name in ["advisoryautonomy_inter_acc", "advisoryautonomy_inter_vel"]:
        settings = {
            'lower_bound': 0,
            'unguided': 6.8379,
            'upper_bound': 7.7092,
            'delta_min': 1,
            'delta_max': 400,
            'slope': 0.005
        }
    elif data_name in ["intersection_flow", "intersection_speed", "intersection_length"]:
        settings = {
            'lower_bound': -100,
            'unguided': -100,
            'upper_bound': 0,
            'delta_min': 0,
            'delta_max': 19,
            'slope': 0.005
        }
    elif data_name in ["intersection_flow_0", "intersection_flow_1", "intersection_flow_2", "intersection_speed_0", "intersection_speed_1", "intersection_speed_2", "intersection_length_0", "intersection_length_1", "intersection_length_2"]:
        settings = {
            'lower_bound': -100,
            'unguided': -100,
            'upper_bound': 0,
            'delta_min': 0,
            'delta_max': 49,
            'slope': 0.001
        }
    elif data_name == 'cartpole_masscart':
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 500,
            'delta_min': 0,
            'delta_max': 48,
            'slope': 0.005
        }
    elif data_name == 'cartpole_masscart_lenpole':
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 500,
            'delta_min': 0,
            'delta_max': 49,
            'slope': 0.005
        }
    elif data_name in ['cartpole_lenpole', 'cartpole_masspole']:
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 500,
            'delta_min': 0,
            'delta_max': 98,
            'slope': 0.005
        }
    elif data_name in ['cartpole_lenpole_0','cartpole_lenpole_1','cartpole_lenpole_2','cartpole_masspole_0','cartpole_masspole_1','cartpole_masspole_2',"cartpole_masscart_0", "cartpole_masscart_1", "cartpole_masscart_2", "cartpole_lenpole_ppo_0", "cartpole_lenpole_ppo_1", "cartpole_lenpole_ppo_2", "cartpole_masspole_ppo_0", "cartpole_masspole_ppo_1", "cartpole_masspole_ppo_2", "cartpole_masscart_ppo_0", "cartpole_masscart_ppo_1", "cartpole_masscart_ppo_2", "cartpole_lenpole_a2c_0", "cartpole_lenpole_a2c_1", "cartpole_lenpole_a2c_2", "cartpole_masspole_a2c_0", "cartpole_masspole_a2c_1", "cartpole_masspole_a2c_2", "cartpole_masscart_a2c_0", "cartpole_masscart_a2c_1", "cartpole_masscart_a2c_2"]:
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 500,
            'delta_min': 0,
            'delta_max': 99,
            'slope': 0.005
        }
    elif data_name in ['cartpole_lenpole','cartpole_masspole',"cartpole_masscart","cartpole_lenpole_ppo","cartpole_masspole_ppo", "cartpole_masscart_ppo", "cartpole_lenpole_a2c", "cartpole_masspole_a2c", "cartpole_masscart_a2c"]:
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 500,
            'delta_min': 0,
            'delta_max': 99,
            'slope': 0.005
        }
    elif data_name == "lunarlander_mainenginepower":
        settings = {
            'lower_bound': -600,
            'unguided': -600,
            'upper_bound': 300,
            'delta_min': 1,
            'delta_max': 50,
            'slope': 0.005
        }
    elif data_name in ["pendulum_l", "pendulum_m", "pendulum_dt"]:
        settings = {
            'lower_bound': -2200,
            'unguided': -1300,
            'upper_bound': -500,
            'delta_min': 1,
            'delta_max': 50,
            'slope': 0.005
        }
    elif data_name in ["pendulum_dt_0", "pendulum_dt_1", "pendulum_dt_2", "pendulum_l_0", "pendulum_l_1", "pendulum_l_2", "pendulum_m_0", "pendulum_m_1", "pendulum_m_2"]:
        settings = {
            'lower_bound': -2200,
            'unguided': -1300,
            'upper_bound': -500,
            'delta_min': 0,
            'delta_max': 99,
            'slope': 0.005
        }
    elif data_name in ["walker_friction", "walker_gravity", "walker_scale", "walker_friction_0", "walker_friction_1", "walker_friction_2", "walker_gravity_0", "walker_gravity_1", "walker_gravity_2", "walker_scale_0", "walker_scale_1", "walker_scale_2"]:
        settings = {
            'lower_bound': -200,
            'unguided': -200,
            'upper_bound': 10,
            'delta_min': 0,
            'delta_max': 99,
            'slope': 0.005
        }        
    elif data_name in ["halfcheetah_friction", "halfcheetah_gravity", "halfcheetah_stiffness", 'halfcheetah_friction_0', 'halfcheetah_friction_1', 'halfcheetah_friction_2', 'halfcheetah_gravity_0', 'halfcheetah_gravity_1', 'halfcheetah_gravity_2', 'halfcheetah_stiffness_0', 'halfcheetah_stiffness_1', 'halfcheetah_stiffness_2']:
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 6000,
            'delta_min': 0,
            'delta_max': 99,
            'slope': 10
        }        
    elif data_name in ["no-stop_green", "no-stop_inflow", "no-stop_penrate", 'no-stop_green_0', 'no-stop_green_1', 'no-stop_green_2', 'no-stop_penrate_0', 'no-stop_penrate_1', 'no-stop_penrate_2', 'no-stop_inflow_0', 'no-stop_inflow_1', 'no-stop_inflow_2']:
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 15,
            'delta_min': 0,
            'delta_max': 49,
            'slope': 0.005
        }        
    elif data_name in ["ideal"]:
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 1,
            'delta_min': 1,
            'delta_max': 50,
            'slope': 0.005
        }
    elif data_name in ["ideal_n1","ideal_n2","ideal_n3","ideal_n4","ideal_n5","ideal_n6","ideal_n7","ideal_n8","ideal_n9","ideal_n10"]:
        settings = {
            'lower_bound': 0,
            'unguided': 0,
            'upper_bound': 1,
            'delta_min': 1,
            'delta_max': 100,
            'slope': 0.005
        }
    elif data_name in ["walker_friction", "walker_gravity", "walker_gravity_friction", "walker_friction_gravity", "walker_legh_legw", "walker_legw_legh"]:
        base_settings = {
            'lower_bound': -200,
            'unguided': -200,
            'upper_bound': 10,
            'slope': 0.005
        }
        # Customize delta_min and delta_max based on specific datasets
        if data_name in ["walker_friction"]:
            base_settings.update({'delta_min': 1, 'delta_max': 100})
        elif data_name in ["walker_gravity"]:
            base_settings.update({'delta_min': 1, 'delta_max': 20})
        elif data_name in ["walker_gravity_friction", "walker_friction_gravity"]:
            base_settings.update({'delta_min': 0, 'delta_max': 47})
        elif data_name in ["walker_legh_legw", "walker_legw_legh"]:
            base_settings.update({'delta_min': 1, 'delta_max': 16})
        settings = base_settings
    else:
        settings = {'error': "data not recognized"}
    
    return settings


def import_data(home_dir, data_name, random=False):
    settings = get_dataset_settings(home_dir, data_name)
    delta_min = settings['delta_min']
    delta_max = settings['delta_max']
    slope = settings['slope']
    lower_bound = settings['lower_bound']
    upper_bound = settings['upper_bound']
    unguided = settings['unguided']
    
    # make directory if it doesn't exist
    if not os.path.exists(home_dir+'/data/figure'):
        os.makedirs(home_dir+'/data/figure')
    if os.path.exists(home_dir+'/data/'+data_name+'_transfer_result.csv'):
        data_transfer = pd.read_csv(home_dir+'/data/'+data_name+'_transfer_result.csv', header=None)
        if 'intersection' in data_name:
            data_transfer = -data_transfer
    else:
        print("No data found for", data_name)
        
    deltas = data_transfer.columns.values.astype(float)
    
    data_transfer_norm = np.array(data_transfer)
    for i in range(data_transfer_norm.shape[0]):
        for j in range(data_transfer_norm.shape[1]):
            data_transfer_norm[i, j] = (data_transfer_norm[i, j] - lower_bound) / (upper_bound - lower_bound)
    
    data_transfer_norm = pd.DataFrame(data_transfer_norm)

    return data_transfer_norm, deltas, delta_min, delta_max, slope, lower_bound, upper_bound, unguided

# truncated normal distribution with mean=0, std=1, lower=0, upper=inf
def sample_truncated_normal(mean, std, lower, upper, size):
    a, b = (lower - mean) / std, (upper - mean) / std
    return truncnorm.rvs(a, b, loc=mean, scale=std, size=size)

class LinUCB:
    """
    LinUCB Algorithm for Contextual Multi-Armed Bandit Problem.

    Parameters:
    - n_arms (int): Number of arms.
    - n_features (int): Number of features in the context.
    - alpha (float): Confidence level parameter.

    Attributes:
    - A (list of numpy arrays): Inverse of covariance matrices for each arm.
    - b (list of numpy arrays): Sum of rewards times contexts for each arm.
    """
    def __init__(self, n_arms, n_features, alpha=1):
        self.n_arms = n_arms
        self.n_features = n_features
        self.alpha = alpha
        self.A = [np.eye(n_features) for _ in range(n_arms)]  # Inverse of covariance matrices
        self.b = [np.zeros((n_features, 1)) for _ in range(n_arms)]  # Sum of rewards times contexts

    def choose_arm(self, context):
        """
        Choose the best arm based on the current context.

        Parameters:
        - context (numpy array): The current context.

        Returns:
        - int: The index of the chosen arm.
        """
        p = np.zeros(self.n_arms)
        for i in range(self.n_arms):
            theta = self.A[i] @ self.b[i]
            p[i] = theta.T @ context + self.alpha * np.sqrt(context.T @ self.A[i] @ context)
        return np.argmax(p)

    def update(self, arm, context, reward):
        """
        Update the parameters based on the chosen arm and received reward.

        Parameters:
        - arm (int): The index of the chosen arm.
        - context (numpy array): The context when the arm was chosen.
        - reward (float): The reward received.
        """
        self.A[arm] -= (self.A[arm] @ np.outer(context, context) @ self.A[arm]) / \
                       (1 + context.T @ self.A[arm] @ context)  # Sherman-Morrison formula
        self.b[arm] += reward * context

    def predict_reward(self, arm, context):
        """
        Predict the reward for a given arm and context.

        Parameters:
        - arm (int): The index of the arm.
        - context (numpy array): The context.

        Returns:
        - float: The predicted reward.
        """
        theta = self.A[arm] @ self.b[arm]
        return theta.T @ context

    def get_params(self):
        """
        Get the parameters of the model.

        Returns:
        - tuple: A tuple containing the A and b parameters.
        """
        return (self.A, self.b)

    def set_params(self, params):
        """
        Set the parameters of the model.

        Parameters:
        - params (tuple): A tuple containing the A and b parameters.
        """
        self.A, self.b = params

    def reset(self):
        """
        Reset the parameters of the model to their initial state.
        """
        self.__init__(self.n_arms, self.n_features, self.alpha)
        
def test_linucb(data_transfer, n_arms, n_features, num_transfer_steps):
    n_arms = n_arms
    n_features = n_features
    linucb = LinUCB(n_arms, n_features)

    chosen_arms = []
    contexts = []
    rewards = []
    regrets = []
    cum_regrets = []

    for i in range(num_transfer_steps):
        if i ==0:
            context = np.zeros((n_features, 1))
        else:
            context = np.array(np.array(data_transfer.T[chosen_arms].T.max(axis=0)).reshape(-1, 1))
        # print(context)
        chosen_arm = linucb.choose_arm(context)
        reward = data_transfer[chosen_arm][chosen_arm]
        linucb.update(chosen_arm, context, reward)
        chosen_arms.append(chosen_arm)
        contexts.append(context)
        rewards.append(reward)
        regret = np.max(rewards) - reward
        regrets.append(regret)
        cum_regrets.append(np.sum(regrets))
    
    return chosen_arms, contexts, rewards, regrets, cum_regrets

def run_linucb(data_transfer, deltas, num_transfer_steps):
    chosen_arms, contexts, rewards, regrets, cum_regrets = test_linucb(data_transfer, len(data_transfer), len(data_transfer), num_transfer_steps)
    
    V_linucb = np.zeros((len(deltas), num_transfer_steps))

    for k in range(num_transfer_steps):
        for i in range(len(deltas)):
            if k == 0:
                V_linucb[i, k] = data_transfer.iloc[np.where(deltas == chosen_arms[k])[0][0]][i]
            else:
                V_linucb[i, k] = max(data_transfer.iloc[np.where(deltas == chosen_arms[k])[0][0]][i], V_linucb[i, k-1])
    transfer_results_LinUCB = V_linucb.mean(axis=0)
    
    return chosen_arms, transfer_results_LinUCB    

def greedy_heuristic_sts(data_transfer, deltas, num_transfer_steps, delta_min, delta_max, slope):
    source_tasks = np.zeros(num_transfer_steps)
    J_transfer = np.zeros((len(deltas), num_transfer_steps))
    for k in range(num_transfer_steps):
        if k == 0:
            tmp = (delta_max + delta_min)/2
        else:
            sorted_idx = np.argsort(J_transfer[:, k-1])
            for idx in range(len(sorted_idx)):
                if deltas[np.abs(deltas - sorted_idx[idx]).argmin()] in source_tasks:
                    continue
                else:
                    tmp = sorted_idx[idx]
                    break
        source_tasks[k] = deltas[np.abs(deltas - tmp).argmin()]
        for j in range(len(deltas)):
            if k==0:
                J_transfer[j, k] = data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j]
            else:
                J_transfer[j, k] = max(data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j], J_transfer[j, k-1])
        plt.plot(deltas, J_transfer[:, k], label='step {}'.format(k), c='C{}'.format(k))
        plt.plot(source_tasks[k], J_transfer[np.where(deltas == source_tasks[k])[0][0], k], 'o', c='C{}'.format(k))
    plt.legend(fontsize=10)
        
    return source_tasks, collect_J_matrix(data_transfer, source_tasks, deltas, num_transfer_steps=15).mean(axis=0)

def greedy_heuristic_sts_marginal(data_transfer, deltas, num_transfer_steps, delta_min, delta_max, slope):
    source_tasks = np.zeros(num_transfer_steps)
    J_transfer = np.zeros((len(deltas), num_transfer_steps))
    for k in range(num_transfer_steps):
        if k == 0:
            tmp = (delta_max + delta_min)/2
        else:
            U_transfer = np.zeros((len(deltas), len(deltas)))
            marginal_i = np.zeros(len(deltas))
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    U_transfer[i, j] = 1- slope*np.abs(deltas[i]-deltas[j])
                marginal_i[i] = sum([max(U_transfer[i, j] - J_transfer[j, k-1], 0) for j in range(len(deltas))])
            sorted_idx = np.argsort(-marginal_i)
            for idx in range(len(sorted_idx)):
                if deltas[np.abs(deltas - sorted_idx[idx]).argmin()] in source_tasks:
                    continue
                else:
                    tmp = sorted_idx[idx]
                    break
        source_tasks[k] = deltas[np.abs(deltas - tmp).argmin()]
        for j in range(len(deltas)):
            if k==0:
                J_transfer[j, k] = data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j]
            else:
                J_transfer[j, k] = max(data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j], J_transfer[j, k-1])
        plt.plot(deltas, J_transfer[:, k], label='step {}'.format(k), c='C{}'.format(k))
        plt.plot(source_tasks[k], J_transfer[np.where(deltas == source_tasks[k])[0][0], k], 'o', c='C{}'.format(k))
    plt.legend(fontsize=10)
        
    return source_tasks, collect_J_matrix(data_transfer, source_tasks, deltas, num_transfer_steps=15).mean(axis=0)

def greedy_temporal_transfer_learning(deltas, num_transfer_steps, delta_min=1, delta_max=50):
    fdelta = np.array([delta_min, delta_max])
    source_tasks = np.zeros(num_transfer_steps)
    for k in range(num_transfer_steps):
        if k==0:
            tmp = (delta_max + delta_min)/2
        else:
            fdelta_diff = np.diff(fdelta)
            if fdelta_diff.argmax() == 0:
                tmp = (2*fdelta[fdelta_diff.argmax()]+fdelta[fdelta_diff.argmax()+1])/3
            elif fdelta_diff.argmax() == len(fdelta)-2:
                tmp = (fdelta[fdelta_diff.argmax()]+2*fdelta[fdelta_diff.argmax()+1])/3
            else:
                tmp = (fdelta[fdelta_diff.argmax()]+fdelta[fdelta_diff.argmax()+1])/2
        source_tasks[k] = deltas[np.abs(deltas - tmp).argmin()]
        fdelta = np.append(fdelta, source_tasks[k])
        fdelta.sort()
    return source_tasks

def coarse_to_fine_temporal_transfer_learning(deltas, budgets, delta_min=1, delta_max=50):
    source_tasks = np.zeros(budgets)
    for k in range(budgets):
        tmp = delta_max - (delta_max - delta_min) / (2*budgets) - (delta_max - delta_min) * (k) / (budgets)
        # find the closest value in deltas
        source_tasks[k] = deltas[np.abs(deltas - tmp).argmin()]
    return source_tasks

def fine_to_coarse_temporal_transfer_learning(deltas, budgets, delta_min=1, delta_max=50):
    source_tasks = np.zeros(budgets)
    for k in range(budgets):
        tmp = delta_min + (delta_max - delta_min) / (2*budgets) + (delta_max - delta_min) * (k) / (budgets)
        # find the closest value in deltas
        source_tasks[k] = deltas[np.abs(deltas - tmp).argmin()]
    return source_tasks

def collect_J_matrix(data_transfer, source_tasks, deltas, num_transfer_steps=15):
    J_tmp = np.zeros((len(deltas), num_transfer_steps))
    for k in range(num_transfer_steps):
        for i in range(len(deltas)):
            if k==0:
                J_tmp[i, k] = data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][i]
            else:
                J_tmp[i, k] = max(data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][i], J_tmp[i, k-1])
    return J_tmp

def model_based_learning_with_GP(deltas, num_transfer_steps, data_transfer, lower_bound, upper_bound, acquisition_function='UCB', transferred=True, gap_function='linear', slope=5, marginal=False, noise_std=0.1, n_restarts_optimizer=9):
    source_tasks = []
    J_transfer = np.zeros((len(deltas), num_transfer_steps))
    V_estimate = np.zeros((len(deltas), len(deltas), num_transfer_steps))
    mean_prediction = np.zeros(len(deltas))
    std_prediction = np.zeros(len(deltas))
    
    delta_min = min(deltas)
    delta_max = max(deltas)

    for k in range(num_transfer_steps):
        if k==0:
            tmp = (delta_max + delta_min)/2
        else:
            kernel = 1 * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2))

            gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=noise_std**2, n_restarts_optimizer=n_restarts_optimizer)
            if transferred:
                gaussian_process.fit(np.array(source_tasks[:k]).reshape(-1, 1), np.array([V_estimate.mean(axis=1)[np.where(deltas==source_tasks[l])[0][0], l] for l in range(k)]))
            else:
                gaussian_process.fit(np.array(source_tasks[:k]).reshape(-1, 1), np.array([J_transfer[np.where(deltas==source_tasks[l])[0][0], l] for l in range(k)]))
            mean_prediction, std_prediction = gaussian_process.predict(deltas.reshape(-1, 1), return_std=True)

            current_state = np.array(V_obs_tmp.mean(axis=1)) if transferred else np.array([data_transfer[i][i] if i in source_tasks else 0 for i in range(len(deltas))])
            mean_prediction = np.maximum(mean_prediction - current_state, 0) if marginal else mean_prediction
            
            # diagonals of data_transfer
            data_transfer_diagonal = np.zeros(len(deltas))
            for i in range(len(deltas)):
                data_transfer_diagonal[i] = data_transfer.iloc[i][i]

            if acquisition_function == 'EI':
                acquisition = mean_prediction
            elif acquisition_function == 'UCB':
                acquisition = mean_prediction + 1.96*std_prediction
            elif acquisition_function == 'LCB':
                acquisition = mean_prediction - 1.96*std_prediction
            elif acquisition_function == 'VR':
                acquisition = std_prediction
                
            if acquisition_function != 'VR':
                acquisition = np.clip(acquisition, 0, 1)
            
            if transferred:
                sorted_idx = np.argsort(acquisition)
            else:
                if acquisition_function != 'VR':
                    sorted_idx = np.argsort(acquisition)
                else:
                    sorted_idx = np.argsort(-acquisition)

            for idx in range(len(sorted_idx)):
                if deltas[np.abs(deltas - sorted_idx[idx]).argmin()] in source_tasks:
                    continue
                else:
                    tmp = sorted_idx[idx]
                    break
        source_tasks.append(deltas[np.abs(deltas - tmp).argmin()])

        for j in range(len(deltas)):
            if k==0:
                J_transfer[j, k] = data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j]
            else:
                J_transfer[j, k] = max(data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j], J_transfer[j, k-1])

        V_estimate_tmp = np.zeros((len(deltas), len(deltas)))
        for i in range(len(deltas)):
            V_estimate_tmp[i, i] = mean_prediction[i]
        for l in range(k+1):
            idx = np.where(deltas==source_tasks[l])[0][0]
            V_estimate_tmp[idx, idx] = J_transfer[idx, l]
        
        slopes = {}
        for i in range(len(deltas)):
            if deltas[i] in source_tasks:
                x = []
                y = []
                for j in range(len(deltas)):
                    x.append(abs(i-j))
                    y.append(abs(data_transfer.iloc[i,j]-data_transfer.iloc[i,i]))
                slopes[i] = np.polyfit(x, y, 1)[0]

        if gap_function == 'linear':
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    V_estimate_tmp[i, j] = max(V_estimate_tmp[i, i] - slope*abs(i-j), 0)
                    V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
        elif gap_function == 'linear_estimated':
            for i in range(len(deltas)):
                slope = slopes[np.array(list(slopes.keys()))[np.abs(np.array(list(slopes.keys()))- i).argmin()]]
                for j in range(len(deltas)):
                    V_estimate_tmp[i, j] = max(V_estimate_tmp[i, i] - slope*abs(i-j), 0)
                    V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
        elif gap_function == 'true':
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    V_estimate_tmp[i, j] = data_transfer.iloc[i, j]
                    V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
        else:
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    if i == j:
                        V_estimate_tmp[i, j] = V_estimate_tmp[i, i]
                    else:
                        V_estimate_tmp[i, j] = 0
                    V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
                    
        V_obs_tmp = np.zeros((len(deltas), len(deltas)))
        for i in range(len(deltas)):
            if deltas[i] in source_tasks:
                for j in range(len(deltas)):
                    if i == j:
                        V_obs_tmp[i, j] = data_transfer.iloc[i, i]
                    else:
                        V_obs_tmp[i, j] = data_transfer.iloc[i, j]
    transfer_results = J_transfer.mean(axis=0)

    return source_tasks, transfer_results, J_transfer, V_estimate

def model_based_learning_with_GP_new(home_dir, data_name, deltas, num_transfer_steps, data_transfer, acquisition_function='new', gap_function='linear', slope=0.005, noise_std=0.1, n_restarts_optimizer=9):
    # collection of source task
    source_tasks = []
    # J_transfer[i,k]: i is the index of the target task, k is the index of the step
    J_transfer = np.zeros((len(deltas), num_transfer_steps))
    # V_estimate[i,j,k]: i is the index of the source task, j is the index of the target task, k is the index of the step
    V_estimate = np.zeros((len(deltas), len(deltas), num_transfer_steps))
    # mean_prediction[i], std_prediction[i]: i is the index of the source task
    mean_prediction = np.zeros(len(deltas))
    std_prediction = np.zeros(len(deltas))
    # V_obs_tmp[i,j]: i is the index of the source task, j is the index of the target task
    V_obs_tmp = np.zeros((len(deltas), len(deltas)))
    # V_estimate_tmp[i,j]: i is the index of the source task, j is the index of the target task
    V_estimate_tmp = np.zeros((len(deltas), len(deltas)))
    
    delta_min = min(deltas)
    delta_max = max(deltas)

    for k in range(num_transfer_steps):
        if k==0:
            tmp = (delta_max + delta_min)/2
        else:
            kernel = C(1.0, (1e-3, 1e3)) * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2))

            gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=noise_std**2, n_restarts_optimizer=n_restarts_optimizer)
            if acquisition_function == 'new_ucb_transfer':
                gaussian_process.fit(np.array(source_tasks[:k]).reshape(-1, 1), np.array([V_estimate.mean(axis=1)[np.where(deltas==source_tasks[l])[0][0], l] for l in range(k)]))
            else:
                gaussian_process.fit(np.array(source_tasks[:k]).reshape(-1, 1), np.array([J_transfer[np.where(deltas==source_tasks[l])[0][0], l] for l in range(k)]))
            mean_prediction, std_prediction = gaussian_process.predict(deltas.reshape(-1, 1), return_std=True)
            
            # diagonals of data_transfer
            data_transfer_diagonal = np.zeros(len(deltas))
            for i in range(len(deltas)):
                data_transfer_diagonal[i] = data_transfer.iloc[i][i]
            
            if acquisition_function == 'EI':
                acquisition = mean_prediction
            elif acquisition_function == 'UCB':
                acquisition = mean_prediction + 1.96*std_prediction
            elif acquisition_function == 'LCB':
                acquisition = mean_prediction - 1.96*std_prediction
            elif acquisition_function == 'VR':
                acquisition = std_prediction
            elif acquisition_function == 'new':
                # calculate new acquisition function
                new_acquisition = np.zeros(len(deltas))
                for i in range(len(deltas)):
                    new_acquisition[i] = np.mean([max(mean_prediction[i] - slope*np.abs(deltas[i]-deltas[j]) - V_obs_tmp.max(axis=0)[j], 0) for j in range(len(deltas))])
                acquisition = new_acquisition.copy()
            elif acquisition_function == 'new_ucb':
                new_acquisition = np.zeros(len(deltas))
                lambdas = [1]*len(deltas)
                for i in range(len(deltas)):
                    new_acquisition[i] = np.mean([max(mean_prediction[i] + lambdas[i]*std_prediction[i] - slope*np.abs(deltas[i]-deltas[j]) - V_obs_tmp.max(axis=0)[j], 0) for j in range(len(deltas))])
                acquisition = new_acquisition.copy()
            elif acquisition_function == 'new_ucb_beta':
                new_acquisition = np.zeros(len(deltas))
                # lambdas to be list starting from 1 to 0 decaying over time
                lambdas = [1/(k+1)]*len(deltas)
                for i in range(len(deltas)):
                    new_acquisition[i] = np.mean([max(mean_prediction[i] + lambdas[i]*std_prediction[i] - slope*np.abs(deltas[i]-deltas[j]) - V_obs_tmp.max(axis=0)[j], 0) for j in range(len(deltas))])
                acquisition = new_acquisition.copy()
            elif acquisition_function == 'new_ucb_beta_log':
                new_acquisition = np.zeros(len(deltas))
                # lambdas to sqrt(log(k+1))
                lambdas = [np.sqrt(np.log(k+1))]*len(deltas)
                for i in range(len(deltas)):
                    new_acquisition[i] = np.mean([max(mean_prediction[i] + lambdas[i]*std_prediction[i] - slope*np.abs(deltas[i]-deltas[j]) - V_obs_tmp.max(axis=0)[j], 0) for j in range(len(deltas))])
                acquisition = new_acquisition.copy()
            elif acquisition_function == 'new_ucb_dist':
                new_acquisition = np.zeros(len(deltas))
                lambda_set = sample_truncated_normal(mean=0, std=1, lower=0, upper=np.inf, size=500)
                for lambda_tmp in lambda_set:
                    for i in range(len(deltas)):
                        new_acquisition[i] += np.mean([max(mean_prediction[i] + lambda_tmp*std_prediction[i] - slope*np.abs(deltas[i]-deltas[j]) - V_obs_tmp.max(axis=0)[j], 0) for j in range(len(deltas))])
                acquisition = new_acquisition.copy()
            elif acquisition_function == 'new_ucb_transfer':
                new_acquisition = np.zeros(len(deltas))
                lambdas = [1/(k+1)]*len(deltas)
                for i in range(len(deltas)):
                    new_acquisition[i] = np.mean([max(mean_prediction[i] + lambdas[i]*std_prediction[i] - V_obs_tmp.max(axis=0)[j], 0) for j in range(len(deltas))])
                acquisition = new_acquisition.copy()
            
            # find the next source task that maximizes acquisition funciton and is not in the source_tasks
            sorted_idx = np.argsort(-acquisition)

            for idx in range(len(sorted_idx)):
                if deltas[np.abs(deltas - sorted_idx[idx]).argmin()] in source_tasks:
                    continue
                else:
                    tmp = sorted_idx[idx]
                    break
            
        source_tasks.append(deltas[np.abs(deltas - tmp).argmin()])

        # Update J_transfer based on the new source task training
        for j in range(len(deltas)):
            if k==0:
                J_transfer[j, k] = data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j]
            else:
                J_transfer[j, k] = max(data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j], J_transfer[j, k-1])

        # Calculate V_estimate for diagonal elements
        V_estimate_tmp = np.zeros((len(deltas), len(deltas)))
        for i in range(len(deltas)):
            V_estimate_tmp[i, i] = mean_prediction[i]
        for l in range(k+1):
            idx = np.where(deltas==source_tasks[l])[0][0]
            V_estimate_tmp[idx, idx] = J_transfer[idx, l]

        # Calculate V_estimate for non-diagonal elements with different gap functions
        if gap_function == 'linear':
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    V_estimate_tmp[i, j] = max(V_estimate_tmp[i, i] - slope*abs(i-j), 0)
                    # V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
                    V_estimate[i, j, k] = max(V_obs_tmp[i, j], V_estimate_tmp[i, j])
        elif gap_function == 'linear_estimated':
            # slope estimation based on trained source tasks' generalization slope
            slopes = {}
            for i in range(len(deltas)):
                if deltas[i] in source_tasks:
                    x = []
                    y = []
                    for j in range(len(deltas)):
                        x.append(abs(i-j))
                        y.append(abs(data_transfer.iloc[i,j]-data_transfer.iloc[i,i]))
                    slopes[i] = np.polyfit(x, y, 1)[0]
            for i in range(len(deltas)):
                slope = slopes[np.array(list(slopes.keys()))[np.abs(np.array(list(slopes.keys()))- i).argmin()]]
                for j in range(len(deltas)):
                    V_estimate_tmp[i, j] = max(V_estimate_tmp[i, i] - slope*abs(i-j), 0)
                    V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
        elif gap_function == 'true':
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    V_estimate_tmp[i, j] = data_transfer.iloc[i, j]
                    V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
        else:
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    if i == j:
                        V_estimate_tmp[i, j] = V_estimate_tmp[i, i]
                    else:
                        V_estimate_tmp[i, j] = 0
                    V_estimate[i, j, k] = max(V_estimate[i, j, k-1], V_estimate_tmp[i, j])
                    
        V_obs_tmp = np.zeros((len(deltas), len(deltas)))
        for i in range(len(deltas)):
            if deltas[i] in source_tasks:
                for j in range(len(deltas)):
                    if i == j:
                        V_obs_tmp[i, j] = data_transfer.iloc[i, i]
                    else:
                        V_obs_tmp[i, j] = data_transfer.iloc[i, j]
    transfer_results = J_transfer.mean(axis=0)

    return source_tasks, transfer_results, J_transfer, V_estimate


def greedy_heursitc_sts(data_transfer, deltas, num_transfer_steps, delta_min, delta_max, slope):
    source_tasks = np.zeros(num_transfer_steps)
    J_transfer = np.zeros((len(deltas), num_transfer_steps))
    for k in range(num_transfer_steps):
        if k == 0:
            tmp = (delta_max + delta_min)/2
        else:
            sorted_idx = np.argsort(J_transfer[:, k-1])
            for idx in range(len(sorted_idx)):
                if deltas[np.abs(deltas - sorted_idx[idx]).argmin()] in source_tasks:
                    continue
                else:
                    tmp = sorted_idx[idx]
                    break
        source_tasks[k] = deltas[np.abs(deltas - tmp).argmin()]
        for j in range(len(deltas)):
            if k==0:
                J_transfer[j, k] = data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j]
            else:
                J_transfer[j, k] = max(data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j], J_transfer[j, k-1])
        plt.plot(deltas, J_transfer[:, k], label='step {}'.format(k))
    
    transfer_results = collect_J_matrix(data_transfer, source_tasks, deltas, num_transfer_steps=15).mean(axis=0)
        
    return source_tasks, transfer_results

def greedy_heursitc_sts_marginal(data_transfer, deltas, num_transfer_steps, delta_min, delta_max, slope):
    source_tasks = np.zeros(num_transfer_steps)
    J_transfer = np.zeros((len(deltas), num_transfer_steps))
    for k in range(num_transfer_steps):
        if k == 0:
            tmp = (delta_max + delta_min)/2
        else:
            U_transfer = np.zeros((len(deltas), len(deltas)))
            marginal_i = np.zeros(len(deltas))
            for i in range(len(deltas)):
                for j in range(len(deltas)):
                    U_transfer[i, j] = 1- slope*np.abs(deltas[i]-deltas[j])
                marginal_i[i] = sum([max(U_transfer[i, j] - J_transfer[j, k-1], 0) for j in range(len(deltas))])
            sorted_idx = np.argsort(-marginal_i)
            for idx in range(len(sorted_idx)):
                if deltas[np.abs(deltas - sorted_idx[idx]).argmin()] in source_tasks:
                    continue
                else:
                    tmp = sorted_idx[idx]
                    break
        source_tasks[k] = deltas[np.abs(deltas - tmp).argmin()]
        for j in range(len(deltas)):
            if k==0:
                J_transfer[j, k] = data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j]
            else:
                J_transfer[j, k] = max(data_transfer.iloc[np.where(deltas == source_tasks[k])[0][0]][j], J_transfer[j, k-1])
        plt.plot(deltas, J_transfer[:, k], label='step {}'.format(k))
        
    transfer_results = collect_J_matrix(data_transfer, source_tasks, deltas, num_transfer_steps=15).mean(axis=0)
        
    return source_tasks, transfer_results

def approximate_rank(matrix, threshold=0.99):
    """
    Calculate the approximate rank of a matrix based on the given threshold.

    Parameters:
    - matrix: Input matrix for which the rank is to be approximated.
    - threshold: The threshold for variance capture (default is 0.99).

    Returns:
    - rank: The approximate rank of the matrix.
    """
    # Perform Singular Value Decomposition (SVD)
    _, singular_values, _ = np.linalg.svd(matrix)

    # Calculate the total variance
    total_variance = np.sum(singular_values**2)

    # Calculate the cumulative variance
    cumulative_variance = np.cumsum(singular_values**2)
    # Find the index where the cumulative variance surpasses the threshold
    rank_index = np.argmax(cumulative_variance / total_variance >= threshold)

    # The approximate rank is the number of singular values needed to reach the threshold
    rank = rank_index + 1

    return rank

def matrix_estimation(data_origin, estimation_trials, me_method='SoftImpute'):
    data = data_origin.copy()
    data = data.values
    df_min = data.min()
    df_max = data.max()
    
    # normalize data
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            data[i, j] = (data[i, j] - df_min) / (df_max - df_min)
    
    estimation_trials = estimation_trials
    knn_impute_k = 5
    
    apprx_rank = approximate_rank(data)
    
    PLOT = False

    selected_source_tasks = []
    norm_average_performance_list = []
    average_performance_list = []
    average_softImpute_mse = 0
    for trial in range(estimation_trials):
        # create a mask of only a single random row available
        missing_mask = np.zeros(data.shape)
        # randomly set available_data amount of rows out of all rows to be 0
        if trial == 0:
            available_index = int(data.shape[0]/2)
        else:
            data_tmp = data_filled.copy()
            for i in selected_source_tasks:
                data_tmp[i, :] = df_min
            available_index = np.argmax(data_tmp.mean(axis=1))
        selected_source_tasks.append(available_index)
        for i in selected_source_tasks:
            missing_mask[i, :] = 1
        missing_mask = missing_mask.astype(bool)
            
        data_incomplete = data*missing_mask
        # replace all 0s with NaNs
        data_incomplete[data_incomplete == 0] = np.nan

        if me_method == 'SoftImpute':
            data_filled = SoftImpute(max_iters=1000, min_value=df_min, max_value=df_max, max_rank=apprx_rank, verbose=False).fit_transform(data_incomplete)
        elif me_method == 'KNN':
            data_filled = KNN(k=knn_impute_k, min_value=df_min, max_value=df_max, orientation="columns", verbose=False).fit_transform(data_incomplete)
        elif me_method == 'MatrixFactorization':
            data_filled = MatrixFactorization(rank=50, max_iters=500, min_value=df_min, max_value=df_max, shrinkage_value=0.001, verbose=False).fit_transform(data_incomplete)
        elif me_method == 'SimpleFill':
            data_filled = SimpleFill(fill_method="mean", min_value=df_min, max_value=df_max).fit_transform(data_incomplete)
        else:
            raise ValueError("Invalid matrix estimation method")
        # print("Selected sources: %s" % selected_source_tasks)
        
        softImpute_mse = ((data_filled - data)**2).mean()
        # print("Estimation ME based on SoftImpute: %f" % softImpute_mse)
        average_softImpute_mse += softImpute_mse

        data_selected = np.zeros(data.shape)
        for i in selected_source_tasks:
            data_selected[i, :] = data[i, :]
        
        norm_average_performance_list.append(data_selected.max(axis=0).mean())
        # print("Average performance: %f" % (data_selected.max(axis=0).mean()))
    
    if PLOT:
        plt.imshow(data_filled, interpolation='none')
        plt.title(f"{me_method} at trial {trial}")
        plt.show()
    
    for i in range(len(norm_average_performance_list)):
        average_performance_list.append(norm_average_performance_list[i]*(df_max - df_min) + df_min)
    
    return selected_source_tasks, average_performance_list, average_softImpute_mse/estimation_trials

def evaluate_on_task(data_transfer, source_tasks, deltas, num_transfer_steps):
    assert len(source_tasks) == num_transfer_steps
    return collect_J_matrix(data_transfer, source_tasks, deltas, num_transfer_steps).mean(axis=0)

def mean_of_list(list):
    return sum(list)/len(list)

def plot_heatmap(data_transfer, home, data_name):
    plt.clf()
    plt.rcParams['font.family'] = 'sans-serif'
    plt.figure(figsize=(8,6))
    
    plt.rcParams.update({'font.size': 12})
    plt.imshow(data_transfer, interpolation='none')
    plt.colorbar(orientation='vertical')
    plt.xlabel("Target task")
    plt.ylabel("Source task")
    plt.savefig(home+f'/data/heatmap_{data_name}.png', bbox_inches="tight", dpi=500)
    
