import torch
import torch.nn as nn

from dcrl.utils.model_utils import layer_init


class ObsCriticObsStateCriticModel(nn.Module):
    def __init__(
        self,
        obs_space,
        recurrent=True,
        hidden_size_list=None,
        rnn_hidden_size=None,
        get_representation_net_func=None,
    ):
        super().__init__()

        self.recurrent = recurrent
        self.hidden_size_list = hidden_size_list
        self.rnn_hidden_size = rnn_hidden_size

        # Define representation embedding
        self.obs_representation_net, self.obs_tensor_type, self.obs_rep_size = get_representation_net_func(obs_space["obs"])
        self.state_representation_net, self.state_tensor_type, self.state_rep_size = get_representation_net_func(
            obs_space["state"]
        )

        self.obs_embedding_size = self.obs_rep_size
        self.obs_state_embedding_size = self.obs_rep_size + self.state_rep_size

        # Define memory
        if self.recurrent:
            self.obs_memory_rnn = nn.LSTMCell(self.obs_rep_size, self.rnn_hidden_size)
            for name, param in self.obs_memory_rnn.named_parameters():
                if "bias" in name:
                    nn.init.constant_(param, 0)
                elif "weight" in name:
                    nn.init.orthogonal_(param, 1.0)
            self.obs_embedding_size += self.rnn_hidden_size

            self.obs_state_memory_rnn = nn.LSTMCell(self.obs_rep_size, self.rnn_hidden_size)
            for name, param in self.obs_state_memory_rnn.named_parameters():
                if "bias" in name:
                    nn.init.constant_(param, 0)
                elif "weight" in name:
                    nn.init.orthogonal_(param, 1.0)
            self.obs_state_embedding_size += self.rnn_hidden_size

        # Define value net
        self.obs_value_net = []
        input_size = self.obs_embedding_size
        for i in range(len(self.hidden_size_list)):
            self.obs_value_net.append(layer_init(nn.Linear(input_size, self.hidden_size_list[i])))
            self.obs_value_net.append(nn.ReLU())
            input_size = self.hidden_size_list[i]
        self.obs_value_net = nn.Sequential(*self.obs_value_net)
        self.obs_value_head = layer_init(nn.Linear(input_size, 1), std=1.0)

        self.obs_state_value_net = []
        input_size = self.obs_state_embedding_size
        for i in range(len(self.hidden_size_list)):
            self.obs_state_value_net.append(layer_init(nn.Linear(input_size, self.hidden_size_list[i])))
            self.obs_state_value_net.append(nn.ReLU())
            input_size = self.hidden_size_list[i]
        self.obs_state_value_net = nn.Sequential(*self.obs_state_value_net)
        self.obs_state_value_head = layer_init(nn.Linear(input_size, 1), std=1.0)

    @property
    def memory_size(self):
        return 4 * self.rnn_hidden_size

    @property
    def value_size(self):
        return 2

    def forward(self, obs, memory=None):
        x_obs = obs["obs"].type(self.obs_tensor_type)
        x_state = obs["state"].type(self.state_tensor_type)
        obs_embedding = self.obs_representation_net(x_obs)
        state_embedding = self.state_representation_net(x_state)

        if self.recurrent:
            obs_hidden = (memory[:, : self.rnn_hidden_size], memory[:, self.rnn_hidden_size : 2 * self.rnn_hidden_size])
            obs_hidden = self.obs_memory_rnn(obs_embedding, obs_hidden)
            obs_input = torch.cat([obs_hidden[0], obs_embedding], dim=1)

            obs_state_hidden = (
                memory[:, 2 * self.rnn_hidden_size : 3 * self.rnn_hidden_size],
                memory[:, 3 * self.rnn_hidden_size :],
            )
            obs_state_hidden = self.obs_state_memory_rnn(obs_embedding, obs_state_hidden)
            obs_state_input = torch.cat([obs_state_hidden[0], obs_embedding, state_embedding], dim=1)

            memory = torch.cat([obs_hidden[0], obs_hidden[1], obs_state_hidden[0], obs_state_hidden[1]], dim=1)
        else:
            obs_input = obs_embedding
            obs_state_input = torch.cat([obs_embedding, state_embedding], dim=1)

        obs_output = self.obs_value_net(obs_input)
        obs_value = self.obs_value_head(obs_output)

        obs_state_output = self.obs_state_value_net(obs_state_input)
        obs_state_value = self.obs_state_value_head(obs_state_output)

        value = torch.cat((obs_value, obs_state_value), dim=-1)

        return value, memory
