import torch
import numpy as np
from tqdm import tqdm
from collections.abc import Iterable
import torch.nn.functional as F
from typing import Union


def sigmoid(x, alpha=0.4):
    return 1 / (1 + (-alpha * x).exp())


@torch.jit.script
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int) -> torch.Tensor:
    """Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
    Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.

    Args:
        start (torch.Tensor): Tensor of starting values.
        stop (torch.Tensor): Tensor of ending values.
        num (int): Number of samples to generate.

    Returns:
        torch.Tensor: linspace(start, stop, num)
    """
    steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
    for _ in range(start.ndim):
        steps = steps.unsqueeze(-1)

    out = start[None] + steps * (stop - start)[None]
    return out


@torch.no_grad()
def batch_histogram(
    samples: torch.Tensor,
    n_bins: int = 100,
    batch_size: int = 200,
    device: str = "cpu",
    disable_tqdm: bool = False,
    reduce_memory: bool = False,
    bins_min_max: tuple = None,
    probability: bool = False,
) -> Union[torch.Tensor, torch.Tensor]:
    """Generates a histogram (density, bins) for multiple dimensions for batch tensor shaped [n_samples, n_dims].

    Args:
        samples (torch.Tensor): Input samples.
        n_bins (int, optional): Number of bins. Defaults to 100.
        batch_size (int, optional): Batch size. Defaults to 200.
        device (str, optional): Cuda device or CPU. Defaults to "cpu".
        disable_tqdm (bool, optional): Whether to show progress bars (if True: disabled). Defaults to False.
        reduce_memory (bool, optional): If True, it uses a more memory efficient but slower calculation. Defaults to False.
        bins_min_max (tuple, optional): Set min and max for the bins instead of calculating it from the samples. Defaults to None.
        probability (bool, optional): If True, normalizes the densities to sum to one. Defaults to False.

    Returns:
        Union[torch.Tensor, torch.Tensor]: density, bins (shaped [n_bins, n_dims], [n_bins + 1, n_dims])
    """
    eps = 1e-8
    samples = samples.to(device)
    if bins_min_max is None:
        if not isinstance(samples, torch.Tensor):
            min = next(iter(samples)).min(dim=0)[0]
            max = next(iter(samples)).max(dim=0)[0]
            for sample in samples:
                min = torch.where(sample.min(dim=0)[0] < min, sample.min(dim=0)[0], min)
                max = torch.where(sample.max(dim=0)[0] > max, sample.max(dim=0)[0], max)
        else:
            min = samples.min(dim=0)[0]
            max = samples.max(dim=0)[0]
        bins = linspace(min, max + eps, num=n_bins + 1)
    else:
        min = bins_min_max[0]
        max = bins_min_max[1]
        bins = linspace(min, max + eps, num=n_bins + 1)[:, None].repeat(
            1, samples.shape[1]
        )
        bins = torch.cat(
            [bins, torch.ones(1, bins.shape[-1]).to(bins.device) * np.inf], dim=0
        )
        bins = torch.cat(
            [-torch.ones(1, bins.shape[-1]).to(bins.device) * np.inf, bins], dim=0
        )
    bin_length = bins[2:3] - bins[1:2] + 1e-12

    bins = bins.to(device)
    if not isinstance(samples, torch.Tensor) or samples.shape[0] > 1000:
        if isinstance(samples, torch.Tensor):
            samples = samples.split(batch_size)
        densities = []
        for samples in tqdm(samples, disable=disable_tqdm):
            samples = samples.to(device)
            if n_bins > 1000:
                bins_ = bins.split(200)
                density = []
                for i, bins__ in enumerate(bins_):
                    if i < len(bins_) - 1:
                        bins__ = torch.cat([bins__, bins_[i + 1][:1]], dim=0)
                        density_ = torch.logical_and(
                            (samples[None] < bins__[1:, None]),
                            (samples[None] >= bins__[:-1, None]),
                        )
                    else:
                        density_ = torch.logical_and(
                            (samples[None] < bins__[1:, None]),
                            (samples[None] >= bins__[:-1, None]),
                        )
                    density.append(density_)
                density = torch.cat(density, dim=0)
            else:
                density = torch.logical_and(
                    (samples[None] < bins[1:, None]),
                    (samples[None] >= bins[:-1, None]),
                )
            density = density.to(torch.float32).mean(dim=1)
            densities.append(density[None].cpu())
            del samples
        if reduce_memory:
            density = torch.zeros_like(densities[0][0])
            for d in densities:
                density += d[0] / len(densities)
        else:
            density = torch.cat(densities).mean(dim=0)

    else:
        density = torch.logical_and(
            (samples[None] < bins[1:, None]), (samples[None] >= bins[:-1, None])
        )
        density = density.to(torch.float32).mean(dim=1)
    if not probability:
        density /= bin_length.to(density.device)
    if bins_min_max is not None:
        density = density[1:-1] / density.sum(dim=0)[None]

    return density, bins


class ExponentialMovingAverageKernel(torch.nn.Module):
    """Exponential moving average kernel for smoothing the density."""

    def __init__(
        self,
        dim,
        kernel_size=10,
    ):
        super().__init__()
        self.kernel = torch.nn.Conv1d(
            dim, dim, kernel_size, padding="same", padding_mode="replicate"
        )
        self.kernel.weight.data = torch.zeros_like(self.kernel.weight.data)
        self.kernel.weight.data[range(dim), range(dim)] = (
            torch.ones(1, dim, kernel_size) / kernel_size
        )
        self.kernel.bias.data = torch.zeros_like(self.kernel.bias)
        self.kernel.bias.requires_grad = False
        self.kernel.weight.requires_grad = False

    @torch.no_grad()
    def forward(self, densities, bins=None):
        smoothed_densities = self.kernel(densities.T).T
        if bins is not None:
            smoothed_densities /= (smoothed_densities * (bins[1] - bins[0])).sum(dim=0)[
                None
            ]
        return smoothed_densities


class ConsecutiveLinear(torch.nn.Module):
    def __init__(self, in_size, weight, bias, out_size=1000):
        super().__init__()
        self.linears = torch.nn.ModuleList()
        for w, b in zip(weight.split(out_size, dim=0), bias.split(out_size)):
            self.linears.append(torch.nn.Linear(in_size, out_size, bias=True))
            self.linears[-1].weight.data = w
            self.linears[-1].bias.data = b

    @torch.no_grad()
    def forward(self, x):
        out = torch.cat([l(x) for l in self.linears], dim=1)
        return out


@torch.no_grad()
def fc_perturbed(
    model: torch.nn.Module,
    latents: torch.Tensor,
    lin: torch.nn.Module = None,
    device: str = "cpu",
    perturbation_distance: float = 2.1,
    n_repeats: int = 50,
    concatenate_logits_latents: bool = False,
    noise_proportional: bool = True,
    constant_length: bool = True,
    batch_size: int = 200,
    input_return_list: bool = False,
    apply_softmax: bool = False,
    reduce_memory: bool = False,
    ablation_noise_only: bool = False,
) -> Union[torch.Tensor, torch.Tensor]:
    """Creates the perturbed fully connected layer (curly H in the paper)
    and calculates the perturbed logits from the penultimate latents.

    Args:
        model (torch.nn.Module): Instance of the neural network.
        latents (torch.Tensor): The penultimate latents.
        lin (torch.nn.Module, optional): Precalculates perturbed fully connected layer. Defaults to None.
        device (str, optional): Cuda device or CPU. Defaults to "cpu".
        perturbation_distance (float, optional): (Relative) perturbation distance (delta in the paper). Defaults to 2.1.
        n_repeats (int, optional): The number of repeats (r in the paper). Defaults to 50.
        concatenate_logits_latents (bool, optional): Concatenate the output of logits and latents. Defaults to False.
        noise_proportional (bool, optional): Force a constant ratio between noise and class projections controlled by
        perturbation_distance. Defaults to True.
        constant_length (bool, optional): Normalize the noise to have length perturbation_distance. Defaults to True.
        batch_size (int, optional): Batch size. Defaults to 200.
        input_return_list (bool, optional): Whether to return a list. Defaults to False.
        apply_softmax (bool, optional): Whether to apply softmax to the output. Defaults to False.
        reduce_memory (bool, optional): Reduces memory, but slows down the computation. Defaults to False.
        ablation_noise_only (bool, optional): Use weight independent random projections. Defaults to False.

    Returns:
        Union[torch.Tensor, torch.Tensor]: noise_logits, lin
    """
    if lin is not None:
        lin = lin.to(device)
    softmax = lambda x: x.softmax(dim=-1) if apply_softmax else x
    if isinstance(latents, tuple) or isinstance(latents, list):
        train_latents, test_latents, ood_latents = latents
    else:
        train_latents = latents
    random_perturbs = [
        perturbation_distance
        * (
            F.normalize(torch.randn_like(model.fc.weight.data), dim=1)
            if constant_length
            else torch.randn_like(model.fc.weight.data)
        )
        for _ in range(n_repeats)
    ]
    if noise_proportional:
        random_perturbs = [
            pert * model.fc.weight.norm(dim=1)[:, None] for pert in random_perturbs
        ]

    def build_lin():
        weight = model.fc.weight.data
        bias = model.fc.bias.data
        if ablation_noise_only:
            weight = 0 * weight
            bias = 0 * bias
        weight_ = torch.cat([weight + pert for pert in random_perturbs], dim=0)
        bias_ = torch.cat([bias for _ in range(n_repeats)], dim=0)
        if reduce_memory:
            lin = ConsecutiveLinear(train_latents[0].shape[-1], weight_, bias_)
        else:
            lin = torch.nn.Linear(
                train_latents[0].shape[-1], n_repeats * weight.shape[0]
            )
            lin.weight.data = weight_
            lin.bias.data = bias_
        lin.to(device)
        return lin

    if lin is None:
        lin = build_lin()

    if isinstance(latents, tuple) or isinstance(latents, list):
        if not input_return_list:
            train_noise_logits = torch.cat(
                [
                    softmax(lin(x.to(device))).cpu()
                    for x in train_latents.split(batch_size)
                ],
                dim=0,
            )
            test_noise_logits = torch.cat(
                [
                    softmax(lin(x.to(device))).cpu()
                    for x in test_latents.split(batch_size)
                ],
                dim=0,
            )
            ood_noise_logits = {}
            for k, v in ood_latents.items():
                ood_noise_logits[k] = torch.cat(
                    [softmax(lin(x.to(device))).cpu() for x in v.split(batch_size)],
                    dim=0,
                )
        else:
            train_noise_logits = [
                softmax(lin(x.to(device))).cpu() for x in train_latents
            ]
            test_noise_logits = [softmax(lin(x.to(device))).cpu() for x in test_latents]
            ood_noise_logits = {}
            for k, v in ood_latents.items():
                ood_noise_logits[k] = [softmax(lin(x.to(device))).cpu() for x in v]

        if concatenate_logits_latents:
            if not input_return_list:
                train_logits_latents = torch.cat(
                    [train_latents, train_noise_logits], dim=-1
                )
                test_logits_latents = torch.cat(
                    [test_latents, test_noise_logits], dim=-1
                )
                ood_logits_latents = {}
                for k, v in ood_noise_logits.items():
                    ood_logits_latents[k] = torch.cat([ood_latents[k], v], dim=-1)
            else:
                train_logits_latents = [
                    torch.cat([train_latent, train_noise_logit], dim=-1)
                    for train_latent, train_noise_logit in zip(
                        train_latents, train_noise_logits
                    )
                ]
                test_logits_latents = [
                    torch.cat([test_latent, test_noise_logit], dim=-1)
                    for test_latent, test_noise_logit in zip(
                        test_latents, test_noise_logits
                    )
                ]
                ood_logits_latents = {}
                for k, v in ood_noise_logits.items():
                    ood_logits_latents[k] = [
                        torch.cat([ood_latent, ood_noise_logit], dim=-1)
                        for ood_latent, ood_noise_logit in zip(ood_latents[k], v)
                    ]
            noise_logits = (
                train_logits_latents,
                test_logits_latents,
                ood_logits_latents,
            )
        else:
            noise_logits = (
                train_noise_logits,
                test_noise_logits,
                ood_noise_logits,
            )
    else:
        noise_logits = torch.cat(
            [softmax(lin(x.to(device))).cpu() for x in latents.split(batch_size)], dim=0
        )
        if concatenate_logits_latents:
            noise_logits = torch.cat([latents, noise_logits.to(latents.device)], dim=-1)
    return noise_logits, lin


def calculate_density(
    latents: torch.Tensor,
    min_train: torch.Tensor,
    max_train: torch.Tensor,
    n_bins: int = 100,
    device: str = "cpu",
    verbose: bool = False,
) -> torch.Tensor:
    """Calculate the density of each sample over the latent dimensions as described in the paper.

    Args:
        latents (torch.Tensor): Penultimate latents.
        min_train (torch.Tensor): Train set minimum.
        max_train (torch.Tensor): Train set maximum.
        n_bins (int, optional): Number of bins. Defaults to 100.
        device (str, optional): Cuda device or CPU. Defaults to "cpu".
        verbose (bool, optional): Whether to show the progress bar. Defaults to False.

    Returns:
        torch.Tensor: density
    """
    clamp_values = (min_train.to(latents.device), max_train.to(latents.device))
    min_train_, max_train_ = (
        clamp_values[0] * torch.ones(1)[0],
        clamp_values[1] * torch.ones(1)[0],
    )

    density = []

    for l in tqdm(latents.split(100, dim=0), disable=not verbose):
        density_t = batch_histogram(
            l.T,
            n_bins=n_bins,
            disable_tqdm=True,
            device=device,
            probability=False,
            bins_min_max=(min_train_, max_train_),
        )[0]

        density.append(density_t.cpu())
    density = torch.cat(density, dim=-1)
    return density


def kldiv(
    p: torch.Tensor,
    q: torch.Tensor,
    eps: float = 1e-8,
    kernel: torch.nn.Module = None,
    symmetric: bool = True,
    device: str = "cpu",
) -> torch.Tensor:
    """Calculate the Kullback-Leibler divergence after smoothing and normalizing the densities.
    Mapping of shapes: (n_bins, n_samples_0),(n_bins, n_samples_1) -> (n_samples_0,n_samples_1)

    Args:
        p (torch.Tensor): p density tensor.
        q (torch.Tensor): q density tensor.
        eps (float, optional): Epsilon added to q to prevent zero entries. Defaults to 1e-8.
        kernel (torch.nn.Module, optional): Smoothing kernel. Defaults to None.
        symmetric (bool, optional): Calculate KLD(p,q) + KLD(q,p) if True. Defaults to True.
        device (str, optional): Cuda device or CPU. Defaults to "cpu".

    Returns:
        torch.Tensor: uncertainty
    """
    kernel.to(device)
    p = p.to(device)
    if kernel is not None:
        q_ = kernel(q.to(device))
    q_ += eps
    p_ = (p + eps).clone()
    p_ /= p_.sum(dim=0)
    q_ /= q_.sum(dim=0)
    if symmetric:
        q_p = q_.log()[:, None] - p_.log()[..., None]
        return -(p_[..., None] * q_p).sum(dim=0) - (q_[:, None] * (-q_p)).sum(dim=0)
    else:
        return -(p_[..., None] * (q_.log()[:, None] - p_.log()[..., None])).sum(dim=0)


def calculate_uncertainty(
    density: torch.Tensor,
    density_train_mean: torch.Tensor,
    kernel: torch.nn.Module,
    eps: float = 1e-8,
    symmetric: bool = True,
    device: str = "cpu",
) -> torch.Tensor:
    """Calculate the score (uncertainties) for given densities and the training set mean density.

    Args:
        density (torch.Tensor): Given density tensor.
        density_train_mean (torch.Tensor): Density mean of the training set.
        kernel (torch.nn.Module): Smoothing kernel.
        eps (float, optional): Epsilon added to q to prevent zero entries. Defaults to 1e-8.
        symmetric (bool, optional): Calculate KLD(p,q) + KLD(q,p) if True. Defaults to True.
        device (str, optional): Cuda device or CPU. Defaults to "cpu".

    Returns:
        torch.Tensor: uncertainty
    """
    uncertainty = []
    for dens_t in tqdm(density.split(100, dim=-1), disable=True):
        s = 100
        if dens_t.shape[-1] < 100:
            s = dens_t.shape[-1]
            dens_t = torch.cat(
                [dens_t, torch.ones(dens_t.shape[0], 100 - s)], dim=-1
            ).clone()

        uncertainty.append(
            kldiv(
                density_train_mean,
                dens_t,
                kernel=kernel,
                eps=eps,
                symmetric=symmetric,
                device=device,
            )[:, :s]
        )
    return torch.cat(uncertainty, dim=-1)[0]


@torch.no_grad()
def calculate_WeiPerKLDiv_score(
    model: torch.nn.Module,
    latents: Iterable[torch.Tensor],
    densities: torch.Tensor = None,
    n_bins: int = 100,
    perturbation_distance: float = 2.1,
    n_repeats: int = 100,
    smoothing: int = 20,
    smoothing_perturbed: int = 20,
    epsilon: float = 0.01,
    lambda_1: float = 1,
    lambda_2: float = 1,
    symmetric: bool = True,
    device: str = "cpu",
    verbose: bool = False,
    reduce_fc_perturbed_memory: bool = False,
    ablation_noise_only: bool = False,
    train_densities: Iterable[torch.Tensor] = None,
    **params,
) -> Union[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Apply the WeiPerKLDiv score to the given latents.

    Args:
        model (torch.nn.Module): Instance of the neural network.
        latents (Iterable[torch.Tensor]): Train latents, test latents, ood latents.
        densities (torch.Tensor, optional): If the densities are already calculated,
        they can be referenced here. Defaults to None.
        n_bins (int, optional): Number of bins. Defaults to 100.
        perturbation_distance (float, optional): (Relative) perturbation distance (delta in the paper). Defaults to 2.1.
        n_repeats (int, optional): The number of repeats (r in the paper). Defaults to 100.
        smoothing (int, optional): Smoothing size (s_1 in the paper) for the density. Defaults to 20.
        smoothing_perturbed (int, optional): Smoothing size (s_2 in the paper) for the perturbed density. Defaults to 20.
        epsilon (float, optional): Epsilon added to q to prevent zero entries. Defaults to 0.025.
        lambda_1 (float, optional): Lambda_1 (as in the paper). Defaults to 1.
        lambda_2 (float, optional): Lambda_2 (as in the paper). Defaults to 1.
        symmetric (bool, optional): Calculate KLD(p,q) + KLD(q,p) if True. Defaults to True.
        device (str, optional): Cuda device or CPU. Defaults to "cpu".
        verbose (bool, optional): Whether to show progress bars. Defaults to False.
        reduce_fc_perturbed_memory (bool, optional): Reduces memory, but slows down the computation. Defaults to False.
        ablation_noise_only (bool, optional): Use weight independent random projections. Defaults to False.
        train_densities (Iterable[torch.Tensor], optional): If the densities of the train set are already calculated,
        they can be referenced here. Defaults to None.

    Returns:
        Union[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: uncertainties, densities, latents_noise, train_densities, W_tilde
    """
    model.eval()
    if latents is not None:
        latents_train, latents_test, latents_ood = latents
    if densities is None:
        if verbose:
            print("Calculate perturbed logits...")
        latents_noise, _ = (
            latents_noise_train,
            latents_noise_test,
            latents_noise_ood,
        ), W_tilde = fc_perturbed(
            model,
            latents,
            lin=None,
            device="cpu",
            perturbation_distance=perturbation_distance,
            n_repeats=n_repeats,
            concatenate_logits_latents=False,
            noise_proportional=True,
            constant_length=True,
            batch_size=200,
            reduce_memory=reduce_fc_perturbed_memory,
            input_return_list=isinstance(latents[0], list),
            ablation_noise_only=ablation_noise_only,
        )

        if verbose:
            print("Evaluate density...")

        if train_densities is None:
            min_train, max_train = latents_train.min(), latents_train.max()
            min_noise_train, max_noise_train = (
                latents_noise_train.min(),
                latents_noise_train.max(),
            )

            densities_train = calculate_density(
                latents_train,
                min_train,
                max_train,
                n_bins=n_bins,
                device=device,
            )
            densities_train_mean = densities_train.mean(dim=-1)[:, None]
            densities_noise_train = calculate_density(
                latents_noise_train,
                min_noise_train,
                max_noise_train,
                n_bins=n_bins,
                device=device,
            )
            densities_noise_train_mean = densities_noise_train.mean(dim=-1)[:, None]
            train_densities = (
                densities_train_mean,
                densities_noise_train_mean,
                (min_train, max_train),
                (min_noise_train, max_noise_train),
            )
        else:
            (
                densities_train_mean,
                densities_noise_train_mean,
                (min_train, max_train),
                (min_noise_train, max_noise_train),
            ) = train_densities
        densities_test = calculate_density(
            latents_test,
            min_train,
            max_train,
            n_bins=n_bins,
            device=device,
        )
        densities_noise_test = calculate_density(
            latents_noise_test,
            min_noise_train,
            max_noise_train,
            n_bins=n_bins,
            device=device,
        )

        densities_ood = {}
        for k, v in latents_ood.items():
            densities_ood[k] = calculate_density(
                v,
                min_train,
                max_train,
                n_bins=n_bins,
                device=device,
            )
        densities_noise_ood = {}
        for k, v in latents_noise_ood.items():
            densities_noise_ood[k] = calculate_density(
                v,
                min_noise_train,
                max_noise_train,
                n_bins=n_bins,
                device=device,
            )

        def calculate_noise_pred(logits):
            n_classes = int(logits.shape[-1] / n_repeats)
            logits = logits.to(device)
            act_logits = []
            for i in range(n_repeats):
                act_logits.append(
                    logits[:, i * n_classes : (i + 1) * n_classes]
                    .softmax(dim=-1)
                    .max(dim=-1)[0][:, None]
                )
            return torch.cat(act_logits, dim=-1).to(device)

        msp_test, msp_ood = [
            (
                {k: calculate_noise_pred(v).mean(dim=-1) for k, v in l.items()}
                if isinstance(l, dict)
                else calculate_noise_pred(l).mean(dim=-1)
            )
            for l in latents_noise[1:]
        ]
        densities = (
            (densities_train_mean, densities_noise_train_mean),
            (densities_test, densities_noise_test, msp_test),
            (densities_ood, densities_noise_ood, msp_ood),
        )
    else:
        (
            (densities_train_mean, densities_noise_train_mean),
            (densities_test, densities_noise_test, msp_test),
            (densities_ood, densities_noise_ood, msp_ood),
        ) = densities
        latents_noise = None
    kernel = ExponentialMovingAverageKernel(100, smoothing).to(device)
    kernel_noise = ExponentialMovingAverageKernel(100, smoothing_perturbed).to(device)
    uncertainties_test = calculate_uncertainty(
        densities_test,
        densities_train_mean,
        kernel,
        eps=epsilon,
        device=device,
        symmetric=symmetric,
    )
    uncertainties_noise_test = calculate_uncertainty(
        densities_noise_test,
        densities_noise_train_mean,
        kernel_noise,
        eps=epsilon,
        device=device,
        symmetric=symmetric,
    )

    uncertainties_ood = {}
    for k, v in densities_ood.items():
        uncertainties_ood[k] = calculate_uncertainty(
            v,
            densities_train_mean,
            kernel,
            eps=epsilon,
            device=device,
            symmetric=symmetric,
        )
    uncertainties_noise_ood = {}
    for k, v in densities_noise_ood.items():
        uncertainties_noise_ood[k] = calculate_uncertainty(
            v,
            densities_noise_train_mean,
            kernel_noise,
            eps=epsilon,
            device=device,
            symmetric=symmetric,
        )

    uncertainties_test = -(
        uncertainties_test + lambda_1 * uncertainties_noise_test - lambda_2 * msp_test
    )
    uncertainties_ood = {
        k: -(v + lambda_1 * uncertainties_noise_ood[k] - lambda_2 * msp_ood[k])
        for k, v in uncertainties_ood.items()
    }
    uncertainties = (uncertainties_test, uncertainties_ood)
    return uncertainties, densities, latents_noise, train_densities, W_tilde
