import numpy as np
import torch
import torch.nn as nn


class FourierMapping(nn.Module):
    def __init__(self, ff_type, input_dim, ff_dim, ff_sigma, trainable=False, ff_sigma_min=None):
        super().__init__()

        assert ff_type in ["deterministic_transinr", "deterministic_transinr_range", "random_gaussian", "deterministic_transinr_nerf"]
        self.ff_type = ff_type
        self.input_dim = input_dim
        self.ff_dim = ff_dim
        self.ff_sigma = ff_sigma
        self.ff_sigma_min = ff_sigma_min

        if ff_type == "deterministic_transinr":
            assert ff_dim % input_dim == 0
            log_freqs = torch.linspace(0, np.log(ff_sigma), ff_dim // input_dim)
            self.ff_linear = torch.exp(log_freqs)
        elif ff_type == "deterministic_transinr_range":
            assert ff_dim % input_dim == 0
            assert ff_sigma_min is not None
            log_freqs = torch.linspace(np.log(ff_sigma_min), np.log(ff_sigma), ff_dim // input_dim)
            self.ff_linear = torch.exp(log_freqs)
        elif ff_type == "random_gaussian":
            assert ff_dim % input_dim == 0
            self.ff_linear = torch.randn(input_dim, ff_dim) * ff_sigma  # scaler
        elif ff_type == "deterministic_transinr_nerf":
            self.ff_linear = 2 ** torch.linspace(0, ff_sigma, self.ff_dim // input_dim)
        else:
            raise NotImplementedError

        self.ff_linear = nn.Parameter(self.ff_linear, requires_grad=trainable)

    def extra_repr(self):
        repr_str = f"ff_type={self.ff_type}, input_dim={self.input_dim}, ff_dim={self.ff_dim}, ff_sigma={self.ff_sigma}, ff_sigma_min={self.ff_sigma_min}"
        return repr_str

    def forward(self, coord):
        """
        Args
            coord (torch.Tensor) : `coord.shape == (B, -1, input_dim)`
        Returns
            ff_features (torch.Tensor) : `ff_feature.shape == (B, -1, 2*ff_dim)`
        """

        if self.ff_type in ["deterministic_transinr", "deterministic_transinr_range", "deterministic_transinr_nerf"]:
            fourier_features = torch.matmul(coord.unsqueeze(-1), self.ff_linear.unsqueeze(0))
            fourier_features = fourier_features.view(*coord.shape[:-1], -1)
        else:
            fourier_features = torch.matmul(coord, self.ff_linear)

        if not self.ff_type == "deterministic_transinr_nerf":
            fourier_features = fourier_features * np.pi

        fourier_features = [torch.cos(fourier_features), torch.sin(fourier_features)]
        fourier_features = torch.cat(fourier_features, dim=-1)
        return fourier_features
