import math
from typing import Optional

import torch
from gpytorch.functions import MaternCovariance
from torch import nn


def sq_dist(x1, x2):
    adjustment = x1.mean(-2, keepdim=True)
    x1 = x1 - adjustment

    # Compute squared distance matrix using quadratic expansion
    x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
    x1_pad = torch.ones_like(x1_norm)
    x2 = (
        x2 - adjustment
    )  # x1 and x2 should be identical in all dims except -2 at this point
    x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
    x2_pad = torch.ones_like(x2_norm)

    x1_ = torch.cat([-2.0 * x1, x1_norm, x1_pad], dim=-1)
    x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
    res = x1_.matmul(x2_.transpose(-2, -1))

    # Zero out negative values
    return res.clamp_min_(0)


def dist(x1, x2):
    res = sq_dist(x1, x2)
    return res.clamp_min_(1e-30).sqrt_()


class Kernel(nn.Module):
    def __init__(
        self,
        ard_num_dims: Optional[int] = None,
        init_lengthscale: float = 1e0,
        init_scale: float = 1e0,
    ):
        super().__init__()

        self.ard_num_dims = ard_num_dims

        lengthscale_num_dims = 1 if ard_num_dims is None else ard_num_dims
        self.log_lengthscale = nn.Parameter(
            torch.ones(1, lengthscale_num_dims) * math.log(init_lengthscale)
        )
        self.log_scale = nn.Parameter(torch.ones(1) * math.log(init_scale))

    @property
    def lengthscale(self):
        return self.log_lengthscale.exp()

    @lengthscale.setter
    def lenthscale(self, value: float):
        lengthscale_num_dims = 1 if self.ard_num_dims is None else self.ard_num_dims
        self.log_lengthscale = nn.Parameter(
            (torch.ones(1, lengthscale_num_dims) * value).log()
        )

    @property
    def scale(self):
        return self.log_scale.exp()

    @scale.setter
    def scale(self, value: float):
        self.log_scale = nn.Parameter(torch.as_tensor(value).log())

    def forward(
        self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False
    ) -> torch.Tensor:
        raise NotImplementedError

    def covar_dist(
        self,
        x1: torch.Tensor,
        x2: torch.Tensor,
        diag: bool = False,
        square_dist: bool = False,
    ):
        res = None

        if diag:
            res = torch.linalg.norm(x1 - x2, dim=-1)  # 2-norm by default
            return res.pow(2) if square_dist else res
        dist_func = sq_dist if square_dist else dist
        return dist_func(x1, x2)

    def __call__(
        self,
        x1: torch.Tensor,
        x2: Optional[torch.Tensor] = None,
        diag: bool = False,
        **params,
    ) -> torch.Tensor:
        x1_, x2_ = x1, x2

        # Give x1_ and x2_ a last dimension, if necessary
        if x1_.ndimension() == 1:
            x1_ = x1_.unsqueeze(1)
        if x2_ is not None:
            if x2_.ndimension() == 1:
                x2_ = x2_.unsqueeze(1)
            if not x1_.size(-1) == x2_.size(-1):
                raise RuntimeError(
                    "x1_ and x2_ must have the same number of dimensions!"
                )

        if x2_ is None:
            x2_ = x1_

        return self.forward(x1_, x2_, diag=diag, **params)


class RBFKernel(Kernel):
    def forward(
        self,
        x1: torch.Tensor,
        x2: torch.Tensor,
        diag: bool = False,
        **params,
    ):
        x1_ = x1.div(self.lengthscale)
        x2_ = x2.div(self.lengthscale)
        return self.scale * (
            self.covar_dist(x1_, x2_, square_dist=True, diag=diag, **params)
            .div(-2)
            .exp()
        )


class PeriodicKernel(Kernel):
    def __init__(
        self,
        init_period: float = 1e0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        period_num_dims = 1 if self.ard_num_dims is None else self.ard_num_dims
        self.log_period = nn.Parameter(
            torch.ones(1, period_num_dims) * math.log(init_period)
        )

    @property
    def period(self):
        return self.log_period.exp()

    @period.setter
    def period(self, value: float):
        period_num_dims = 1 if self.ard_num_dims is None else self.ard_num_dims
        self.log_period = nn.Parameter((torch.ones(1, period_num_dims) * value).log())

    def forward(self, x1, x2, diag=False, **params):
        lengthscale = self.lengthscale

        x1_ = x1.div(self.period / math.pi)
        x2_ = x2.div(self.period / math.pi)

        diff = self.covar_dist(x1_, x2_, diag=diag, **params)

        if diag:
            lengthscale = lengthscale[..., 0, :, None]
        else:
            lengthscale = lengthscale[..., 0, :, None, None]

        exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0)
        exp_term = exp_term.sum(dim=(-2 if diag else -3))

        return self.scale * exp_term.exp()


class MaternKernel(Kernel):
    def __init__(self, nu: Optional[float] = 2.5, **kwargs):
        if nu not in {0.5, 1.5, 2.5}:
            raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5")
        super().__init__(**kwargs)

        self.nu = nu

    def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag=False, **params):
        if (
            x1.requires_grad
            or x2.requires_grad
            or (self.ard_num_dims is not None and self.ard_num_dims > 1)
            or diag
        ):
            mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]

            x1_ = (x1 - mean).div(self.lengthscale)
            x2_ = (x2 - mean).div(self.lengthscale)
            distance = self.covar_dist(x1_, x2_, diag=diag, **params)
            exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance)

            if self.nu == 0.5:
                constant_component = 1
            elif self.nu == 1.5:
                constant_component = (math.sqrt(3) * distance).add(1)
            elif self.nu == 2.5:
                constant_component = (
                    (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2)
                )
            return self.scale * constant_component * exp_component

        return self.scale * MaternCovariance.apply(
            x1,
            x2,
            self.lengthscale,
            self.nu,
            lambda x1, x2: self.covar_dist(x1, x2, **params),
        )
