from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer


class NStepISTrainer(TorchTrainer):
    """
    Trainer for Behavior Cloning
    Policy is trained by maximizing log likelihood of actions in a given dataset.
    Q function is trained by SARSA
    """
    def __init__(
            self,
            env,
            policy,
            target_policy,
            policy_data,
            qf1,
            qf2,
            target_qf1,
            target_qf2,

            update_q_first=False,
            kl_reg=True,
            n_actions=10,
            alpha=1.0,
            discount=0.99,
            reward_scale=1.0,
            log_is_clip_value=1.0,

            n_train_steps_policy=int(1e5),
            n_train_steps_qf=int(1e5),
            qf_lr=1e-4,
            policy_lr=1e-4,
            optimizer_class=optim.Adam,

            soft_target_tau=5e-3,
            target_update_period=2,
            update_policy_data=False,
            normalize_q=False,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.target_policy = target_policy
        self.policy_data = policy_data
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.update_q_first = update_q_first
        self.update_q_phase = self.update_q_first
        self.last_phase_steps = 0
        print('self.update_q_first: \t', self.update_q_first)
        print('self.update_q_phase: \t', self.update_q_phase)
        print('self.last_phase_steps: \t', self.last_phase_steps)
        self.update_policy_data = update_policy_data
        print('self.update_policy_data: \t', self.update_policy_data)
        self.kl_reg = kl_reg
        print('self.kl_reg: \t', self.kl_reg)
        self.normalize_q = normalize_q
        print('self.normalize_q: \t', self.normalize_q)

        self.n_actions = n_actions
        self.alpha = alpha
        print('self.alpha: \t', self.alpha)

        self.qf_criterion = nn.MSELoss()

        self.policy_lr = policy_lr
        self.qf_lr = qf_lr
        self.optimizer_class = optimizer_class

        self.reset_policy_optimizer()
        self.reset_qf_optimizer()

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self.eval_wandb = OrderedDict()
        self._n_train_steps_total = 0

        self._n_train_steps_policy = n_train_steps_policy
        self._n_train_steps_qf = n_train_steps_qf
        print('self._n_train_steps_policy: \t', self._n_train_steps_policy)
        print('self._n_train_steps_qf: \t', self._n_train_steps_qf)
        self._need_to_update_eval_statistics = True

        self.log_is_clip_value = log_is_clip_value
        print("self.log_is_clip_value: \t", self.log_is_clip_value)

        self.discrete = False

    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        if not self.update_q_phase:
            # Policy training phase

            q1_pred = self.qf1(obs, actions)
            q2_pred = self.qf2(obs, actions)

            """
            Policy and Alpha Loss
            """
            new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
                obs, reparameterize=True, return_log_prob=True,
            )

            if self.kl_reg:
                obs_stack = torch.unsqueeze(obs, 1).repeat(1, self.n_actions, 1).reshape((-1, obs.shape[1]))
                new_obs_actions_stack, _, _, log_pi_stack, *_ = self.policy(obs_stack, reparameterize=True,
                                                                            return_log_prob=True, )
                log_pi = torch.mean(log_pi_stack.reshape((-1, self.n_actions)), dim=1)

                log_pi_data_stack = self.policy_data.log_prob(obs_stack, new_obs_actions_stack)
                log_pi_data = torch.mean(log_pi_data_stack.reshape((-1, self.n_actions)), dim=1)

                kl = (log_pi - log_pi_data).mean()
                qf1_pi = self.qf1(obs, new_obs_actions)
                # qf2_pi = self.qf2(obs, new_obs_actions)
                if self.normalize_q:
                    avg_qf = qf1_pi.abs().mean().detach()
                    policy_loss = self.alpha * kl - qf1_pi.mean() / avg_qf
                else:
                    policy_loss = self.alpha * kl - qf1_pi.mean()
            else:
                policy_loss = -1 * self.qf1(obs, new_obs_actions).mean()
                kl = 0

            """
            Update networks
            """
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            qf_loss = qf1_loss = qf2_loss = torch.tensor(0.)

            """
            Soft Updates
            """
            if self._n_train_steps_total % self.target_update_period == 0:
                ptu.soft_update_from_to(
                    self.policy, self.target_policy, self.soft_target_tau
                )

        else:
            # Q function training phase
            """
            QF Loss
            """
            q1_pred = self.qf1(obs, actions)
            q2_pred = self.qf2(obs, actions)

            next_actions, _, _, log_prob_data, *_ = self.policy_data(
                next_obs, reparameterize=False, return_log_prob=True,
            )
            _, _, _, log_prob_pi, *_ = self.target_policy(
                next_obs, reparameterize=False, return_log_prob=True,
            )
            importance_sampling = torch.exp(torch.clip(log_prob_pi - log_prob_data, -1*self.log_is_clip_value, self.log_is_clip_value))

            target_qf1_values = self.target_qf1(next_obs, next_actions)
            target_qf2_values = self.target_qf2(next_obs, next_actions)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

            q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * importance_sampling * target_q_values
            qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
            qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
            qf_loss = qf1_loss + qf2_loss

            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

            """
            Soft Updates
            """
            if self._n_train_steps_total % self.target_update_period == 0:
                ptu.soft_update_from_to(
                    self.qf1, self.target_qf1, self.soft_target_tau
                )
                ptu.soft_update_from_to(
                    self.qf2, self.target_qf2, self.soft_target_tau
                )

            policy_loss = torch.tensor(0.)
            kl = torch.tensor(0.)

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(policy_loss))
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))

            if self.kl_reg:
                self.eval_statistics['KL'] = np.mean(ptu.get_numpy(kl))

            for k in ['Policy Loss', 'Q1 Predictions', 'Q2 Predictions', 'QF1 Loss', 'QF2 Loss', 'QF Loss', 'KL']:
                self.eval_wandb[k] = self.eval_statistics.get(k, 0)

            for k1 in ['Q1 Predictions', 'Q2 Predictions']:
                for k2 in ['Mean', 'Std', 'Max', 'Min']:
                    k = ' '.join([k1, k2])
                    self.eval_wandb[k] = self.eval_statistics.get(k, 0)

        self._n_train_steps_total += 1
        check_phase_transition = \
            (self.update_q_phase and int(self._n_train_steps_total - self.last_phase_steps) % self._n_train_steps_qf == 0) or \
            (not self.update_q_phase and int(self._n_train_steps_total - self.last_phase_steps) % self._n_train_steps_policy == 0)
        if check_phase_transition:
            self.update_q_phase = not self.update_q_phase
            if self.update_q_phase:
                print("Updating Q starts!!")
            else:
                print("Updating Policy starts!!")
            self.last_phase_steps = self._n_train_steps_total
            print('policy optimizer and qf optimizer are reset')
            if self.update_policy_data:
                ptu.copy_model_params_from_to(self.policy, self.policy_data)
            self.reset_policy_optimizer()
            self.reset_qf_optimizer()

    def reset_policy_optimizer(self):
        self.policy_optimizer = self.optimizer_class(
            self.policy.parameters(),
            lr=self.policy_lr,
        )

    def reset_qf_optimizer(self):
        self.qf1_optimizer = self.optimizer_class(
            self.qf1.parameters(),
            lr=self.qf_lr,
        )
        self.qf2_optimizer = self.optimizer_class(
            self.qf2.parameters(),
            lr=self.qf_lr,
        )

    def get_diagnostics(self):
        return (self.eval_statistics, self.eval_wandb)

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.target_policy,
            self.policy_data,
            self.qf1,
            self.qf2,
            self.target_qf1,
            self.target_qf2,
        ]

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            policy_data=self.policy_data,
            qf1=self.qf1,
            qf2=self.qf2,
            target_qf1=self.target_qf1,
            target_qf2=self.target_qf2,
        )

    def set_snapshot(self, snapshot):
        self.policy = snapshot['policy']
        self.policy_data = snapshot['policy_data']
        self.qf1 = snapshot['qf1']
        self.qf2 = snapshot['qf2']
        self.target_qf1 = snapshot['target_qf1']
        self.target_qf2 = snapshot['target_qf2']
