import torch
import torch.nn as nn


class Ground_Truth(nn.Module):

    def __init__(
        self, input_dim, mode_num, mode_std, point_std, amplitude, device=torch.device('cpu')
    ):
        # ground truth is gaussian mixture
        # q, num of mode
        # std, std of mode position
        # psdt, std of single mode point
        # range, maxium value of single mode point
        # output dim is always 1
        super(Ground_Truth, self).__init__()

        self.device = device
        self.input_dim = input_dim
        self.mode_num = mode_num
        self.mode_std = mode_std
        self.point_std = point_std
        self.amplitude = amplitude

        self.generate()

    def generate(self):

        self.mode_data = self.mode_std * torch.randn(
            size=(self.mode_num, self.input_dim), device=self.device
        )
        self.point_std_data = torch.exp(
            torch.randn(self.mode_num, device=self.device) / 3
        ) * self.point_std
        self.amplitude_data = (
            2 *
            torch.randint(low=0, high=2, size=(self.mode_num, ), dtype=float, device=self.device) -
            1
        ) * torch.exp(torch.randn(self.mode_num, device=self.device) / 3) * self.amplitude

    def forward(self, x):
        out = 0.0
        for i in range(self.mode_num):
            out = out + self.amplitude_data[i] * torch.exp(
                -torch.sum(torch.pow(x - self.mode_data[i], 2), dim=1, keepdim=True) /
                self.point_std_data[i]
            )

        return out
