import time
import collections
import functools
import logging

import haiku as hk
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as onp
import optax
import os
import pickle
from ray import tune
import rlax
import tree

from algorithms import td_nets
from algorithms import utils
from algorithms.haiku_nets import torso_network
from algorithms.target_features import get_random_feature_fn
import environments.empty_room.vec_env_utils as empty_room_vec_env

AgentOutput = collections.namedtuple(
    'AgentOutput', (
        'state',
        'value',
        'aux_pred',
    )
)

PELog = collections.namedtuple(
    'PELog', (
        'value',
        'ret',
        'pe_loss',
        'state_norm',
        'theta_norm',
        'grad_norm',
        'update_norm',
    )
)

AuxLog = collections.namedtuple(
    'AuxLog', (
        'pred',
        'td_target',
        'feature',
        'mask',
        'aux_loss',
        'abs_td_error',
        'grad_norm',
        'update_norm',
    )
)


DIR = [(-1, 0), (1, 0), (0, -1), (0, 1)]


def compute_ground_truth_value(size, gamma):
    n = (size - 2) ** 2
    r = onp.zeros((n,))
    P = onp.zeros((n, n))
    for i in range(1, size - 1):
        for j in range(1, size - 1):
            id0 = (i - 1) * (size - 2) + (j - 1)
            for dx, dy in DIR:
                u = onp.clip(i + dx, 1, size - 2)
                v = onp.clip(j + dy, 1, size - 2)
                id1 = (u - 1) * (size - 2) + (v - 1)
                P[id0, id1] += 1 / len(DIR)
    r[size - 4] = 1.
    values = jax.device_get(jnp.linalg.pinv(onp.eye(n) - gamma * P) @ P @ r)
    return values.reshape((size - 2, size - 2))


def render_observation(grid, observation_type):
    from environments.empty_room.gym_maze import MazeEnv
    from pycolab.rendering import Observation
    env = MazeEnv(size=len(grid), observation_type=observation_type)
    board = onp.zeros_like(grid, dtype=onp.uint8)
    for i in range(len(grid)):
        for j in range(len(grid)):
            board[i, j] = ord(grid[i, j])
    layers = {}
    for l in ['+', '0', '#', ' ', 't']:
        layer = onp.zeros_like(grid, dtype=onp.bool)
        for i in range(len(grid)):
            for j in range(len(grid)):
                if grid[i, j] == l:
                    layer[i, j] = True
        layers[l] = layer
    ob = Observation(board=board, layers=layers)
    ob, _ = env._process_outputs(ob, None)
    return ob


def generate_test_batch(size, observation_type):
    from environments.empty_room.pycolab_maze import generate_maze
    grid = generate_maze(size)
    grid = onp.array([list(row) for row in grid])
    for i in range(size):
        for j in range(size):
            if grid[i, j] == '+':
                grid[i, j] = ' '
    batch = []
    for i in range(size):
        for j in range(size):
            if grid[i, j] != '#':
                c = grid[i, j]
                grid[i, j] = '+'
                batch.append(render_observation(grid, observation_type))
                grid[i, j] = c
    return onp.stack(batch, axis=0)


class ValueAuxNet(hk.Module):
    def __init__(self, num_pred, torso_type, torso_kwargs, head_layers, stop_value_grad, scale, name=None):
        super(ValueAuxNet, self).__init__(name=name)
        self._num_pred = num_pred
        self._torso_type = torso_type
        self._torso_kwargs = torso_kwargs
        self._head_layers = head_layers
        self._stop_value_grad = stop_value_grad
        self._scale = scale

    def __call__(self, x):
        torso_net = torso_network(self._torso_type, **self._torso_kwargs)
        torso_output = torso_net(x)

        aux_head = []
        for dim in self._head_layers:
            aux_head.append(hk.Linear(dim))
            aux_head.append(jax.nn.relu)
        aux_input = hk.Sequential(aux_head)(torso_output)
        aux_input = self._scale * aux_input + lax.stop_gradient((1 - self._scale) * aux_input)
        aux_pred = hk.Linear(self._num_pred)(aux_input)

        main_head = []
        if self._stop_value_grad:
            main_head.append(lax.stop_gradient)
        for dim in self._head_layers:
            main_head.append(hk.Linear(dim))
            main_head.append(jax.nn.relu)
        h = hk.Sequential(main_head)(torso_output)
        value = hk.Linear(1)(h)

        agent_output = AgentOutput(
            state=torso_output,
            value=value.squeeze(-1),
            aux_pred=aux_pred,
        )
        return agent_output


class Agent(object):
    def __init__(self, ob_space, ac_space, num_pred, torso_type, torso_kwargs, head_layers, stop_value_grad,
                 scale):
        self._ob_space = ob_space
        self._num_actions = ac_space.n
        self._init_fn, self._apply_fn = hk.without_apply_rng(
            hk.transform(lambda inputs: ValueAuxNet(
                num_pred=num_pred,
                torso_type=torso_type,
                torso_kwargs=torso_kwargs,
                head_layers=head_layers,
                stop_value_grad=stop_value_grad,
                scale=scale,
            )(inputs))
        )

    @functools.partial(jax.jit, static_argnums=(0,))
    def init(self, rngkey):
        dummy_observation = tree.map_structure(lambda t: jnp.zeros(t.shape, t.dtype), self._ob_space)
        dummy_observation = tree.map_structure(lambda t: t[None], dummy_observation)
        return self._init_fn(rngkey, dummy_observation)

    @functools.partial(jax.jit, static_argnums=(0,))
    def step(self, rngkey, params, observations):
        n = observations.shape[0]
        rngkey, step_key = jrandom.split(rngkey)
        action = jrandom.randint(step_key, (n,), 0, self._num_actions)
        return rngkey, action, ()

    def unroll(self, params, observation):
        return self._apply_fn(params, observation)  # [T, ...]


ActorOutput = collections.namedtuple(
    'ActorOutput', [
        'action_tm1',
        'reward',
        'discount',
        'first',
        'observation',
    ]
)


class Actor(object):
    def __init__(self, envs, agent, nsteps):
        self._envs = envs
        self._agent = agent
        self._nsteps = nsteps
        nenvs = self._envs.num_envs
        self._timestep = ActorOutput(
            action_tm1=onp.zeros((nenvs,), dtype=onp.int32),  # dummy actions
            reward=onp.zeros((nenvs,), dtype=onp.float32),  # dummy reward
            discount=onp.ones((nenvs,), dtype=onp.float32),  # dummy discount
            first=onp.ones((nenvs,), dtype=onp.float32),  # dummy first
            observation=self._envs.reset(),
        )

    def rollout(self, rngkey, params):
        timestep = self._timestep
        timesteps = [timestep]
        epinfos = []
        for t in range(self._nsteps):
            timestep = jax.device_put(timestep)
            rngkey, action, agent_output = self._agent.step(rngkey, params, timestep.observation)
            action = jax.device_get(action)  # This is crucial for a higher throughput!!!
            observation, reward, terminate, info = self._envs.step(action)
            timestep = ActorOutput(
                action_tm1=action,
                reward=reward.astype(onp.float32),
                discount=1.-terminate.astype(onp.float32),
                first=1.-jax.device_get(timestep.discount),
                observation=observation,
            )
            timesteps.append(timestep)
            for i in info:
                maybeepinfo = i.get('episode')
                if maybeepinfo:
                    epinfos.append(maybeepinfo)
        self._timestep = timestep
        return rngkey, utils.pack_namedtuple_onp(timesteps, axis=1), epinfos


def gen_pe_update_fn(agent, opt_update, gamma):
    def pe_loss(theta, trajs):
        learner_output = jax.vmap(agent.unroll, (None, 0))(theta, trajs.observation)  # [B, T + 1, ...]
        rewards = jnp.clip(trajs.reward[:, 1:], 0, 1)
        discounts = trajs.discount[:, 1:] * gamma
        bootstrap_value = learner_output.value[:, -1]
        returns = jax.vmap(rlax.discounted_returns)(rewards, discounts, bootstrap_value)

        masks = trajs.discount[:, :-1]
        baseline_loss = 0.5 * jnp.mean(
            jnp.square(learner_output.value[:, :-1] - lax.stop_gradient(returns)) * masks, axis=1)
        loss = jnp.mean(baseline_loss)

        state_norm = jnp.sqrt(jnp.sum(jnp.square(learner_output.state), axis=-1))
        pe_log = PELog(
            value=learner_output.value,
            ret=returns,
            pe_loss=loss,
            state_norm=state_norm,
            theta_norm=optax.global_norm(theta),
            grad_norm=0.,  # placeholder
            update_norm=0.,  # placeholder
        )
        return loss, pe_log

    def pe_update(theta, opt_state, trajs):
        grads, logs = jax.grad(pe_loss, has_aux=True)(theta, trajs)
        updates, new_opt_state = opt_update(grads, opt_state)
        grad_norm = optax.global_norm(grads)
        update_norm = optax.global_norm(updates)
        logs = logs._replace(
            grad_norm=grad_norm,
            update_norm=update_norm,
        )
        new_theta = optax.apply_updates(theta, updates)
        return new_theta, new_opt_state, logs

    return pe_update


def gen_td_net_update_fn(agent, opt_update, td_mat, td_masks, target_feature_fn):
    def compute_td_target(pred_tp1):
        return jnp.matmul(td_mat, pred_tp1)

    def td_net_loss(theta, target_feature_params, trajs):
        agent_output = jax.vmap(agent.unroll, (None, 0))(theta, trajs.observation)  # [B, T, ...]
        pred = agent_output.aux_pred[:, :-1]
        pred_tp1 = agent_output.aux_pred[:, 1:] * trajs.discount[:, 1:, None]  # Episode boundary.

        target_feature = jax.vmap(target_feature_fn, (None, 0))(target_feature_params, trajs)
        pred_masks = td_masks[trajs.action_tm1[:, 1:]]
        transition_masks = trajs.discount[:, :-1]

        feature_and_pred_tp1 = jnp.concatenate([target_feature, pred_tp1], axis=-1)
        _td_target = jax.vmap(jax.vmap(compute_td_target, 0), 0)(feature_and_pred_tp1)
        td_target = _td_target[..., target_feature.shape[-1]:]

        # Flatten the tensors: [B, T, ...] -> [B * T, ...]
        pred, td_target, pred_masks, transition_masks = tree.map_structure(
            lambda t: t.reshape((t.shape[0] * t.shape[1],) + t.shape[2:]),
            (pred, td_target, pred_masks, transition_masks)
        )

        abs_td_error = jnp.sum(jnp.abs(pred - td_target) * pred_masks, axis=-1)
        pred_losses = 0.5 * jnp.square(pred - lax.stop_gradient(td_target)) * pred_masks
        aux_loss = jnp.mean(jnp.sum(pred_losses, axis=-1) * transition_masks)

        aux_log = AuxLog(
            pred=pred,
            td_target=td_target,
            feature=target_feature,
            mask=pred_masks,
            aux_loss=aux_loss,
            abs_td_error=abs_td_error,
            grad_norm=0.,  # placeholder
            update_norm=0.,  # placeholder
        )
        return aux_loss, aux_log

    def td_net_update(theta, target_feature_params, opt_state, trajs):
        grads, logs = jax.grad(td_net_loss, has_aux=True)(theta, target_feature_params, trajs)
        grad_norm = optax.global_norm(grads)
        updates, new_opt_state = opt_update(grads, opt_state)
        update_norm = optax.global_norm(updates)
        logs = logs._replace(
            grad_norm=grad_norm,
            update_norm=update_norm,
        )
        new_theta = optax.apply_updates(theta, updates)
        return new_theta, new_opt_state, logs

    return td_net_update


class Experiment(tune.Trainable):
    def setup(self, config):
        self._config = config
        platform = jax.lib.xla_bridge.get_backend().platform
        logging.warning("Running on %s", platform)
        self._envs = empty_room_vec_env.make_vec_env(
            config['nenvs'],
            config['seed'],
            env_kwargs=config['env_kwargs'],
        )
        self._nsteps = config['nsteps']

        self._test_batch = jax.device_put(generate_test_batch(**config['env_kwargs']))
        self._ground_truth_value = compute_ground_truth_value(
            config['env_kwargs']['size'],
            config['gamma'],
        )
        self._test_mask = onp.ones_like(self._ground_truth_value)
        # self._test_mask[0, -1] = 0
        print(self._ground_truth_value.round(2))
        mean_value = self._ground_truth_value.mean()
        constant_prediction_loss = onp.mean((self._ground_truth_value - mean_value) ** 2)
        print(constant_prediction_loss)

        jax_seed = onp.random.randint(2 ** 31 - 1)
        self._rngkey = jrandom.PRNGKey(jax_seed)
        if config['target_feature'] == 'touch':
            num_targets = 1
            target_feature_fn = lambda _, traj: (traj.reward[1:, None] == -0.01) + (traj.reward[1:, None] == 0.99)
            self._target_feature_params = None
        elif config['target_feature'] == 'random_feature':
            self._rngkey, subkey = jrandom.split(self._rngkey)
            target_feature_fn, self._target_feature_params, num_targets = get_random_feature_fn(
                rngkey=subkey,
                observation_space=self._envs.observation_space,
                **config['target_feature_kwargs'],
            )
        else:
            raise KeyError

        num_actions = self._envs.action_space.n
        num_pred, td_mat, td_masks, self._dep = td_nets.FACTORY[config['td_net_type']](
            num_actions=num_actions, num_targets=num_targets, **config['td_net_kwargs'])
        self._depth = self._dep.max() + 1
        print('{} features, {} predictions in total.'.format(num_targets, num_pred))
        print(td_mat)
        print(td_masks)

        if config['td_net_type'] == 'cond_tree_sum':
            active_predictions = num_actions ** (config['td_net_kwargs']['depth'] - 1)
        elif config['td_net_type'] == 'chain_sum':
            active_predictions = config['td_net_kwargs']['length']
        else:
            active_predictions = len(config['td_net_kwargs']['discount_factors']) * num_targets + \
                                 config['td_net_kwargs']['depth'] * config['td_net_kwargs']['repeat']
        scale = 1. / onp.sqrt(active_predictions)

        self._agent = Agent(
            ob_space=self._envs.observation_space,
            ac_space=self._envs.action_space,
            num_pred=num_pred,
            torso_type=config['torso_type'],
            torso_kwargs=config['torso_kwargs'],
            head_layers=config['head_layers'],
            stop_value_grad=config['stop_value_grad'],
            scale=scale,
        )
        self._actor = Actor(self._envs, self._agent, self._nsteps)

        if config['pe_opt_type'] == 'adam':
            pe_opt_kwargs = config['pe_opt_kwargs'].copy()
            learning_rate = pe_opt_kwargs.pop('learning_rate')
            schedule_fn = optax.polynomial_schedule(
                init_value=-learning_rate,
                end_value=0.,
                power=1,
                transition_steps=config['max_frames'] / (config['nenvs'] * config['nsteps']),
                transition_begin=1,
            )
            pe_opt = optax.chain(
                optax.scale_by_adam(**pe_opt_kwargs),
                # optax.scale(-learning_rate),
                optax.scale_by_schedule(schedule_fn)
            )
        elif config['pe_opt_type'] == 'rmsprop':
            pe_opt = optax.rmsprop(**config['pe_opt_kwargs'])
        else:
            raise KeyError
        if config['max_pe_grad_norm'] > 0:
            pe_opt = optax.chain(
                optax.clip_by_global_norm(config['max_pe_grad_norm']),
                pe_opt,
            )
        pe_opt_init, pe_opt_update = pe_opt
        if config['aux_opt_type'] == 'adam':
            aux_opt_kwargs = config['aux_opt_kwargs'].copy()
            learning_rate = aux_opt_kwargs.pop('learning_rate')
            aux_opt = optax.chain(
                optax.scale_by_adam(**aux_opt_kwargs),
                optax.scale(-learning_rate),
            )
        elif config['aux_opt_type'] == 'rmsprop':
            aux_opt = optax.rmsprop(**config['aux_opt_kwargs'])
        else:
            raise KeyError
        if config['max_aux_grad_norm'] > 0:
            aux_opt = optax.chain(
                optax.clip_by_global_norm(config['max_aux_grad_norm']),
                aux_opt,
            )
        aux_opt_init, aux_opt_update = aux_opt

        pe_update_fn = gen_pe_update_fn(
            agent=self._agent,
            opt_update=pe_opt_update,
            gamma=config['gamma'],
        )

        aux_update_fn = gen_td_net_update_fn(
            agent=self._agent,
            opt_update=aux_opt_update,
            td_mat=jax.device_put(td_mat),
            td_masks=jax.device_put(td_masks),
            target_feature_fn=target_feature_fn,
        )
        self._pe_update_fn = jax.jit(pe_update_fn)
        self._aux_update_fn = jax.jit(aux_update_fn)

        self._rngkey, subkey = jrandom.split(self._rngkey)
        self._theta = self._agent.init(subkey)
        self._pe_opt_state = pe_opt_init(self._theta)
        self._aux_opt_state = aux_opt_init(self._theta)

        self._num_iter = 0
        self._num_frames = 0
        self._tstart = time.time()

    def step(self):
        t0 = time.time()
        rngkey = self._rngkey
        theta = self._theta
        num_frames_this_iter = 0
        for _ in range(self._config['log_interval']):
            rngkey, trajs, epinfos = self._actor.rollout(rngkey, theta)

            trajs = jax.device_put(trajs)
            theta, self._pe_opt_state, pe_log = self._pe_update_fn(
                theta, self._pe_opt_state, trajs)
            theta, self._aux_opt_state, aux_log = self._aux_update_fn(
                theta, self._target_feature_params, self._aux_opt_state, trajs)

            self._num_iter += 1
            num_frames_this_iter += self._config['nenvs'] * self._nsteps
        self._rngkey = rngkey
        self._theta = theta
        self._num_frames += num_frames_this_iter

        agent_output = self._agent.unroll(self._theta, self._test_batch)
        v_hat = jax.device_get(agent_output.value.reshape(self._ground_truth_value.shape))
        # print(v_hat.round(2))
        test_loss = (onp.square(v_hat - self._ground_truth_value) * self._test_mask).mean()

        # aux_pred = jax.device_get(agent_output.aux_pred[:, 0].reshape(self._ground_truth_value.shape))
        # print(aux_pred.round(2))

        pe_log = jax.device_get(pe_log)
        aux_log = jax.device_get(aux_log)
        log = {
            'label': self._config['label'],
            'test_loss': test_loss,
            'pe_loss': pe_log.pe_loss.mean(),
            'value_mean': pe_log.value.mean(),
            'value_std': pe_log.value.std(),
            'return_mean': pe_log.ret.mean(),
            'return_std': pe_log.ret.std(),
            'state_norm': onp.mean(pe_log.state_norm),
            'pe_grad_norm': pe_log.grad_norm,
            'pe_update_norm': pe_log.update_norm,
            'param_norm': pe_log.theta_norm,
            'aux_loss': aux_log.aux_loss.mean(),
            'aux_grad_norm': aux_log.grad_norm,
            'aux_update_norm': aux_log.update_norm,
            'num_iterations': self._num_iter,
            'num_frames': self._num_frames,
            'fps': num_frames_this_iter / (time.time() - t0),
        }

        # pred_mean = (aux_log.pred * aux_log.mask).sum(axis=0) / (aux_log.mask.sum(axis=0) + 1E-8)
        # target_mean = (aux_log.td_target * aux_log.mask).sum(axis=0) / (aux_log.mask.sum(axis=0) + 1E-8)
        # loss = 0.5 * (onp.square(aux_log.pred - aux_log.td_target) * aux_log.mask).sum(axis=0) / (
        #             aux_log.mask.sum(axis=0) + 1E-8)

        # n = aux_log.pred.shape[1]
        # pred_std = onp.zeros((n,))
        # target_std = onp.zeros((n,))
        # ev = onp.zeros((n,))
        # for i in range(n):
        #     pred = []
        #     target = []
        #     for b in range(aux_log.pred.shape[0]):
        #         if aux_log.mask[b, i]:
        #             pred.append(aux_log.pred[b, i])
        #             target.append(aux_log.td_target[b, i])
        #     if not pred:
        #         continue
        #     pred = onp.array(pred)
        #     target = onp.array(target)
        #     pred_std[i] = pred.std()
        #     target_std[i] = target.std()
        #     if pred.shape[0] == 1:
        #         ev[i] = 0.
        #     else:
        #         ev[i] = utils.explained_variance(pred, target)
        # for d in range(self._depth):
        #     mask = self._dep == d
        #     depth_pred_mean = (pred_mean * mask).sum() / mask.sum()
        #     depth_pred_std = (pred_std * mask).sum() / mask.sum()
        #     depth_target_mean = (target_mean * mask).sum() / mask.sum()
        #     depth_target_std = (target_std * mask).sum() / mask.sum()
        #     depth_loss = (loss * mask).sum()
        #     depth_ev = (ev[mask]).sum() / mask.sum()
        #     log.update({
        #         'depth_{}_pred_mean'.format(d): depth_pred_mean,
        #         'depth_{}_pred_std'.format(d): depth_pred_std,
        #         'depth_{}_td_target_mean'.format(d): depth_target_mean,
        #         'depth_{}_td_target_std'.format(d): depth_target_std,
        #         'depth_{}_loss'.format(d): depth_loss,
        #         'depth_{}_ev'.format(d): depth_ev,
        #     })
        return log

    def _save(self, tmp_checkpoint_dir):
        theta = jax.device_get(self._theta)
        pe_opt_state = jax.device_get(self._pe_opt_state)
        aux_opt_state = jax.device_get(self._aux_opt_state)
        checkpoint_path = os.path.join(tmp_checkpoint_dir, 'model.chk')
        with open(checkpoint_path, 'wb') as checkpoint_file:
            pickle.dump((theta, pe_opt_state, aux_opt_state), checkpoint_file)
        return checkpoint_path

    def _restore(self, checkpoint):
        with open(checkpoint, 'rb') as checkpoint_file:
            theta, opt_state = pickle.load(checkpoint_file)
        self._theta = theta
        self._opt_state = opt_state


if __name__ == '__main__':
    config = {
        'label': 'policy-evaluation',
        'env_kwargs': {
            'size': 9,
            'observation_type': 'one_hot',
        },

        'torso_type': 'maze_shallow',
        'torso_kwargs': {
            'conv_layers': (),
            'dense_layers': (64, 64, 32,),
            'padding': 'VALID',
        },
        'head_layers': (),
        'stop_value_grad': True,

        'nenvs': 8,
        'nsteps': 8,
        'max_frames': 1 * 10 ** 6,
        'gamma': 0.98,

        'pe_opt_type': 'adam',
        'pe_opt_kwargs': {
            'learning_rate': 1E-3,
        },
        'max_pe_grad_norm': 0.,

        # 'target_feature': 'touch',
        # 'target_feature_kwargs': {},
        'target_feature': 'random_feature',
        'target_feature_kwargs': {
            'conv_layers': (),
            'dense_layers': (8,),
            'padding': 'VALID',
            'w_init': 'orthogonal',
            'w_init_scale': 2.,
            'delta': True,
            'absolute': True,
            'only_last_channel': False,
        },

        'td_net_type': 'mixed_open_loop_planning',
        'td_net_kwargs': {
            'seed': None,
            'depth': 4,
            'repeat': 8,
            'discount_factors': (0.8,),
        },
        # 'td_net_type': 'cond_tree_sum',
        # 'td_net_kwargs': {
        #     'depth': 2,
        #     'balance_by_depth': False,
        # },
        # 'td_net_type': 'chain_sum',
        # 'td_net_kwargs': {
        #     'length': 3,
        # },

        'aux_update_freq': 1,

        'aux_opt_type': 'adam',
        'aux_opt_kwargs': {
            'learning_rate': 1E-3,
        },
        'max_aux_grad_norm': 0.,

        'log_interval': 100,
        'seed': 0,
    }
    analysis = tune.run(
        Experiment,
        name='debug',
        config=config,
        stop={
            'num_frames': 1 * 10 ** 6,
        },
        resources_per_trial={
            'cpu': 1,
        },
    )
