import os
import pickle as pkl

import torch
import wandb
from gym import Env

import src
from src import set_rng_seed
from src.spaces import BatchBox

LOWER = 0
UPPER = 1
TERMINAL = 2


class AdaptiveIntervDesignEnv(Env):
    def __init__(
        self,
        model,
        budget,
        l,
        batch_size=1,
        bound_type=LOWER,
        num_initial_obs=0,
        zero_bias=True,
        action_space=None,
        observation_space=None,
    ):
        """
        A generic class for building a SED MDP for do causal experiments (Hard Interventions with intervention targets and values)

        args:
            design_space (gym.Space): the space of experiment designs
            model_space (gym.Space): the space of model parameterisations
            outcome_space (gym.Space): the space of experiment outcomes
            model (models.ExperimentModel): a model of experiment outcomes
            true_model (models.ExperimentModel): a ground-truth model
            M (int): number of trajectories per sample of theta
            N (int): number of samples of theta
        """
        num_nodes = model.d
        if action_space is None:
            self.action_space = BatchBox(
                low=-1.0, high=1.0, shape=(1, batch_size, 2 * num_nodes)
            )
        else:
            self.action_space = action_space
        if observation_space is None:
            self.observation_space = BatchBox(
                low=torch.as_tensor([[-10.0] * num_nodes, [-10.0] * num_nodes]).T,
                high=torch.as_tensor([[10.0] * num_nodes, [10.0] * num_nodes]).T,
            )
        else:
            self.observation_space = observation_space
        self.model = model
        self.n_parallel = model.n_parallel
        self.l = l
        self.num_initial_obs = num_initial_obs
        self.budget = budget
        # self.M = M
        # self.N = N
        self.bound_type = bound_type
        self.log_products_upper = None
        self.last_logsumprod_upper = None
        self.log_products_lower = None
        self.last_logsumprod_lower = None
        self.history = []
        self.batch_size = batch_size
        self.thetas = None
        self.theta0 = None
        self.zero_bias = zero_bias
        self.eval = False

    def reset(self, n_parallel=1):
        self.model.reset(n_parallel=n_parallel)
        self.n_parallel = n_parallel
        self.history = []
        self.log_products_lower = torch.zeros(
            (
                self.l + 1,
                self.n_parallel,
            )
        )

        self.last_logsumprod_lower = torch.logsumexp(self.log_products_lower, dim=0)
        self.log_products_upper = torch.zeros(
            (
                self.l,
                self.n_parallel,
            )
        )
        self.last_logsumprod_upper = torch.logsumexp(self.log_products_upper, dim=0)
        self.thetas = self.model.sample_theta(self.l + 1, zero_bias=self.zero_bias)
        wandb.log(
            {"Environment/graph_density": self.thetas["graph"].sum((-1, -2)).mean()}
        )
        # if self.M != 1 and self.M * self.N == n_parallel:
        #     for k, v in self.thetas.items():
        #         self.thetas[k] = v[:, :self.N].repeat_interleave(self.M, dim=1)
        # index theta correctly because it is a dict
        self.theta0 = {k: v[0] for k, v in self.thetas.items()}
        reward = torch.zeros(self.n_parallel)
        info = {}
        if self.num_initial_obs > 0:
            design_shape = (
                tuple(self.action_space.shape[:-3])
                + (self.n_parallel, self.num_initial_obs)
                + self.action_space.shape[-1:]
            )
            design = torch.zeros(*design_shape)
            y = self.model.run_experiment(design, self.theta0)
            self.history.extend(
                torch.stack(
                    [
                        y,
                        design[..., : int(self.action_space.shape[-1] / 2)].squeeze(
                            dim=-2
                        ),
                    ],
                    dim=-1,
                ).unbind(1)
            )
            reward, reward_c = self.get_reward(y, design)
            info = {"y": y.squeeze()}
            if reward_c is not None:
                info["reward_c"] = reward_c
        return self.get_obs(), reward, False, info

    def step(self, action):
        design = torch.as_tensor(action)
        # y = self.true_model(design)
        y = self.model.run_experiment(design, self.theta0)
        self.history.extend(
            torch.stack(
                [
                    y,
                    (design[..., : int(self.action_space.shape[-1] / 2)] > 0).to(
                        y.dtype
                    ),  # Directly append the hard intervention
                ],
                dim=-1,
            ).unbind(1)
        )
        obs = self.get_obs()
        reward, reward_c = self.get_reward(y, design)
        done = self.terminal()
        done = done * torch.ones_like(reward, dtype=torch.bool)
        info = {"y": y.squeeze()}
        if reward_c is not None:
            info["reward_c"] = reward_c
        return obs, reward, done, info

    def get_obs(self):
        if self.history:
            return torch.stack(self.history, dim=-3)
        else:
            return torch.zeros(
                (self.n_parallel, 0, *self.observation_space.shape[-2:]),
            )

    def terminal(self):
        return len(self.history) >= self.budget * self.batch_size + self.num_initial_obs
        # return False

    def get_reward(self, y, design):
        with torch.no_grad():
            log_probs = self.model.get_likelihoods(
                y, design, self.thetas
            ).sum(  # l+1 x n_parallel x B
                dim=-1
            )
        log_prob0 = log_probs[0]
        if self.bound_type == LOWER or self.eval:
            # maximise lower bound
            self.log_products_lower += log_probs
            logsumprod_lower = torch.logsumexp(self.log_products_lower, dim=0)
            reward_lower = log_prob0 + self.last_logsumprod_lower - logsumprod_lower
            self.last_logsumprod_lower = logsumprod_lower
        if self.bound_type == UPPER or self.eval:
            # maximise upper bound
            self.log_products_upper += log_probs[1:]
            logsumprod_upper = torch.logsumexp(self.log_products_upper, dim=0)
            reward_upper = log_prob0 + self.last_logsumprod_upper - logsumprod_upper
            self.last_logsumprod_upper = logsumprod_upper
        if self.bound_type == TERMINAL:
            self.log_products_lower += log_probs
            logsumprod_lower = torch.logsumexp(self.log_products, dim=0)
            if self.terminal():
                reward_lower = (
                    self.log_products_lower[0]
                    - logsumprod_lower
                    + torch.log(torch.as_tensor(self.l + 1.0))
                )
            else:
                reward_lower = torch.zeros(self.n_parallel)
            self.last_logsumprod_lower = logsumprod_lower
        reward = reward_lower if self.bound_type in [LOWER, TERMINAL] else reward_upper
        if self.eval:
            if self.bound_type in [LOWER, TERMINAL]:
                return reward, reward_upper
            else:
                return reward, reward_lower
        return reward, None

    def render(self, mode="human"):
        pass


class AdaptiveIntervDesignEnvEval(AdaptiveIntervDesignEnv):
    def __init__(
        self,
        model,
        budget,
        l,
        batch_size=1,
        bound_type=LOWER,
        num_initial_obs=0,
        zero_bias=True,
        data_seed=0,
        save_path=None,
        action_space=None,
        observation_space=None,
    ):
        """
        A generic class for building a SED MDP for do causal experiments (Hard Interventions with intervention targets and values)

        args:
            design_space (gym.Space): the space of experiment designs
            model_space (gym.Space): the space of model parameterisations
            outcome_space (gym.Space): the space of experiment outcomes
            model (models.ExperimentModel): a model of experiment outcomes
            true_model (models.ExperimentModel): a ground-truth model
            M (int): number of trajectories per sample of theta
            N (int): number of samples of theta
        """
        super().__init__(
            model=model,
            budget=budget,
            l=l,
            batch_size=batch_size,
            bound_type=bound_type,
            num_initial_obs=num_initial_obs,
            zero_bias=zero_bias,
            action_space=action_space,
            observation_space=observation_space,
        )
        self.data_seed = data_seed
        self.eval = True
        set_rng_seed(data_seed)
        self.thetas = self.model.sample_theta(self.l + 1, zero_bias=zero_bias)
        # if self.M != 1 and self.M * self.N == n_parallel:
        #     for k, v in self.thetas.items():
        #         self.thetas[k] = v[:, :self.N].repeat_interleave(self.M, dim=1)
        # index theta correctly because it is a dict
        self.theta0 = {k: v[0] for k, v in self.thetas.items()}
        if save_path is not None:
            theta0 = {k: v[0].cpu().numpy() for k, v in self.thetas.items()}
            theta0["noise_type"] = self.model.noise_type
            os.makedirs(save_path, exist_ok=True)
            pkl.dump(theta0, open(f"{save_path}/theta0.pkl", "wb"))
        self.init_design = None
        if self.num_initial_obs > 0:
            design_shape = (
                tuple(self.action_space.shape[:-3])
                + (self.n_parallel, self.num_initial_obs)
                + self.action_space.shape[-1:]
            )
            self.init_design = torch.zeros(*design_shape)
            self.init_y = self.model.run_experiment(self.init_design, self.theta0)

    def reset(self, n_parallel=1):
        self.history = []
        if self.num_initial_obs > 0:
            self.history.extend(
                torch.stack(
                    [
                        self.init_y,
                        self.init_design[
                            ..., : int(self.action_space.shape[-1] / 2)
                        ].squeeze(dim=-2),
                    ],
                    dim=-1,
                ).unbind(1)
            )
        self.log_products_lower = torch.zeros(
            (
                self.l + 1,
                self.n_parallel,
            )
        )
        self.last_logsumprod_lower = torch.logsumexp(self.log_products_lower, dim=0)
        self.log_products_upper = torch.zeros(
            (
                self.l,
                self.n_parallel,
            )
        )
        self.last_logsumprod_upper = torch.logsumexp(self.log_products_upper, dim=0)
        if self.init_design is not None:
            reward, reward_c = self.get_reward(self.init_y, self.init_design)
        else:
            reward = torch.zeros(self.n_parallel)
            reward_c = reward.clone()
        info = {"y": self.init_y.squeeze()} if self.init_design is not None else {}
        info["reward_c"] = reward_c
        return self.get_obs(), reward, False, info
