import abc
import dataclasses
import enum
import functools
import os
from typing import Generic, Mapping, TypeVar

import chex
import jax
import jax.numpy as jnp
import numpy as np
import tqdm
from absl import flags
from chex._src.pytypes import PRNGKey
from clu import metric_writers as clu_metric_writers
from clu import periodic_actions
from etils import epath
from matplotlib import pyplot as plt

from tabular_mvdrl import plotting
from tabular_mvdrl.envs.mrp import MarkovRewardProcess
from tabular_mvdrl.types import (
    MRPTransitionBatch,
    TaggedMetrics,
)
from tabular_mvdrl.utils.discrete_distributions import (
    DiscreteDistribution,
    ProbabilityMetric,
)

_METRIC_WRITERS = flags.DEFINE_multi_enum(
    "metric-writer",
    default=["cli"],
    enum_values=["cli", "aim"],
    help="Metric writer to use for logging",
)
_EXPERIMENT = flags.DEFINE_string(
    "experiment", "tabular_mvdrl", help="Experiment name for logging"
)
_WORKDIR = epath.DEFINE_path(
    "workdir", default="results", help="Base directory for experiment artifacts"
)

_SAVE = flags.DEFINE_bool("save", default=True, help="Save data to disk")

StateT = TypeVar("StateT")


class EvalStage(enum.IntFlag):
    IGNORE = enum.auto()
    WARMUP = enum.auto()
    MID_TRAINING = enum.auto()
    POST_TRAINING = enum.auto()


@dataclasses.dataclass(frozen=True, kw_only=True)
class Trainer(Generic[StateT], abc.ABC):
    env: MarkovRewardProcess
    seed: int
    num_steps: int
    write_metrics_interval_steps: int
    eval_interval_steps: int
    num_eval_mc_samples: int = 1000

    @property
    def id_suffix(self) -> str:
        return f"{self.env.__class__.__name__}-v{self.seed}"

    @property
    def identifier(self) -> str:
        return "AbstractTrainer"

    @property
    @abc.abstractmethod
    def state(self) -> StateT: ...

    @functools.cached_property
    def metric_writer(self) -> clu_metric_writers.MetricWriter:
        writer_enums = _METRIC_WRITERS.value
        if len(writer_enums) == 0:
            writer_enums = ["cli"]

        writers: list[clu_metric_writers.MetricWriter] = []

        for writer in _METRIC_WRITERS.value:
            match writer:
                case "cli":
                    writers.append(clu_metric_writers.LoggingWriter())
                case "aim":
                    from tabular_mvdrl.utils.metric_writers.aim_writer import AimWriter

                    writers.append(
                        AimWriter(experiment=_EXPERIMENT.value, log_system_params=True)
                    )
                case "tensorboard" | "comet" | "wandb":
                    raise NotImplementedError(
                        f"Metric writer {writer} not yet supported."
                    )
                case _:
                    raise ValueError(f"Unknown metric writer: {writer}")
        return clu_metric_writers.MultiWriter(writers)

    @abc.abstractmethod
    def train_step(
        self, key: chex.PRNGKey, state: StateT, batch: MRPTransitionBatch, **kwargs
    ) -> StateT: ...

    def mid_training_eval(self, rng: chex.PRNGKey, state: StateT) -> TaggedMetrics:
        return {"scalars": {}}

    def post_training_eval(
        self, rng: chex.PRNGKey, state: StateT
    ) -> Mapping[str, chex.Scalar]:
        return {"scalars": {}}

    def return_distribution(self, state: StateT, i: int) -> DiscreteDistribution:
        locs = state.apply_fn(state.params, jnp.array(i))
        return DiscreteDistribution.empirical_from(locs)

    def train(self) -> None:
        rng = jax.random.PRNGKey(self.seed)
        state = self.state
        train_step = self.train_step

        report_progress = periodic_actions.ReportProgress(
            num_train_steps=self.num_steps, writer=self.metric_writer
        )

        def _write_metrics(step: int, t: float | None = None) -> None:
            nonlocal state
            del t
            self.metric_writer.write_scalars(step, state.metrics.compute())
            state = state.replace(metrics=state.metrics.empty())

        def _eval(step: int, t: float | None = None) -> None:
            nonlocal state
            del t
            rng = jax.random.fold_in(jax.random.PRNGKey(self.seed), step)
            eval_metrics = self.mid_training_eval(rng, state)
            self.metric_writer.write_scalars(step, eval_metrics["scalars"])

        callbacks = [
            # TODO: maybe save state
            periodic_actions.PeriodicCallback(
                every_steps=self.write_metrics_interval_steps,
                callback_fn=_write_metrics,
            ),
            periodic_actions.PeriodicCallback(
                every_steps=self.eval_interval_steps, callback_fn=_eval
            ),
            report_progress,
        ]

        @jax.jit
        def _train_step(i: int, state: StateT) -> StateT:
            key = jax.random.fold_in(rng, i)
            key, batch_key, train_key = jax.random.split(key, 3)
            batch_keys = jnp.array(jax.random.split(batch_key, self.env.num_states))
            cumulants, next_states = jax.vmap(
                self.env.sample_from_state, in_axes=(0, 0)
            )(batch_keys, jnp.arange(self.env.num_states))
            batch = MRPTransitionBatch(
                o_t=jnp.arange(self.env.num_states), r_t=cumulants, o_tp1=next_states
            )
            state = train_step(train_key, state, batch)
            return state

        for step in tqdm.tqdm(range(self.num_steps)):
            state = _train_step(step, state)
            for callback in callbacks:
                callback(step)

        post_training_results = self.post_training_eval(
            jax.random.fold_in(rng, self.num_steps), state
        )
        self.metric_writer.write_scalars(
            self.num_steps, post_training_results.pop("scalars", {})
        )


@dataclasses.dataclass(frozen=True, kw_only=True)
class MVDRLTransferTrainer(Trainer[StateT]):
    dsf_metric: ProbabilityMetric
    return_metric: ProbabilityMetric
    # dsf_eval_stage: EvalStage = (
    #     EvalStage.IGNORE
    # )
    dsf_eval_stage: EvalStage = EvalStage.POST_TRAINING
    return_eval_stage: EvalStage = EvalStage.POST_TRAINING
    num_eval_reward_fns: int = 100
    num_demo_reward_fns: int = 5
    max_eval_states: int = 5

    @property
    def identifier(self) -> str:
        return "AbstractTransferTrainer"

    @functools.cached_property
    def eval_reward_functions(self) -> chex.Array:
        base_reward_fns = jnp.eye(self.env.reward_dim)
        random_reward_fns = jax.random.uniform(
            jax.random.PRNGKey(42),
            shape=(self.num_eval_reward_fns - self.env.reward_dim, self.env.reward_dim),
            maxval=1.0,
            minval=-1.0,
        )
        return jnp.concatenate([base_reward_fns, random_reward_fns], axis=0)

    @functools.cached_property
    def mc_dsf(self) -> list[DiscreteDistribution]:
        num_states = self.env.num_states
        rng, *mc_keys = jax.random.split(
            jax.random.PRNGKey(42), 1 + num_states * self.num_eval_mc_samples
        )
        mc_keys = jnp.reshape(
            jnp.array(mc_keys), (num_states, self.num_eval_mc_samples, -1)
        )
        monte_carlo_returns = jax.vmap(
            jax.vmap(self.env.monte_carlo_return, in_axes=(0, None, None)),
            in_axes=(0, 0, None),
        )(mc_keys, jnp.arange(num_states), self.discount)
        return jax.vmap(DiscreteDistribution.empirical_from)(
            jnp.squeeze(monte_carlo_returns)
        )

    @functools.cached_property
    def mc_returns(self) -> list[DiscreteDistribution]:
        dsf = self.mc_dsf
        rs = self.eval_reward_functions
        mc_returns = jax.vmap(
            jax.vmap(DiscreteDistribution.pushforward_linear, in_axes=(0, None)),
            in_axes=(None, 0),
        )(dsf, rs)
        return mc_returns

    def eval_dsf(
        self, rng: chex.PRNGKey, state: StateT, final: bool = False
    ) -> TaggedMetrics:
        num_states = min(self.max_eval_states, self.env.num_states)
        eta_mc = jax.tree_util.tree_map(lambda x: x[:num_states], self.mc_dsf)
        eta_pred = jax.vmap(self.return_distribution, in_axes=(None, 0))(
            state, jnp.arange(num_states)
        )
        pred_tag = self.identifier.split(":")[0]
        loss = jnp.mean(jax.vmap(self.dsf_metric)(eta_pred, eta_mc))
        if self.env.reward_dim == 1:
            for idx in range(num_states):
                fig = plotting.plot_univariate_empirical_distributions(
                    {
                        "MC": jax.tree_util.tree_map(lambda x: x[idx], eta_mc),
                        pred_tag: jax.tree_util.tree_map(lambda x: x[idx], eta_pred),
                    }
                )
                self.metric_writer.write_images(
                    state.step.item(), {f"dsf-distributions/{idx}": fig}
                )
                plt.close(fig)
        elif self.env.reward_dim == 2:
            for idx in range(num_states):
                fig = plotting.plot_bivariate_empirical_distributions_kde(
                    {
                        "MC": jax.tree_util.tree_map(lambda x: x[idx], eta_mc),
                        pred_tag: jax.tree_util.tree_map(lambda x: x[idx], eta_pred),
                    },
                    scatter_tag=pred_tag,
                )
                self.metric_writer.write_images(
                    state.step.item(), {f"dsf-distributions/{idx}": fig}
                )
                # fig.savefig("results/")
                if ((idx == 0) and final) and _SAVE.value:
                    os.makedirs(f"results/{_EXPERIMENT.value}", exist_ok=True)
                    tag_components = self.identifier.split(":")
                    _id = "_".join(tag_components)
                    fig.savefig(f"results/{_EXPERIMENT.value}/dsf-{_id}.pdf")
                plt.close(fig)
        return {"scalars": {"eval__mmd": loss}}

    def eval_return(
        self, rng: chex.PRNGKey, state: StateT, final: bool = False
    ) -> TaggedMetrics:
        mc_returns = self.mc_returns
        predicted_dsf = jax.vmap(self.return_distribution, in_axes=(None, 0))(
            state, jnp.arange(self.env.num_states)
        )
        predicted_returns = jax.vmap(
            jax.vmap(DiscreteDistribution.pushforward_linear, in_axes=(0, None)),
            in_axes=(None, 0),
        )(predicted_dsf, self.eval_reward_functions)
        return_distances = jax.vmap(self.return_metric)(mc_returns, predicted_returns)

        if final:
            data_path = epath.Path("results") / f"{_EXPERIMENT.value}.npz"
            if data_path.resolve().exists():
                data = dict(np.load(data_path.resolve()))
            else:
                data = {}
            data[self.identifier] = np.array(return_distances)
            if _SAVE.value:
                np.savez_compressed(data_path, **data)

        scalar_results = {
            "eval__returndist_worst": jnp.max(return_distances),
            "eval__returndist_mean": jnp.mean(return_distances),
            "eval__returndist_best": jnp.min(return_distances),
        }
        for i in range(self.num_demo_reward_fns):
            mc_returns_i = jax.tree_util.tree_map(lambda x: x[i, 0], mc_returns)
            predicted_returns_i = jax.tree_util.tree_map(
                lambda x: x[i, 0], predicted_returns
            )
            fig = plotting.plot_univariate_empirical_distributions(
                {"MC": mc_returns_i, "Prediction": predicted_returns_i}
            )
            self.metric_writer.write_images(
                state.step.item(), {f"return-distributions/{i}": fig}
            )
            if _SAVE.value and final:
                os.makedirs(f"results/{_EXPERIMENT.value}", exist_ok=True)
                _id = self.identifier.split(":")[0]
                fig.savefig(f"results/{_EXPERIMENT.value}/reward{i}-{_id}.pdf")
            plt.close(fig)
        return {"scalars": scalar_results}

    def _eval(
        self, rng: chex.PRNGKey, state: StateT, stage: EvalStage
    ) -> TaggedMetrics:
        dsf_results = {}
        return_results = {}
        if stage in self.dsf_eval_stage:
            rng, key = jax.random.split(rng)
            dsf_results = self.eval_dsf(
                key, state, final=stage is EvalStage.POST_TRAINING
            )

        if stage in self.return_eval_stage:
            rng, key = jax.random.split(rng)
            return_results = self.eval_return(
                key, state, final=stage is EvalStage.POST_TRAINING
            )

        all_results = {}
        dsf_scalars = dsf_results.pop("scalars", {})
        return_scalars = return_results.pop("scalars", {})
        all_results["scalars"] = dsf_scalars | return_scalars
        return all_results

    def mid_training_eval(self, rng: jax.Array, state: StateT) -> TaggedMetrics:
        return self._eval(rng, state, EvalStage.MID_TRAINING)

    def post_training_eval(self, rng: PRNGKey, state: StateT) -> TaggedMetrics:
        return self._eval(rng, state, EvalStage.POST_TRAINING)
