from typing import List
import torch
import torch.nn as nn
import numpy as np
import exp_utils as PQ


class EnsembleModel(nn.Module):
    def __init__(self, models: List[nn.Module]):
        super().__init__()
        self.models = nn.ModuleList(models)
        self.n_models = len(models)
        self.n_elites = self.n_models
        self.elites = []
        self.recompute_elites()

    def recompute_elites(self):
        self.elites = list(range(len(self.models)))

    def forward(self, states, actions):
        n = len(states)

        perm = np.random.permutation(n)
        inv_perm = np.argsort(perm)

        next_states = []
        for i, (model_idx, indices) in enumerate(zip(self.elites, np.array_split(perm, len(self.elites)))):
            next_states.append(self.models[model_idx](states[indices], actions[indices]))
        return torch.cat(next_states, dim=0)[inv_perm]


class EnsembleUncertainty(nn.Module):
    def __init__(self, ensemble, lyapunov, buffer=None):
        super().__init__()
        self.ensemble = ensemble
        self.lyapunov = lyapunov
        PQ.log.warning(f"[Ensemble Uncertainty]: models = {ensemble.elites}")

    def forward(self, states, actions, add=False):
        next_states = [self.ensemble.models[idx](states, actions) for idx in self.ensemble.elites]
        Us = torch.stack([self.lyapunov(next_state) for next_state in next_states])
        L_sp = Us.max(dim=0)[0]  # + Us.std(dim=0)
        return L_sp
