from typing import List,Dict

import torch
from disentangle.AbstractTrainer import AbstractLitModule
from models.encoders import MLP
from torch.distributions import Normal


class AdaLitModule(AbstractLitModule):
    def __init__(
        self,
        state_dim: int,
        n_step: int,
        n_iv_steps: int,
        param_dim: int,
        n_views: int,
        learning_rate: float = 1e-5,
        eval_metrics=[],
        code_sharing: Dict[int,List[int]] = None,
        factor_type="discrete",
        device="cuda:0",
        **kwargs,
    ):
        super().__init__(
            state_dim,
            n_step,
            n_iv_steps,
            param_dim,
            n_views,
            learning_rate,
            eval_metrics,
            code_sharing,
            factor_type,
        )

        self.encoder = MLP(
            input_dim=self.input_dim, output_dim=param_dim, hidden_dim=kwargs.get("hidden_dim", 1024), num_layers=5
        )

        self.decoder = MLP(
            input_dim=param_dim + state_dim * n_iv_steps,
            output_dim=state_dim * n_step,
            hidden_dim=kwargs.get("hidden_dim", 1024),
            num_layers=5,
        )

        self.mean_head = MLP(
            input_dim=param_dim, output_dim=param_dim, hidden_dim=kwargs.get("hidden_dim", 1024), num_layers=2
        )

        self.logvar_head = MLP(
            input_dim=param_dim, output_dim=param_dim, hidden_dim=kwargs.get("hidden_dim", 1024), num_layers=2
        )

        # AdaGVAE-specific settings
        assert len(code_sharing) == 1, "only one augmented view"
        self.loss = torch.nn.MSELoss()
        self.prior = Normal(torch.zeros(param_dim).to(device), torch.ones(param_dim).to(device))
        self.shared_dims = list(code_sharing.keys())[0]

    def encode(self, states: torch.Tensor):
        if self.dct_layer:
            states: torch.Tensor = self.state_transform(states)
        params = self.encoder(states.reshape(-1, states.shape[-2] * states.shape[-1])).reshape(
            states.shape[0], -1, self.param_dim
        ) # (n_views, batch_size, param_dim)
        means = self.mean_head(params)
        logvars = self.logvar_head(params)
        scales = torch.exp(logvars / 2)
        if len(params) == 1:
            return [torch.distributions.Normal(means, scales)]
        else:
            shared = means.clone()
            avg = (means[0, :, self.shared_dims[0]] + means[1, :, self.shared_dims[0]]) / 2
            shared[:, :, self.shared_dims[0]] = avg.expand_as(shared[:, :, self.shared_dims[0]])
            scales = torch.exp(logvars / 2)
            return [Normal(shared[0], scales[0]), Normal(shared[1], scales[1])]

    def decode(self, posteriors: torch.distributions.Distribution, states: torch.Tensor):
        num_views, batch_size, _, _ = states.shape
        if len(posteriors) > 1:  # training mode
            zs = torch.stack([p.rsample() for p in posteriors], 0)  # shape: [num_views, batch_size, param_dim]
        else:
            zs = posteriors[0].mean  # shape: [num_views=1, batch_size, param_dim]
        ivps = states[..., : self.n_iv_steps, :]  # shape: [num_views, batch_size, n_iv_steps, state_dim]
        latents = torch.cat([zs, ivps.reshape(zs.shape[0], zs.shape[1], -1)], -1)
        return self.decoder(latents.reshape(-1, latents.shape[-1])).reshape(num_views, batch_size, -1, self.state_dim)

    def forward(self, states: torch.Tensor):
        posteriors = self.encode(states)
        u0 = self.decode(posteriors, states)
        return u0.reshape(*states.shape), posteriors

    def training_step(self, batch, batch_idx):
        states = batch["states"].float().reshape(self.n_views, -1, self.n_step, self.state_dim)
        u0, posteriors = self.forward(states)
        recon = self.loss(u0, states)
        D_kl = torch.stack([torch.distributions.kl.kl_divergence(p, self.prior) for p in posteriors]).sum()
        loss = recon + D_kl
        self.log("recon_loss", recon, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("kl_loss", D_kl, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        states = batch["states"].float().reshape(1, -1, self.n_step, self.state_dim)
        u0, posteriors = self.forward(states)
        recon = self.loss(u0, states)
        self.misc["pred_params"].append(posteriors[0].mean.cpu().numpy().squeeze())
        if "gt_params" in batch:
            if isinstance(batch["gt_params"], dict):
                self.misc["gt_params"].append(torch.stack(list(batch["gt_params"].values()), -1).cpu().numpy())
            else:
                self.misc["gt_params"].append(batch["gt_params"].cpu().numpy())
        self.log("val_loss", recon, on_epoch=True, prog_bar=True, logger=True)

    def predict_step(self, batch, batch_id):
        # select the first half as encoder input
        states = batch["states"].float()
        input_states = states[..., : self.n_step, :]
        future_states = states[..., self.n_step :, :]

        posteriors = self.encode(input_states)
        zs = posteriors[0].mean  # shape: [num_views=1, batch_size, param_dim]
        ivps = future_states[..., : self.n_iv_steps, :]  # shape: [num_views, batch_size, n_iv_steps, state_dim]
        latents = torch.cat([zs, ivps.reshape(zs.shape[0], zs.shape[1], -1)], -1)
        u0s = self.decoder(latents.reshape(-1, latents.shape[-1])).reshape(*future_states.shape)
        forecast_loss = self.loss(u0s, future_states)
        return forecast_loss

    def test_step(self, batch, batch_idx):
        pass
