from lpcmdp.env.FrozenLake import *
from lpcmdp.algorithm.utils import *
import argparse
from lpcmdp.algorithm.importance_sampling_solver import *
from test import importance_test
from lpcmdp.algorithm.BC import *
import pandas as pd
from lpcmdp.algorithm.coptidice import Coptidice_trainer


parser = argparse.ArgumentParser()
parser.add_argument('--Algorithm', type=str, default='Importance_Sampling')
parser.add_argument('--Solver', type=str, default='Approximate')
parser.add_argument('--cost_threshold', type=float, default=0)
parser.add_argument('--E_n_delta', type=float, default=0.1)
parser.add_argument('--E_n_delta_k', type=float, default=0.0001)
args = parser.parse_args()

env = FrozenLakeEnv(ncol=8, nrow=8)

theta = 0.01
expert = ValueIteration(env, theta, env.gamma)
expert.value_iteration()

if args.Algorithm == 'Importance_Sampling':
    offline_dataset, mu_D_count, r_s_a, c_s_a, P, mu_0 = env.getdataset(expert_pi=expert.pi, percent=1.0, alg=args.Algorithm)
    # 计算所需的参数
    mu_D_s_a = mu_D_count / mu_D_count.sum()
    M = np.zeros_like(P)
    
    for i in range(M.shape[0]):
        M[i, i*env.action_size:(i+1)*env.action_size] = np.ones(env.action_size)

    col_sums = P.sum(axis=0)
    P[:, col_sums != 0] /= col_sums[col_sums != 0]
    M = M - env.gamma*P
    K_D = M * mu_D_s_a
    u_D = r_s_a * mu_D_s_a
    h_D = c_s_a * mu_D_s_a
    if args.Solver == 'Discrete':
        w_s_a = Importance_Sampling_Discrete_Solver(env, u_D, h_D, K_D, mu_0, env.gamma, args.cost_threshold, args.E_n_delta, args.E_n_delta_k)
    else:
        Encode_type = 'one_hot' #'None' 'binary'
        w_s_a = Importance_Approximate(env, u_D, K_D, h_D, mu_0, env.gamma, args.cost_threshold, args.E_n_delta, args.E_n_delta_k, Encode_type)
        
    test_episode = 100
    plot_policy(env=env, para=w_s_a, type='w')
    goal_rate = importance_test(env, w_s_a=w_s_a, test_episode=test_episode)
    print(f'Test goal rate in {test_episode} test episodes of {args.Algorithm} is: {goal_rate}')

elif args.Algorithm == 'BC' or args.Algorithm == 'BC-Safe':
    
    offline_dataset = env.getdataset(alg=args.Algorithm, percent=1, expert_pi=expert.pi)
    print(offline_dataset.keys())
    Encode_type = 'one_hot' #'None' 'binary'
        
    action = torch.tensor(offline_dataset['action'], dtype=torch.float32)
    observation = torch.tensor(offline_dataset['observation'], dtype=torch.float32)
    new_observation = torch.tensor(offline_dataset['new_observation'], dtype=torch.float32)
    cost = torch.tensor(offline_dataset['cost'], dtype=torch.float32)
    reward = torch.tensor(offline_dataset['reward'], dtype=torch.float32)
    goal = torch.tensor(offline_dataset['goal'], dtype=torch.float32)
    hole = torch.tensor(offline_dataset['hole'], dtype=torch.float32)

    dataset = torch.utils.data.TensorDataset(observation, action, new_observation, cost, reward, goal, hole)
    BC_model = BC_Trainer(env, dataset, Encode_type)
    
    test_episode = 100
    plot_policy(env=env, para=BC_model, type='model')
    test(env, 'model', BC_model, random=True)
    #goal_rate = test(env, w_s_a, test_episode)
    # print(f'Test goal rate in {test_episode} test episodes of {args.Algorithm} is: {goal_rate}')
    
elif args.Algorithm == 'Coptidice':
    offline_dataset = env.getdataset(expert_pi=expert.pi, alg=args.Algorithm, percent=0.5)
    Encode_type = 'one_hot' #'None' 'binary
    
    action = torch.tensor(offline_dataset['action'], dtype=torch.float32)
    observation = torch.tensor(offline_dataset['observation'], dtype=torch.float32)
    new_observation = torch.tensor(offline_dataset['new_observation'], dtype=torch.float32)
    cost = torch.tensor(offline_dataset['cost'], dtype=torch.float32).view(-1, 1)
    reward = torch.tensor(offline_dataset['reward'], dtype=torch.float32).view(-1, 1)
    goal = torch.tensor(offline_dataset['goal'], dtype=torch.float32).view(-1, 1)
    hole = torch.tensor(offline_dataset['hole'], dtype=torch.float32).view(-1, 1)
    is_init = torch.tensor(offline_dataset['is_init'], dtype=torch.float32).view(-1, 1)
    
    dataset = torch.utils.data.TensorDataset(observation, action, reward, cost, new_observation, goal, hole, is_init)
    w_s_a = Coptidice_trainer(env, dataset, Encode_type)
    print(w_s_a)
    
    
        
