"""
Only encode the latest observation.
"""
import torch
import torch.nn as nn
from torch.nn import functional as F


class DummyEncoder(nn.Module):

    def __init__(
        self,
        obs_dim: int,
        act_dim: int,
        obs_encode_dim: int,
        act_encode_dim: int,
        rew_encode_dim: int,
        encoder_activation=F.relu,
    ):
        super().__init__()
        if obs_encode_dim <= 0:
            raise ValueError('Require obs encoder to be positive integer.')
        self.obs_encoder = nn.Linear(obs_dim, obs_encode_dim)
        if act_encode_dim > 0:
            self.act_encoder = nn.Linear(act_dim, act_encode_dim)
        else:
            self.act_encoder = None
        if rew_encode_dim > 0:
            self.rew_encoder = nn.Linear(1, rew_encode_dim)
        else:
            self.rew_encoder = None
        total_encode_dim = obs_encode_dim + act_encode_dim + rew_encode_dim
        self.out_dim = total_encode_dim
        self.encoder_activation = encoder_activation

    def forward(self, obs_seq, act_seq, rew_seq, history=None):
        """Forward pass to get encodings.

        Args:
            obs_seq: Observations of shape (batch_size, seq length, obs dim)
            act_seq: Actions of shape (batch size, seq length, act dim)
            rew_seq: Rewards of sequence (batch size, seq length, 1)
            history: None

        Returns:
            * Encodings of shape (batch size, seq length, out dim)
            * History which is None
        """
        obs_encoding = self.encoder_activation(self.obs_encoder(obs_seq))
        encoding = [obs_encoding]
        if self.act_encoder is not None:
            encoding.append(self.encoder_activation(self.act_encoder(act_seq)))
        if self.rew_encoder is not None:
            encoding.append(self.encoder_activation(self.rew_encoder(rew_seq)))
        encoding = torch.cat(encoding, dim=-1)
        return encoding, None
