import warnings
from typing import List

import numpy as np
import ot
import torch
from scipy.stats import norm

from .testing import borda, resample_list
from .utils import bootstrap_multiprocessing, pdist


class MVStochasticOrderTesting:

    def __init__(self,
                 scores_list: List[np.ndarray],
                 n_bootstrap: int = 1000,
                 num_workers: int = 1,
                 cost: str = 'logistic',
                 cost_kwargs: dict = {'beta': 8.0},
                 use_sinkhorn: bool = True,
                 sinkhorn_reg: float = 0.1,
                 max_bootstrap_samples: int = int(5e4),
                 use_cuda: bool = False,
                 include_entropy: bool = True,
                 verbose=False,
                 verbose_warnings=False) -> None:
        r"""Multivatiate absolute and relative Stochastic Order Testing
        
        Args:
            scores_list (List[np.ndarray]): List of scores (N-D arrays) to compare
            n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.
            num_workers (int, optional): Number of workers for parallelization. Defaults to 1.
                Note: Parallelization is only used for bootstrapping and has to be made compatible with use of GPU if use_cuda is True.
            cost (str, optional): Cost to use for computing violations. Defaults to 'hinge'.
            cost_kwargs (dict, optional): Additional arguments for the cost (if relevant). Defaults to {'beta': 8.0}.
            use_sinkhorn (bool, optional): Whether to use Sinkhorn algorithm. Defaults to False.
            sinkhorn_reg (float, optional): Regularization parameter for Sinkhorn. Defaults to 0.1.
            max_bootstrap_samples (int, optional): Max number of samples to use for bootstrap. Defaults to 5e4.
            use_cuda (bool, optional): Whether to use GPU. Only relevant if `use_sinkhorn` is True. Defaults to True.
            include_entropy (bool, optional): Whether to include entropy contribution in OT solution in addition to <gamma, C>. Defaults to True.
            verbose (bool, optional): Whether to print progress. Defaults to True.
        """
        # Check and cache inputs
        if not isinstance(scores_list, list):
            raise ValueError("Scores should be a list of 1D arrays (one per model)")
        self.scores_list = scores_list
        self.n_samples = len(scores_list[0])
        self.k = len(scores_list)

        self.n_bootstrap = n_bootstrap
        self.num_workers = num_workers
        self.cost = cost
        self.cost_kwargs = cost_kwargs
        self.use_sinkhorn = use_sinkhorn
        self.sinkhorn_reg = sinkhorn_reg
        self.max_bootstrap_samples = max_bootstrap_samples
        self.use_cuda = use_cuda and use_sinkhorn and torch.cuda.is_available()
        self.include_entropy = include_entropy
        self.verbose = verbose

        if not verbose_warnings:
            warnings.filterwarnings("ignore")

    def _compute_violations(self) -> None:

        @torch.no_grad()
        def get_eps(scores_list: list[np.ndarray], use_sinkhorn: bool, cost: str,
                    use_cuda: bool) -> np.ndarray:

            # Backend
            if use_cuda:
                be = torch
            else:
                be = np

            k = len(scores_list)
            eps = np.zeros((k, k))
            for i in range(k):
                for j in range(k):
                    if i == j:
                        eps[i, j] = 0.0
                    else:
                        # Get scores
                        xi, xj = scores_list[i], scores_list[j]
                        ai, aj = np.ones(len(xi)) / len(xi), np.ones(len(xj)) / len(xj)
                        if use_cuda:
                            xi = torch.tensor(xi, dtype=torch.float).cuda()
                            xj = torch.tensor(xj, dtype=torch.float).cuda()
                            ai, aj = torch.tensor(ai).cuda(), torch.tensor(aj).cuda()

                        # Select OT method: sinkhorn or emd
                        if use_sinkhorn:
                            ot_method = lambda a, b, M: ot.sinkhorn(
                                a, b, M, reg=self.sinkhorn_reg, numItermax=1000)
                        else:
                            ot_method = lambda a, b, M: ot.emd(a, b, M)
                        beta = self.cost_kwargs.get('beta', 1.0)

                        # Numerator
                        M = pdist(xi, xj, metric=cost, beta=beta)  # type: ignore
                        if isinstance(M, torch.Tensor):
                            M = M.double()
                        gamma = ot_method(ai, aj, M)
                        c = be.sum(M * gamma).item()
                        if use_sinkhorn and self.include_entropy:  # add entropy contribution
                            c += self.sinkhorn_reg * be.sum(
                                gamma * be.log(gamma + 1e-9)).item()  # type: ignore
                            c += self.sinkhorn_reg * (np.log(xi.shape[0]) + np.log(xj.shape[0]))
                        del M, gamma

                        # Denominator
                        if cost == 'hinge':
                            M = pdist(xi, xj, metric='euclidean')
                        elif cost == 'logistic':  # Symmetrize logistic cost for denominator
                            M = pdist(xi, xj, metric='logistic_sym', beta=beta)
                        else:
                            raise ValueError(f"Cost {cost} not supported for denominator")

                        if isinstance(M, torch.Tensor):
                            M = M.double()
                        gamma = ot_method(ai, aj, M)

                        norm = be.sum(M * gamma).item()
                        if use_sinkhorn and self.include_entropy:  # add entropy contribution
                            norm += self.sinkhorn_reg * be.sum(
                                gamma * be.log(gamma + 1e-9)).item()  # type: ignore
                            norm += self.sinkhorn_reg * (np.log(xi.shape[0]) + np.log(xj.shape[0]))
                        del M, gamma

                        # Numerator c in eps computation
                        if norm > 1e-9:  # type: ignore
                            eps[i, j] = c / norm
                        else:
                            eps[i, j] = 0.5
            return eps

        # Compute violations statistics for original data
        k_m = self.k - 1
        self.eps = get_eps(self.scores_list, self.use_sinkhorn, self.cost,
                           use_cuda=self.use_cuda)  # k x k
        self.eps_i = self.eps.sum(axis=1, keepdims=True) / k_m  # k x 1

        # Compute violations statistics for all bootstraps
        if self.n_bootstrap > 0:
            eps_bs_fn = lambda seed: get_eps(resample_list(self.scores_list, seed, self.
                                                           max_bootstrap_samples),
                                             self.use_sinkhorn,
                                             self.cost,
                                             use_cuda=self.use_cuda)

            eps_bs = np.c_[bootstrap_multiprocessing(eps_bs_fn,
                                                     num_workers=self.num_workers,
                                                     n_bootstrap=self.n_bootstrap,
                                                     desc="bootstrap quantiles violations",
                                                     verbose=self.verbose)]  # B x k x k

            # Compute bootstrap variance for absolute test
            self.sigma_abs = np.std(eps_bs, ddof=1, axis=0)  # k x k

            # Compute bootstratp variances for relative test
            eps_bs_i = eps_bs.sum(axis=2, keepdims=True) / k_m  # B x k x 1
            eps_bs_i = eps_bs_i.swapaxes(0, 1)  # k x B x 1
            self.sigma_rel = np.std(eps_bs_i - eps_bs_i.T, ddof=1, axis=1)  # k x k

    def _get_wins(self,
                  eps_0: np.ndarray,
                  sigma: np.ndarray,
                  alpha: float,
                  tau: float = 0.0) -> np.ndarray:
        phi = norm.ppf(alpha / self.k**2)
        th = 1 / np.sqrt(self.n_samples) * sigma * phi + tau
        wins = (eps_0 <= th).astype(int)
        return wins

    def compute_relative_test(self, alpha: float = 0.05, return_wins=False):
        """Compute relative test

        Args:
            alpha (float, optional): Significance level. Defaults to 0.05.
            return_wins (bool, optional): Whether to return wins instead of ranks. Defaults to False.

        Returns:
            Tuple[np.ndarray, np.ndarray]: QS and IQS ranks

        Example:
            >>> from soe.testing import RelativeStochasticOrderTesting
            >>> means = np.random.permutation(15)
            >>> scores_list = [m + np.random.randn(100) for m in means]
            >>> rel_test = StochasticOrderTesting(scores_list, n_bootstrap=100)
            >>> rank_qs = rel_test.compute_relative_test(alpha=0.05)
        """
        if not hasattr(self, "eps"):
            self._compute_violations()

        eps_0 = self.eps_i - self.eps_i.T
        wins = self._get_wins(eps_0, self.sigma_rel, alpha)

        if return_wins:
            return wins.sum(axis=1)
        else:
            # Compute ranks using Borda
            rank_qs = borda(wins)
            return rank_qs

    def compute_absolute_test(self, alpha: float = 0.05, tau: float = 0.25, return_wins=False):
        """Compute absolute test

        Args:
            alpha (float, optional): Significance level. Defaults to 0.05.
            tau (float, optional): Statistical threshold. Defaults to 0.25.
            return_wins (bool, optional): Whether to return wins instead of ranks. Defaults to False.

        Returns:
            Tuple[np.ndarray, np.ndarray]: QS and IQS ranks

        Example:
            >>> from soe.testing import RelativeStochasticOrderTesting
            >>> means = np.random.permutation(15)
            >>> scores_list = [m + np.random.randn(100) for m in means]
            >>> rel_test = StochasticOrderTesting(scores_list, n_bootstrap=100)
            >>> rank_qs, rank_iqs = rel_test.compute_absolute_test(tau=0.25)
        """
        if not hasattr(self, "eps"):
            self._compute_violations()

        wins = self._get_wins(self.eps, self.sigma_abs, alpha, tau)

        if return_wins:
            return wins.sum(axis=1)
        else:
            # Compute ranks using Borda
            rank_qs = borda(wins)
            return rank_qs
