import itertools

import torch

from dcrl.algos.base import BaseAlgo


class A2CAlgo(BaseAlgo):
    def __init__(
        self,
        envs,
        actor_model,
        critic_model,
        device,
        num_frames=5,
        gamma=0.99,
        lr=1e-3,
        gae_lambda=1.0,
        entropy_coef=0.01,
        value_loss_coef=0.5,
        max_grad_norm=0.5,
        reshape_reward_fn=None,
        reshape_adv_fn=None,
        rmsprop_alpha=0.99,
        rmsprop_eps=1e-5,
    ):

        super().__init__(
            envs,
            actor_model,
            critic_model,
            device,
            num_frames,
            gamma,
            lr,
            gae_lambda,
            entropy_coef,
            value_loss_coef,
            max_grad_norm,
            reshape_reward_fn,
            reshape_adv_fn,
        )

        self.parameters = itertools.chain(*(model.parameters() for model in [self.actor_model, self.critic_model]))
        self.optimizer = torch.optim.RMSprop(self.parameters, lr, alpha=rmsprop_alpha, eps=rmsprop_eps)

    def update_parameters(self, exps):
        update_entropy = 0
        update_value = 0
        update_policy_loss = 0
        update_value_loss = 0
        update_loss = 0

        dist, actor_memory = self.actor_model(exps.obs, exps.actor_memory * exps.memory_mask)
        value, critic_memory = self.critic_model(exps.obs, exps.critic_memory * exps.memory_mask)

        entropy = dist.entropy().mean()
        policy_loss = -(dist.log_prob(exps.action) * exps.advantage).mean()
        value_loss = (value - exps.returnn).pow(2).sum(dim=-1).mean()
        loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss

        update_entropy += entropy.item()
        update_value += value.mean().item()
        update_policy_loss += policy_loss.item()
        update_value_loss += value_loss.item()
        update_loss += loss

        self.optimizer.zero_grad()
        update_loss.backward()
        update_grad_norm = sum(p.grad.data.norm(2) ** 2 for p in self.parameters) ** 0.5
        torch.nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm)
        self.optimizer.step()

        logs = {
            "entropy": update_entropy,
            "value": update_value,
            "policy_loss": update_policy_loss,
            "value_loss": update_value_loss,
            "grad_norm": update_grad_norm,
        }
        return logs
