import gc
import os
import numpy as np
import torch
import random

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy

from lib.TsallisDQN import  MVIq, MunchausenTsallisDQN




envs = [
    'CartPole-v1',
    'LunarLander-v2',
    'Acrobot-v1',
    #'MountainCar-v0',
]

algs = {
    ' MVIq':  MVIq,
    'MunchausenTsallisDQN': MunchausenTsallisDQN,
}

num_seeds = 50

for env_id in envs:
    for alg_name, agent in algs.items():
        # modify the range to get the results in the paper                                
        for q in [2.0, 3.0, 4.0, 5.0]:
            
            for seed in range(num_seeds):

                random.seed(seed)
                torch.manual_seed(seed)
                np.random.seed(seed)

                total_timesteps = int(5e5)

                logdir = os.path.join(".", 'data', 'gym', env_id, alg_name, 'q' + str(q), str(seed), "")
                env = gym.make(env_id)

                param = {
                        'tensorboard_log':logdir,
                        'seed':seed,
                        'learning_starts':1e3,
                        'exploration_initial_eps':0.01,
                        'exploration_final_eps':0.01,
                        'learning_rate':0.001,
                        'buffer_size': int(5e4),
                        'batch_size':128,
                        'target_update_interval':1000,
                        'tau':0.99,
                        'policy_kwargs': dict(activation_fn=torch.nn.ReLU, net_arch=[512,512]),
                        'entropy_tau': 0.03,
                        'advantage_coef': 0.99,
                        'k': 1/2,
                        'q': q,
                    }
                
                if env_id != 'CartPole-v1':
                    param['target_update_interval'] = 2500

                model = agent('MlpPolicy', env, **param)
                print("running {} on env {} with seed {}".format(alg_name, env_id, seed))
                print(" ")
                model.learn(total_timesteps=total_timesteps)


                del model.replay_buffer
                gc.collect()

