import torch


class VectorToContinuousSin(torch.nn.Module):
    def __init__(
        self, vector_length, num_functions, amplitude=1.0, phase=0.0, device=None
    ):
        """
        Initialize the converter with the vector length, number of sine functions, and constant amplitude and phase.
        :param vector_length: Length of the input binary vectors.
        :param num_functions: Number of sine functions to use.
        :param amplitude: Constant amplitude for all sine functions.
        :param phase: Constant phase for all sine functions.
        """
        super(VectorToContinuousSin, self).__init__()
        self.vector_length = vector_length
        self.num_functions = num_functions

        # Create a tensor of frequencies for each function
        self.frequencies = (
            torch.arange(0.5, num_functions + 0.5, step=1.0).float().to(device)
        )
        # Compute weighted sum for each vector in the batch
        self.indices = (
            torch.arange(1, self.vector_length + 1).float().view(1, -1).to(device)
        )
        self.amplitude = amplitude
        self.phase = phase

    def forward(self, batch):
        """
        Convert a batch of binary vectors to multiple continuous values using different sine functions.
        :param batch: A batch of binary vectors (2D tensor) to be converted.
        :return: A 2D tensor of continuous values.
        """
        weighted_sums = (
            torch.matmul(batch, (self.frequencies.view(-1, 1) * self.indices).T) + self.phase
        )
        # Apply sine function and scale by amplitude
        return self.amplitude * torch.sin(weighted_sums)


if __name__ == "__main__":
    # Example usage
    vector_length = 5
    num_functions = 3
    converter = VectorToContinuousSin(vector_length, num_functions)

    # Batch of binary vectors
    batch_binary_vectors = torch.tensor(
        [[0, 1, 0, 1, 1], [1, 0, 1, 0, 1], [1, 1, 1, 1, 1]]
    )
    continuous_values_batch = converter(batch_binary_vectors)
    print(continuous_values_batch)
