import gc
import gym
import os

from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.env_util import make_atari_env

from lib.MunchausenQRDQN import  MVIq, TsallisQRDQN, MunchausenQRDQN


if __name__ == '__main__':

    alg_names = {
       ' MVIq':  MVIq,
       'MDQN':MunchausenQRDQN,
       'TsallisQRDQN': TsallisQRDQN,
    }

    project_name = 'distributional_rl'    

    num_steps = 5e7

    # only a subset of games is shown here as example
    envs = [
        'Assault',
        'Asterix',
        'Atlantis',
        'BeamRider', 
        'Breakout',
        'Enduro',
        'MsPacman', 
        'Pong',
        'Seaquest', 
        'SpaceInvaders',
        'Jamesbond',
        'Frostbite',
        'Zaxxon',
        'KungFuMaster',
        ]    


    
    for env_name in envs:
        for alg_name, agent in alg_names.items():

            for seed in [0]:
                game_name = env_name + 'NoFrameskip-v4'
                log_dir = os.path.join(".", 'data', 'distributional', game_name, alg_name, str(seed),"")
                os.makedirs(log_dir, exist_ok=True)

                env = make_atari_env(game_name, seed=seed, monitor_dir=log_dir) 
                env = VecFrameStack(env, n_stack=4)

                tb_log_folder = log_dir
                        
                model = agent('CnnPolicy', env, verbose=0, tensorboard_log=tb_log_folder)

                print(" ")
                print("running {} on env {} with seed {}".format(alg_name, game_name, seed))
                print(" ")

                model.learn(total_timesteps=num_steps, tb_log_name=log_dir)
                
                # free memory of replay buffer
                model.replay_buffer.reset()
                del model.replay_buffer
                del model 
                gc.collect()

