import torch
from torch import nn


class TransformerEncoder(nn.Module):
    def __init__(self, n_step: int, state_dim: int, param_dim: int, embedding_dim: int, hidden_dim: int, **kwargs):
        super(TransformerEncoder, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=2)

        self.embedding = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),  # input: [bs, n_step, state_dim]
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, embedding_dim),  # input: [bs, n_step, state_dim]
        )
        self.transformer = nn.Sequential(
            nn.TransformerEncoder(encoder_layer, num_layers=2),
            nn.Flatten(),
            nn.Linear(n_step * embedding_dim, param_dim),
        )

    def forward(self, states: torch.Tensor):
        # states [bs, ts, state_dim]
        emb = self.embedding(states.reshape(-1, states.shape[-1]))
        return self.transformer(emb.reshape(states.shape[0], -1, emb.shape[-1]))  # (bs, param_dim)


class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_layers=5, **kwargs):
        super(MLP, self).__init__()
        self.model = nn.Sequential()

        for _ in range(num_layers):
            self.model.append(nn.Linear(input_dim, hidden_dim))
            self.model.append(nn.LeakyReLU())
            input_dim = hidden_dim
        # append output layer
        self.model.append(nn.Linear(hidden_dim, output_dim))

    def forward(self, inputs: torch.Tensor):
        # states [bs, ts, state_dim]
        return self.model(inputs)  # out: (bs, param_dim)
