"""This modules creates a continuous Q-function network."""

import torch
import torch.nn.functional as F
from garage.torch.modules import MLPModule
from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule
from torch import nn

from src.modules import AlternateAttention


class AdaptiveMLPQFunction(nn.Module):
    """
    Implements a continuous MLP Q-value network.

    It predicts the Q-value for all actions based on the history of experiments.
    It uses a PyTorch neural network module to fit the function of Q(s, a).
    Inputs to the encoder should be of the shape
    (batch_dim, history_length, obs_dim)

    Args:
            env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
            encoder_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for encoder. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        encoder_nonlinearity (callable): Activation function for intermediate
            dense layer(s) of encoder. It should return a torch.Tensor. Set it
            to None to maintain a linear activation.
        encoder_output_nonlinearity (callable): Activation function for encoder
            output dense layer. It should return a torch.Tensor. Set it to None
            to maintain a linear activation.
        encoding_dim (int): Output dimension of output dense layer for encoder.
        emitter_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for emitter.
        emitter_nonlinearity (callable): Activation function for intermediate
            dense layer(s) of emitter.
        emitter_output_nonlinearity (callable): Activation function for emitter
            output dense layer.
    """

    def __init__(
        self,
        env_spec,
        encoder_sizes=(32, 32),
        encoder_nonlinearity=nn.ReLU,
        encoder_output_nonlinearity=None,
        encoding_dim=16,
        emitter_sizes=(32, 32),
        emitter_nonlinearity=nn.ReLU,
        emitter_output_nonlinearity=None,
        **kwargs,
    ):
        super().__init__()
        self._env_spec = env_spec
        self._obs_dim = env_spec.observation_space.flat_dim
        self._action_dim = env_spec.action_space.flat_dim

        self._encoder = MLPModule(
            input_dim=self._obs_dim,
            output_dim=encoding_dim,
            hidden_sizes=encoder_sizes,
            hidden_nonlinearity=encoder_nonlinearity,
            output_nonlinearity=encoder_output_nonlinearity,
            **kwargs,
        )

        self._emitter = MLPModule(
            input_dim=encoding_dim + self._action_dim,
            output_dim=1,
            hidden_sizes=emitter_sizes,
            hidden_nonlinearity=emitter_nonlinearity,
            output_nonlinearity=emitter_output_nonlinearity,
            **kwargs,
        )

    def forward(self, observations, actions, mask=None):
        """Return Q-value(s)."""
        encoding = self._encoder.forward(observations)
        if mask is not None:
            encoding = encoding * mask
        pooled_encoding = encoding.sum(dim=-2)
        if self._env_spec.action_space.is_discrete and actions.shape[-1] == 1:
            actions = F.one_hot(actions.squeeze(dim=-1), self._action_dim)
        return self._emitter.forward(torch.cat([pooled_encoding, actions], -1))


class AdaptiveMLPQFunctionDoCausal(nn.Module):
    """
    Implements a continuous MLP Q-value network.

    It predicts the Q-value for all actions based on the history of experiments.
    It uses a PyTorch neural network module to fit the function of Q(s, a).
    Inputs to the encoder should be of the shape
    (batch_dim, history_length, obs_dim)

    Args:
            env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
            encoder_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for encoder. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        encoder_nonlinearity (callable): Activation function for intermediate
            dense layer(s) of encoder. It should return a torch.Tensor. Set it
            to None to maintain a linear activation.
        encoder_output_nonlinearity (callable): Activation function for encoder
            output dense layer. It should return a torch.Tensor. Set it to None
            to maintain a linear activation.
        encoding_dim (int): Output dimension of output dense layer for encoder.
        emitter_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for emitter.
        emitter_nonlinearity (callable): Activation function for intermediate
            dense layer(s) of emitter.
        emitter_output_nonlinearity (callable): Activation function for emitter
            output dense layer.
    """

    def __init__(
        self,
        env_spec,
        encoding_dim=16,
        batch_size=1,
        encoder_widening_factor=2,
        encoder_dropout=0.1,
        encoder_n_layers=1,
        encoder_num_heads=8,
        pooling="max",
        emitter_sizes=(32, 32),
        emitter_nonlinearity=nn.ReLU,
        emitter_output_nonlinearity=None,
        is_single_target=False,
        no_value=False,
        **kwargs,
    ):
        super().__init__()
        self._env_spec = env_spec
        self._obs_dim = env_spec.observation_space.shape[-1]
        self._action_dim = env_spec.action_space.flat_dim
        self._pooling = pooling
        self.is_single_target = is_single_target
        self.no_value = no_value

        self._encoder = AlternateAttention(
            dim=encoding_dim,
            feedforward_dim=encoder_widening_factor * encoding_dim,
            dropout=encoder_dropout,
            n_layers=encoder_n_layers,
            num_heads=encoder_num_heads,
        )
        self._encoder = nn.DataParallel(self._encoder)
        if no_value:
            self._emitter = MLPModule(
                input_dim=(
                    encoding_dim if is_single_target else encoding_dim + batch_size
                ),
                output_dim=1,
                hidden_sizes=emitter_sizes,
                hidden_nonlinearity=emitter_nonlinearity,
                output_nonlinearity=emitter_output_nonlinearity,
                **kwargs,
            )
        elif is_single_target and not no_value:
            self._emitter = MultiHeadedMLPModule(
                n_heads=5,
                input_dim=encoding_dim,
                output_dims=self._action_dim,
                hidden_sizes=(32, 32),
                hidden_nonlinearity=torch.nn.ReLU,
                hidden_w_init=nn.init.xavier_uniform_,
                hidden_b_init=nn.init.zeros_,
                output_nonlinearities=None,
                output_w_inits=nn.init.xavier_uniform_,
                output_b_inits=nn.init.zeros_,
                layer_normalization=False,
            )
        else:
            self._emitter = MLPModule(
                input_dim=(
                    encoding_dim if is_single_target else encoding_dim + 2 * batch_size
                ),
                output_dim=self._action_dim + 4 if is_single_target else 1,
                hidden_sizes=emitter_sizes,
                hidden_nonlinearity=emitter_nonlinearity,
                output_nonlinearity=emitter_output_nonlinearity,
                **kwargs,
            )
        self._emitter = nn.DataParallel(self._emitter)

    def forward(self, observations, actions=None, mask=None):
        """Return Q-value(s)."""
        encoding = self._encoder.forward(observations)
        if self._pooling:
            if self._pooling == "max":
                if mask is not None:
                    min_value = torch.min(encoding).detach()
                    encoding[~mask.expand(*encoding.shape)] = min_value
                encoding = torch.max(encoding, -3).values
            elif self._pooling == "sum":
                if mask is not None:
                    encoding = encoding * mask
                encoding = torch.sum(encoding, -3)
            else:
                raise NotImplementedError(f"{self._pooling} not implemented")

        # if self._env_spec.action_space.is_discrete and actions.shape[-1] == 1:
        #    actions = F.one_hot(actions.squeeze(dim=-1), self._action_dim)
        if actions is not None:
            if self.is_single_target and not self.no_value:
                actions = actions.unsqueeze(-1)
            elif self.no_value:
                actions = actions.unsqueeze(-1)
            else:
                actions = actions.reshape(*actions.shape[:-1], -1, 2)
            actions = actions.transpose(-2, -3)
            actions = actions.reshape(*actions.shape[:-2], -1)
        if self.is_single_target and not self.no_value:
            q_vals, q_vals_vals, q_vals_vals_1, q_vals_obs, q_vals_obs_1 = (
                self._emitter.forward(encoding)
            )
            q_vals_vals, q_vals_vals_1 = q_vals_vals.sum(-2, keepdim=True).sum(
                -1
            ), q_vals_vals_1.sum(-2, keepdim=True).sum(-1)
            q_vals_obs, q_vals_obs_1 = q_vals_obs.sum(-2, keepdim=True).sum(
                -1
            ), q_vals_obs_1.sum(-2, keepdim=True).sum(-1)
            q_vals = torch.cat(
                [q_vals.sum(-1), q_vals_vals, q_vals_vals_1, q_vals_obs, q_vals_obs_1],
                -1,
            )

        else:
            q_vals = self._emitter.forward(torch.cat([encoding, actions], -1)).sum(-2)
        return q_vals
