import time
import collections
import copy
import logging

import dm_pix as pix
import haiku as hk
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as onp
import optax
import os
import pickle
import ray
from ray import tune
import rlax
import tree

from algorithms import utils
from algorithms.actor import Actor, ActorOutput
from algorithms.haiku_nets import torso_network

AgentOutput = collections.namedtuple(
    'AgentOutput', (
        'state',
        'logits',
        'value',
        'aux_proj',
        'aux_pred',
    )
)

A2CLog = collections.namedtuple(
    'A2CLog', (
        'entropy',
        'value',
        'ret',
        'pg_loss',
        'baseline_loss',
        'state_norm',
        'theta_norm',
        'grad_norm',
        'update_norm',
    )
)

AuxLog = collections.namedtuple(
    'AuxLog', (
        'aux_loss',
        'grad_norm',
        'update_norm',
    )
)


class ActorCriticAuxNet(hk.RNNCore):
    def __init__(self, num_actions, aux_dim, torso_type, torso_kwargs, use_rnn, head_layers, stop_ac_grad, scale,
                 name=None):

        super(ActorCriticAuxNet, self).__init__(name=name)
        self._num_actions = num_actions
        self._aux_dim = aux_dim
        self._torso_type = torso_type
        self._torso_kwargs = torso_kwargs
        self._use_rnn = use_rnn
        if use_rnn:
            core = hk.GRU(512, w_h_init=hk.initializers.Orthogonal())
        else:
            core = hk.IdentityCore()
        self._core = hk.ResetCore(core)
        self._head_layers = head_layers
        self._stop_ac_grad = stop_ac_grad
        self._scale = scale

    def __call__(self, timesteps, state):
        torso_net = torso_network(self._torso_type, **self._torso_kwargs)
        torso_output = torso_net(timesteps.observation)

        if self._use_rnn:
            core_input = jnp.concatenate([
                hk.one_hot(timesteps.action_tm1, self._num_actions),
                timesteps.reward[:, None],
                torso_output
            ], axis=1)
            should_reset = timesteps.first
            core_output, next_state = hk.dynamic_unroll(self._core, (core_input, should_reset), state)
        else:
            core_output, next_state = torso_output, state

        aux_head = []
        for dim in self._head_layers:
            aux_head.append(hk.Linear(dim))
            aux_head.append(jax.nn.relu)
        aux_input = hk.Sequential(aux_head)(core_output)
        aux_proj = hk.Linear(self._aux_dim)(aux_input)
        aux_pred = hk.Linear(self._aux_dim, with_bias=False)(aux_proj)

        main_head = []
        if self._stop_ac_grad:
            main_head.append(lax.stop_gradient)
        for dim in self._head_layers:
            main_head.append(hk.Linear(dim))
            main_head.append(jax.nn.relu)
        h = hk.Sequential(main_head)(core_output)
        logits = hk.Linear(self._num_actions)(h)
        value = hk.Linear(1)(h)

        agent_output = AgentOutput(
            state=core_output,
            logits=logits,
            value=value.squeeze(-1),
            aux_proj=aux_proj,
            aux_pred=aux_pred,
        )
        return agent_output, next_state

    def initial_state(self, batch_size):
        return self._core.initial_state(batch_size)


class Agent(object):
    def __init__(self, ob_space, action_space, aux_dim, torso_type, torso_kwargs, head_layers, use_rnn, stop_ac_grad,
                 scale):
        self._ob_space = ob_space
        num_actions = action_space.n
        _, self._initial_state_apply_fn = hk.without_apply_rng(
            hk.transform(lambda batch_size: ActorCriticAuxNet(
                num_actions=num_actions,
                aux_dim=aux_dim,
                torso_type=torso_type,
                torso_kwargs=torso_kwargs,
                use_rnn=use_rnn,
                head_layers=head_layers,
                stop_ac_grad=stop_ac_grad,
                scale=scale,
            ).initial_state(batch_size))
        )
        self._init_fn, self._apply_fn = hk.without_apply_rng(
            hk.transform(lambda inputs, state: ActorCriticAuxNet(
                num_actions=num_actions,
                aux_dim=aux_dim,
                torso_type=torso_type,
                torso_kwargs=torso_kwargs,
                use_rnn=use_rnn,
                head_layers=head_layers,
                stop_ac_grad=stop_ac_grad,
                scale=scale,
            )(inputs, state))
        )
        self.step = jax.jit(self._step)

    def init(self, rngkey):
        dummy_observation = tree.map_structure(lambda t: jnp.zeros(t.shape, t.dtype), self._ob_space)
        dummy_observation = tree.map_structure(lambda t: t[None], dummy_observation)
        dummy_reward = jnp.zeros((1,), dtype=jnp.float32)
        dummy_action = jnp.zeros((1,), dtype=jnp.int32)
        dummy_discount = jnp.zeros((1,), dtype=jnp.float32)
        dummy_first = jnp.zeros((1,), dtype=jnp.float32)
        dummy_state = self.initial_state(None)
        dummy_input = ActorOutput(
            rnn_state=dummy_state,
            action_tm1=dummy_action,
            reward=dummy_reward,
            discount=dummy_discount,
            first=dummy_first,
            observation=dummy_observation,
        )
        return self._init_fn(rngkey, dummy_input, dummy_state)

    def initial_state(self, batch_size):
        return self._initial_state_apply_fn(None, batch_size)

    def _step(self, rngkey, params, timesteps, states):
        rngkey, subkey = jrandom.split(rngkey)
        timesteps = tree.map_structure(lambda t: t[:, None, ...], timesteps)  # [B, 1, ...]
        agent_output, next_states = jax.vmap(self._apply_fn, (None, 0, 0))(params, timesteps, states)
        agent_output = tree.map_structure(lambda t: t.squeeze(axis=1), agent_output)  # [B, ...]
        action = hk.multinomial(subkey, agent_output.logits, num_samples=1).squeeze(axis=-1)
        return rngkey, action, agent_output, next_states

    def unroll(self, params, timesteps, state):
        return self._apply_fn(params, timesteps, state)  # [T, ...]


def gen_a2c_update_fn(agent, opt_update, gamma, vf_coef, entropy_reg, use_mask):
    def a2c_loss(theta, trajs):
        rnn_states = tree.map_structure(lambda t: t[:, 0], trajs.rnn_state)
        learner_output, _ = jax.vmap(agent.unroll, (None, 0, 0))(theta, trajs, rnn_states)  # [B, T + 1, ...]
        rewards = trajs.reward[:, 1:]
        discounts = trajs.discount[:, 1:] * gamma
        bootstrap_value = learner_output.value[:, -1]
        returns = jax.vmap(rlax.discounted_returns)(rewards, discounts, bootstrap_value)
        advantages = returns - learner_output.value[:, :-1]

        if use_mask:
            masks = trajs.discount[:, :-1]
        else:
            masks = jnp.ones_like(trajs.discount[:, :-1])
        pg_loss = jax.vmap(rlax.policy_gradient_loss)(
            learner_output.logits[:, :-1], trajs.action_tm1[:, 1:], advantages, masks)
        ent_loss = jax.vmap(rlax.entropy_loss)(learner_output.logits[:, :-1], masks)
        baseline_loss = 0.5 * jnp.mean(
            jnp.square(learner_output.value[:, :-1] - lax.stop_gradient(returns)) * masks, axis=1)
        loss = jnp.mean(pg_loss + vf_coef * baseline_loss + entropy_reg * ent_loss)

        state_norm = jnp.sqrt(jnp.sum(jnp.square(learner_output.state), axis=-1))
        a2c_log = A2CLog(
            entropy=-ent_loss,
            value=learner_output.value,
            ret=returns,
            pg_loss=pg_loss,
            baseline_loss=baseline_loss,
            state_norm=state_norm,
            theta_norm=optax.global_norm(theta),
            grad_norm=0.,  # placeholder
            update_norm=0.,  # placeholder
        )
        return loss, a2c_log

    def a2c_update(theta, opt_state, trajs):
        grads, logs = jax.grad(a2c_loss, has_aux=True)(theta, trajs)
        updates, new_opt_state = opt_update(grads, opt_state)
        grad_norm = optax.global_norm(grads)
        update_norm = optax.global_norm(updates)
        logs = logs._replace(
            grad_norm=grad_norm,
            update_norm=update_norm,
        )
        new_theta = optax.apply_updates(theta, updates)
        return new_theta, new_opt_state, logs

    return a2c_update


def gen_curl_update_fn(agent, opt_update, ema_rate):
    def curl_loss(theta, theta_ema, trajs, rngkey):
        rnn_states = tree.map_structure(lambda t: t[:, 0], trajs.rnn_state)
        obs = trajs.observation
        k1, k2 = jax.random.split(rngkey)
        cropped_obs = pix.random_crop(k1, obs, obs.shape[:2] + (80, 80) + obs.shape[-1:])
        anchor_obs = jnp.pad(cropped_obs, ((0, 0), (0, 0), (2, 2), (2, 2), (0, 0)), mode='constant', constant_values=0)
        anchor_trajs = trajs._replace(observation=anchor_obs)
        agent_output, _ = jax.vmap(agent.unroll, (None, 0, 0))(theta, anchor_trajs, rnn_states)  # [B, T, ...]
        pred = agent_output.aux_pred[:, :-1]
        pred = pred.reshape((pred.shape[0] * pred.shape[1],) + pred.shape[2:])

        cropped_obs = pix.random_crop(k2, obs, obs.shape[:2] + (80, 80) + obs.shape[-1:])
        target_obs = jnp.pad(cropped_obs, ((0, 0), (0, 0), (2, 2), (2, 2), (0, 0)), mode='constant', constant_values=0)
        target_trajs = trajs._replace(observation=target_obs)
        target_output, _ = jax.vmap(agent.unroll, (None, 0, 0))(theta_ema, target_trajs, rnn_states)  # [B, T, ...]
        target = target_output.aux_proj[:, :-1]
        target = target.reshape((target.shape[0] * target.shape[1],) + target.shape[2:])
        target = jax.lax.stop_gradient(target)

        logits = jnp.matmul(pred, target.T)
        labels = jnp.eye(logits.shape[0])

        pred_losses = optax.softmax_cross_entropy(logits, labels)
        aux_loss = jnp.mean(pred_losses)

        aux_log = AuxLog(
            aux_loss=aux_loss,
            grad_norm=0.,  # placeholder
            update_norm=0.,  # placeholder
        )
        return aux_loss, aux_log

    def curl_update(rngkey, theta, theta_ema, opt_state, trajs):
        rngkey, subkey = jax.random.split(rngkey)
        grads, logs = jax.grad(curl_loss, has_aux=True)(theta, theta_ema, trajs, subkey)
        grad_norm = optax.global_norm(grads)
        updates, new_opt_state = opt_update(grads, opt_state)
        update_norm = optax.global_norm(updates)
        logs = logs._replace(
            grad_norm=grad_norm,
            update_norm=update_norm,
        )
        new_theta = optax.apply_updates(theta, updates)
        new_theta_ema = optax.incremental_update(new_theta, theta_ema, ema_rate)
        return rngkey, new_theta, new_theta_ema, new_opt_state, logs

    return curl_update


class Experiment(tune.Trainable):
    def setup(self, config):
        self._config = config
        platform = jax.lib.xla_bridge.get_backend().platform
        logging.warning("Running on %s", platform)
        if config['env_id'] == 'maze':
            import environments.maze.vec_env_utils as maze_vec_env
            self._envs = maze_vec_env.make_vec_env(
                config['nenvs'],
                config['seed'],
                env_kwargs=config['env_kwargs'],
            )
            self._frame_skip = 1
            use_mask = True
        elif config['env_id'].startswith('procgen/'):
            import environments.procgen.vec_env_utils as procgen_vec_env
            env_id = config['env_id'][8:]
            self._envs = procgen_vec_env.make_vec_env(
                env_id,
                config['nenvs'],
                env_kwargs=config['env_kwargs'],
            )
            self._frame_skip = 1
            use_mask = False
        elif config['env_id'].startswith('dmlab/'):
            import environments.dmlab.vec_env_utils as dmlab_vec_env
            env_id = config['env_id'][6:]
            gpu_id = ray.get_gpu_ids()[0]
            env_kwargs = copy.deepcopy(config['env_kwargs'])
            env_kwargs['gpuDeviceIndex'] = gpu_id
            self._envs = dmlab_vec_env.make_vec_env(
                env_id, config['cache'], config['noop_max'], config['nenvs'], config['seed'], env_kwargs)
            self._frame_skip = 4
            use_mask = True
        elif config['env_id'][-14:] == 'NoFrameskip-v4':
            import environments.atari.vec_env_utils as atari_vec_env
            from vec_env import VecFrameStack
            envs = atari_vec_env.make_vec_env(
                config['env_id'],
                config['nenvs'],
                config['seed'],
            )
            if config['use_rnn']:
                self._envs = envs
            else:
                self._envs = VecFrameStack(envs, 4)
            self._frame_skip = 4
            use_mask = True
        else:
            raise KeyError
        self._nsteps = config['nsteps']

        jax_seed = onp.random.randint(2 ** 31 - 1)
        self._rngkey = jrandom.PRNGKey(jax_seed)

        scale = 1.

        agent = Agent(
            ob_space=self._envs.observation_space,
            action_space=self._envs.action_space,
            aux_dim=config['aux_dim'],
            torso_type=config['torso_type'],
            torso_kwargs=config['torso_kwargs'],
            use_rnn=config['use_rnn'],
            head_layers=config['head_layers'],
            stop_ac_grad=config['stop_ac_grad'],
            scale=scale,
        )
        self._actor = Actor(self._envs, agent, self._nsteps)

        if config['a2c_opt_type'] == 'adam':
            a2c_opt_kwargs = config['a2c_opt_kwargs'].copy()
            learning_rate = a2c_opt_kwargs.pop('learning_rate')
            a2c_opt = optax.chain(
                optax.scale_by_adam(**a2c_opt_kwargs),
                optax.scale(-learning_rate),
            )
        elif config['a2c_opt_type'] == 'rmsprop':
            a2c_opt = optax.rmsprop(**config['a2c_opt_kwargs'])
        else:
            raise KeyError
        if config['max_a2c_grad_norm'] > 0:
            a2c_opt = optax.chain(
                optax.clip_by_global_norm(config['max_a2c_grad_norm']),
                a2c_opt,
            )
        a2c_opt_init, a2c_opt_update = a2c_opt
        if config['aux_opt_type'] == 'adam':
            aux_opt_kwargs = config['aux_opt_kwargs'].copy()
            learning_rate = aux_opt_kwargs.pop('learning_rate')
            aux_opt = optax.chain(
                optax.scale_by_adam(**aux_opt_kwargs),
                optax.scale(-learning_rate),
            )
        elif config['aux_opt_type'] == 'rmsprop':
            aux_opt = optax.rmsprop(**config['aux_opt_kwargs'])
        else:
            raise KeyError
        if config['max_aux_grad_norm'] > 0:
            aux_opt = optax.chain(
                optax.clip_by_global_norm(config['max_aux_grad_norm']),
                aux_opt,
            )
        aux_opt_init, aux_opt_update = aux_opt

        a2c_update_fn = gen_a2c_update_fn(
            agent=agent,
            opt_update=a2c_opt_update,
            gamma=config['gamma'],
            vf_coef=config['vf_coef'],
            entropy_reg=config['entropy_reg'],
            use_mask=use_mask,
        )

        aux_update_fn = gen_curl_update_fn(
            agent=agent,
            opt_update=aux_opt_update,
            ema_rate=config['ema_rate'],
        )
        self._a2c_update_fn = jax.jit(a2c_update_fn)
        self._aux_update_fn = jax.jit(aux_update_fn)

        self._rngkey, subkey = jrandom.split(self._rngkey)
        self._theta = agent.init(subkey)
        self._theta_ema = self._theta
        self._a2c_opt_state = a2c_opt_init(self._theta)
        self._aux_opt_state = aux_opt_init(self._theta)

        self._epinfo_buf = collections.deque(maxlen=100)
        self._num_iter = 0
        self._num_frames = 0
        self._tstart = time.time()

    def step(self):
        t0 = time.time()
        rngkey = self._rngkey
        theta = self._theta
        theta_ema = self._theta_ema
        num_frames_this_iter = 0
        for _ in range(self._config['log_interval']):
            rngkey, trajs, epinfos = self._actor.rollout(rngkey, theta)
            self._epinfo_buf.extend(epinfos)

            trajs = jax.device_put(trajs)
            theta, self._a2c_opt_state, a2c_log = self._a2c_update_fn(
                theta, self._a2c_opt_state, trajs)
            rngkey, theta, theta_ema, self._aux_opt_state, aux_log = self._aux_update_fn(
                rngkey, theta, theta_ema, self._aux_opt_state, trajs)

            self._num_iter += 1
            num_frames_this_iter += self._config['nenvs'] * self._nsteps * self._frame_skip
        self._rngkey = rngkey
        self._theta = theta
        self._theta_ema = theta_ema
        self._num_frames += num_frames_this_iter

        a2c_log = jax.device_get(a2c_log)
        aux_log = jax.device_get(aux_log)
        ev = utils.explained_variance(a2c_log.value[:, :-1].flatten(), a2c_log.ret.flatten())
        log = {
            'label': self._config['label'],
            'episode_return': onp.mean([epinfo['r'] for epinfo in self._epinfo_buf]),
            'episode_length': onp.mean([epinfo['l'] for epinfo in self._epinfo_buf]),
            'entropy': a2c_log.entropy.mean(),
            'explained_variance': ev,
            'pg_loss': a2c_log.pg_loss.mean(),
            'baseline_loss': a2c_log.baseline_loss.mean(),
            'value_mean': a2c_log.value.mean(),
            'value_std': a2c_log.value.std(),
            'return_mean': a2c_log.ret.mean(),
            'return_std': a2c_log.ret.std(),
            'state_norm': onp.mean(a2c_log.state_norm),
            'a2c_grad_norm': a2c_log.grad_norm,
            'a2c_update_norm': a2c_log.update_norm,
            'param_norm': a2c_log.theta_norm,
            'aux_loss': aux_log.aux_loss.mean(),
            'aux_grad_norm': aux_log.grad_norm,
            'aux_update_norm': aux_log.update_norm,
            'num_iterations': self._num_iter,
            'num_frames': self._num_frames,
            'fps': num_frames_this_iter / (time.time() - t0),
        }
        return log

    def _save(self, tmp_checkpoint_dir):
        theta = jax.device_get(self._theta)
        a2c_opt_state = jax.device_get(self._a2c_opt_state)
        aux_opt_state = jax.device_get(self._aux_opt_state)
        checkpoint_path = os.path.join(tmp_checkpoint_dir, 'model.chk')
        with open(checkpoint_path, 'wb') as checkpoint_file:
            pickle.dump((theta, a2c_opt_state, aux_opt_state), checkpoint_file)
        return checkpoint_path

    def _restore(self, checkpoint):
        with open(checkpoint, 'rb') as checkpoint_file:
            theta, opt_state = pickle.load(checkpoint_file)
        self._theta = theta
        self._opt_state = opt_state


if __name__ == '__main__':
    config = {
        'label': 'a2c-curl',
        'env_id': 'BreakoutNoFrameskip-v4',
        'env_kwargs': {},

        'torso_type': 'atari_shallow',
        'torso_kwargs': {
            'dense_layers': (),
        },
        'use_rnn': False,
        'head_layers': (512,),
        'stop_ac_grad': True,

        'scale_gradient': True,

        'nenvs': 16,
        'nsteps': 20,
        'gamma': 0.99,
        'lambda_': 1.0,
        'vf_coef': 0.5,
        'entropy_reg': 0.01,

        'a2c_opt_type': 'rmsprop',
        'a2c_opt_kwargs': {
            'learning_rate': 7E-4,
            'decay': 0.99,
            'eps': 1E-5,
        },
        'max_a2c_grad_norm': 0.5,

        'aux_dim': 128,
        'aux_opt_type': 'adam',
        'aux_opt_kwargs': {
            'learning_rate': 7E-4,
            'b1': 0.,
            'b2': 0.99,
            'eps_root': 1E-5,
        },
        'max_aux_grad_norm': 0.,
        'ema_rate': 0.001,

        'log_interval': 100,
        'seed': 0,
    }
    analysis = tune.run(
        Experiment,
        name='debug',
        config=config,
        stop={
            'num_frames': 200 * 10 ** 6,
        },
        resources_per_trial={
            'gpu': 1,
        },
    )
