import torch
import torch.nn as nn
import transformers
from diffuser.models.bert import BertModel
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class StateEncoderTransformer(nn.Module):
    def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        output_size,
        max_ep_len=4096,
        repre_type='vec',
        **kwargs
    ):
        super().__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim

        self.hidden_size = hidden_size
        config = transformers.BertConfig(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            hidden_size=hidden_size,
            **kwargs
        )
        self.output_size = output_size

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = BertModel(config)

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
        self.embed_task = nn.Embedding(max_ep_len, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)
        self.repre_type = repre_type
        if self.repre_type == 'vec':
            self.to_phi = nn.Linear(self.hidden_size, self.output_size)
        elif self.repre_type == 'dist':
            self.to_phi_mean = nn.Linear(self.hidden_size, self.output_size)
            self.to_phi_std = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, states, timesteps, attention_mask=None, task_ids=None):
        # state:(32,20,11), timesteps:(32,20)

        batch_size, seq_length = states.shape[0], states.shape[1] # 32, 100
        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long).to(DEVICE)

        # embed each modality with a different head
        state_embeddings = self.embed_state(states) # (32,20,128)
        task_embeddings = self.embed_task(task_ids) # (32,20,128)
        # action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps) # (32,20,128)
        # print(state_embeddings.shape, task_embeddings.shape, time_embeddings.shape)
        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings + task_embeddings
        stacked_inputs = self.embed_ln(state_embeddings)

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs, # (32,20,128)
            attention_mask=attention_mask, # (32,20)
        )
        x = transformer_outputs["last_hidden_state"] # (32,20,128)
        x = x.sum(dim=1)
        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
#         x = x.reshape(batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3)

#         x = x.sum(dim=2).sum(dim=1)
        if self.repre_type == 'vec':
            return self.to_phi(x) # (32,16)
        elif self.repre_type == 'dist':
            std = torch.clamp(self.to_phi_std(x), min=-5, max=2)
            return self.to_phi_mean(x), std