from comet_ml import Experiment

from rlpyt.utils.launching.affinity import make_affinity, quick_affinity_code, affinity_from_code
from rlpyt.samplers.async_.gpu_sampler import AsyncGpuSampler
from rlpyt.samplers.parallel.gpu.sampler import GpuSampler

from rlpyt.samplers.parallel.gpu.collectors import GpuWaitResetCollector
from rlpyt.samplers.async_.collectors import DbGpuResetCollector, DbCpuResetCollector

from rlpyt.samplers.serial.sampler import SerialSampler
from rlpyt.samplers.serial.collectors import SerialEvalCollector

from rlpyt.envs.atari.atari_env import AtariEnv, AtariTrajInfo

from rlpyt.algos.dqn.cat_dqn import CategoricalDQN
from rlpyt.agents.dqn.atari.atari_catdqn_agent import AtariCatDqnAgent
from rlpyt.agents.dqn.atari.atari_dqn_agent import AtariDqnAgent
from rlpyt.models.dqn.atari_catdqn_model import AtariCatDqnModel
from rlpyt.runners.async_rl import AsyncRlEval
from rlpyt.runners.minibatch_rl import MinibatchRlEval

from rlpyt.utils.logging import logger
from rlpyt.utils.logging.context import logger_context
from rlpyt.experiments.configs.atari.dqn.atari_dqn import configs


from nce_model import AtariCatDqnModel_nce
from nce_algo import CategoricalDQN_nce
from nce_override import log_diagnostics_custom, _log_infos, make_env

import os
from gym import envs

import torch

from RLDIM.utils import set_seed
from RLDIM.loggers import Logger

from decimal import Decimal

import uuid
import numpy as np

import multiprocessing

"""
Code for "Deep InfoMax and Reinforcement Learning"

To execute:
python procgen_parallel_gpu.py  --lambda_LL "0" --lambda_GL "0" --lambda_LG "0" --lambda_GG "1" --experiment-name "test" --env-name "procgen-bigfish-v0.500" \
                           --data_aug "0" --n_step-return "7" --nce-batch-size "256" --horizon "10000" --algo "c51" --n-cpus "8" --weight-save-interval "-1" --n_step-nce "-1" \
                           --frame_stack "3" --ema_moco "0" --nce_loss "InfoNCE_action_loss" --architecture "Mnih" --log-interval-steps=1000
"""


"""
The below config has been taken directly from RLPYT (https://bair.berkeley.edu/blog/2019/09/24/rlpyt/, https://github.com/astooke/rlpyt)
"""
config = dict(
    agent=dict(
        eps_init=0.1,
        eps_final=0.01
        ),
    algo=dict(
        discount=0.99,
        batch_size=256,
        delta_clip=1.,
        learning_rate=2.5e-4,
        target_update_interval=int(312),
        clip_grad_norm=10.,
        min_steps_learn=int(1000),
        target_update_tau=0.95, # tau * new + (1-tau) * old

        double_dqn=True,
        prioritized_replay=True,
        n_step_return=3, # -> 1

        replay_size=int(1e6),
        replay_ratio=8, # -> 8,

        pri_alpha=0.5,
        pri_beta_init=0.4,
        pri_beta_final=1.,
        pri_beta_steps=int(50e6),

        eps_steps=int(1e5)
    ),
    env=dict(
        game=None,
        episodic_lives=False,
        clip_reward=False,
        horizon=int(27e3),
        max_start_noops=0,
        repeat_action_probability=0.,
        frame_skip=1,
        num_img_obs=4
    ),
    eval_env=dict(
        game=None,
        episodic_lives=False,
        horizon=int(27e3),
        clip_reward=False,
        max_start_noops=0,
        repeat_action_probability=0.,
        frame_skip=1,
        num_img_obs=4
    ),
    model=dict(dueling=False),
    optim=dict(),
    runner=dict(
        n_steps=200e6,
        log_interval_steps=1e3,
    ),
    sampler=dict(
        batch_T=7,
        batch_B=32,
        max_decorrelation_steps=1000,
        eval_n_envs=4,
        eval_max_steps=int(125e3),
        eval_max_trajectories=100,
    ),
)

def build_and_train(args,game="pong", run_ID=0,config=None):
    if game is not None:
        config['env']['game'] = game
        config["eval_env"]["game"] = config["env"]["game"]
    else:
        config['env']['id'] = args.env_name
        config["eval_env"]["id"] = args.env_name

    config["eval_env"]["horizon"] = args.horizon
    config["env"]["horizon"] = args.horizon

    if 'procgen' in args.env_name:
        for k,v in vars(args).items():
            if args.env_name.split('-')[1] in k:
                config['env'][k] = v

    config['model']['architecture'] = args.architecture
    config['model']['downsample'] = args.downsample
    config['model']['frame_stack'] = args.frame_stack
    config['model']['nce_loss'] = args.nce_loss
    config['model']['algo'] = args.algo
    config['model']['data_aug'] = args.data_aug == 1

    config['model']['dueling'] = args.dueling == 1
    config['algo']['double_dqn'] = args.double_dqn == 1
    config['algo']['prioritized_replay'] = args.prioritized_replay == 1
    config['algo']['n_step_return'] = args.n_step_return
    config['algo']['learning_rate'] = args.learning_rate

    config['runner']['log_interval_steps'] = args.log_interval_steps
    config['cmd_args'] = vars(args)
    # Change these inputs to match local machine and desired parallelism.

    if 'c51' in args.algo:
        agent = AtariCatDqnAgent(ModelCls=AtariCatDqnModel_nce,model_kwargs=config["model"], **config["agent"])
        algo = CategoricalDQN_nce(
        args=config['cmd_args'],
        ReplayBufferCls=None,
        optim_kwargs=config["optim"], **config["algo"]
    )

    if args.mode == 'parallel':
        affinity = make_affinity(
                    n_cpu_core=args.n_cpus,
                    n_gpu=1,
                    n_socket=1
                    # hyperthread_offset=0
                )
        """
        Depending on your system, you might want to comment out or un-comment the following chunk of code
        """
        import psutil
        psutil.Process().cpu_affinity([])
        cpus = tuple(psutil.Process().cpu_affinity())
        affinity['all_cpus'] = affinity['master_cpus'] = cpus
        affinity['workers_cpus'] = tuple([tuple([x]) for x in cpus+cpus])
        env_kwargs = config['env']

        sampler = GpuSampler(
                    EnvCls=AtariEnv if args.game is not None else make_env,
                    env_kwargs=config["env"],
                    CollectorCls=GpuWaitResetCollector,
                    TrajInfoCls=AtariTrajInfo,
                    eval_env_kwargs=config["eval_env"],
                    **config["sampler"]
                )

    folders_name = [args.output_dir,args.env_name,'run_'+args.run_ID]
    path = os.path.join(*folders_name)
    os.makedirs(path, exist_ok=True)

    experiment = Experiment(api_key='YOUR_API_KEY',auto_output_logging=False, project_name='procgen',workspace="YOUR_USERNAME",disabled=True)
    experiment.set_name( args.experiment_name )
    experiment.log_parameters(config)

    MinibatchRlEval.TF_logger = Logger(path, use_TFX=True, params=config, comet_experiment=experiment, disable_local=True) # set disable_local=False to enable Comet.ml
    MinibatchRlEval.log_diagnostics = log_diagnostics_custom
    MinibatchRlEval._log_infos = _log_infos

    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )

    runner.algo.opt_info_fields = tuple(list(runner.algo.opt_info_fields) + ['lossNCE','repeatingActions','averageN']+['action%d'%i for i in range(15)])
    # config = dict(game=game)
    name = args.mode+"_value_based_nce_" + args.env_name
    log_dir = os.path.join(args.output_dir, args.env_name)
    logger.set_snapshot_gap( args.weight_save_interval//config['runner']['log_interval_steps'] )



    with experiment.train():
        with logger_context(log_dir, run_ID, name, config,snapshot_mode=('last' if args.weight_save_interval == -1 else 'gap')): # set 'all' to save every it, 'gap' for every X it
            runner.train()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--experiment-name', default='c51_procgen')
    parser.add_argument('--env-name', help='Procgen game', default='procgen-bigfish-v0.500')
    parser.add_argument('--n-cpus', help='number of cpus as workers', type=int, default=8)
    parser.add_argument('--mode', help='mode (serial, async or parallel)', type=str, default='parallel',choices=['serial','async','parallel'])
    parser.add_argument('--output-dir', type=str, default='results/')
    parser.add_argument('--log-interval-steps', type=int, default=100e3)
    parser.add_argument('--algo', help='Algo type (C51)', default='c51',choices=['c51'])
    ### C51 / Rainbow
    parser.add_argument('--dueling', type=int, default=0)
    parser.add_argument('--double-dqn', type=int, default=0)
    parser.add_argument('--n_step-return', type=int, default=7)
    parser.add_argument('--prioritized-replay', type=int, default=0)
    parser.add_argument('--learning-rate', type=float, default=2.5e-4)
    ### NCE
    parser.add_argument('--nce-batch-size', type=int, default=256, help='NCE batch size')
    parser.add_argument('--weight-save-interval', type=int,help='How often to save weights (default: every 500k steps). If set to -1, only best weight will be saved',default=500000)
    parser.add_argument('--architecture', type=str, default='Mnih', help='Network architecture')
    parser.add_argument('--frame_stack', type=int, default=4,help='Framestack (number of frames)')
    parser.add_argument('--run_ID', type=int, default=0,help='To start multiple runs with the same parameters')
    parser.add_argument('--downsample', type=int, default=1,help='Downsample (yes 1 or no 0)')
    parser.add_argument('--lambda_LL', type=float, default=0)
    parser.add_argument('--lambda_LG', type=float, default=0)
    parser.add_argument('--lambda_GL', type=float, default=0)
    parser.add_argument('--lambda_GG', type=float, default=0)
    parser.add_argument('--nce_loss', type=str, default='InfoNCE_action_loss')
    parser.add_argument('--score_fn', type=str,choices=['nce_scores_log_softmax'], default='nce_scores_log_softmax')
    ### Misc
    parser.add_argument('--decay-aux', type=int, default=0) # Legacy
    parser.add_argument('--horizon', type=int, default=27e3)
    parser.add_argument('--nce-within-trajectory', type=int, default=0) # Legacy
    parser.add_argument('--n_step-nce', type=int, default=1)
    parser.add_argument('--data_aug',help='Augment data batches?', type=int, default=0)
    parser.add_argument('--ema_moco',help='Stop gradient at psi_tp1(phi.detach())?', type=int, default=0)
    ### Procgen env-specific args
    """
    Legacy code - experimenting with leaper task distribution (not in paper)
    """
    parser.add_argument('--leaper__monster_radius', type=str, default='0.25')
    parser.add_argument('--leaper__log_radius', type=str, default='0.45')
    parser.add_argument('--leaper__goal_reward', type=str, default='10')
    parser.add_argument('--leaper__nstep', type=str, default='5')
    parser.add_argument('--leaper__min_car_speed', type=str, default='0.05')
    parser.add_argument('--leaper__max_car_speed', type=str, default='0.2')
    parser.add_argument('--leaper__min_log_speed', type=str, default='0.05')
    parser.add_argument('--leaper__max_log_speed', type=str, default='0.1')
    args = parser.parse_args()
    args.batch_size = config['algo']['batch_size']

    try:
        args.game = list(filter(lambda x:x.id == args.env_name,list(envs.registry.all())))[0]._kwargs['game']
    except Exception as e:
        args.game = None

    args.run_ID = str(uuid.uuid1())
    seed = np.random.randint(1000000,size=1)[0]
    args.seed = str(seed)
    build_and_train(
        args=args,
        game=args.game,
        run_ID=args.run_ID,
        config=config
    )
