
import numpy as np
import torch
import torch.nn as nn
import math
Sequence = True

def sampled_k_order_data(batch_size=32, seq_length=25, order_k=5, noise_std=0.1):
    # directly modify the seq_length is the sampliest way
    # Initialize the data array
    num_feat = 3
    data = np.zeros((batch_size, seq_length, num_feat+2))
    transition = np.array([[-0.05093365,  0.29281012,  0.18562611],
       [ 0.33428532,  0.40349844, -0.33467879],
       [ 0.14762354,  0.30575826, -0.26174898],
       [-0.44782906, -0.29130051, -0.08115543],
       [-0.32102355, -0.34755047,  0.3520845 ],
       [ 0.29529524, -0.11832884,  0.14952023],
       [ 0.23306076,  0.3170936 , -0.39420355],
       [-0.20894906, -0.34532153,  0.16065137],
       [-0.26019758, -0.22963394, -0.09091722],
       [ 0.07084283, -0.40970881, -0.43793993],
       [ 0.42431915,  0.23318924,  0.11397961],
       [-0.2561352 ,  0.05819257, -0.35414398],
       [ 0.46721657,  0.35555137, -0.01805047],
       [-0.15490611, -0.39558058, -0.0217171 ],
       [-0.17647525,  0.11403893,  0.16993299]]) # np.random.rand(order_k*num_feat, num_feat)
    weights = np.array([[0.5, 0.3, 0.2]])  # Example weights

    # Initialize the first k elements
    data[:, :order_k, :3] = np.random.rand(batch_size, order_k, 3)

    for i in range(order_k):
        data[:, i, 3] = np.sum(data[:, i, :3]*weights, axis=1)  # Example function: linear combination with weights
        # The fifth dimension is Gaussian noise
        data[:, i, 4] = np.random.normal(0, noise_std, (batch_size))


    # Generate the sequential data according to the k-th order dependency
    for i in range(order_k, seq_length):
        # For the first three dimensions, let's just propagate the previous state
        data[:, i, :3] = np.dot(data[:, i-order_k:i, :3].reshape(batch_size, order_k*num_feat,), transition)
        data[:, i, 3] = np.sum(data[:, i, :3]*weights, axis=1)  # The fourth dimension is a function of the first three dimensions
        data[:, i, 4] = np.random.normal(0, noise_std, (batch_size))
    # modify length or mask data here
    x = torch.from_numpy(data[:, :-1]).type(torch.float)
    # x[:, :, 4:] = 0
    y = torch.from_numpy(data[:, -1]).type(torch.float) 
    return x, y

def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, cubic=None):
    """Generate samples from a correlated Gaussian distribution."""
    if Sequence:
        # TODO: shift with 2, mi->20

        T = 10
        eps = torch.randn((T * batch_size,dim)).view(batch_size, T,  dim)  
        x = torch.empty((batch_size,T-1,dim)).float()
        for i in range(T-1):
            x[:, i] = eps[:, i]-0.5 

        y = math.sqrt(1 - rho**2) * eps[:, -1] + rho * torch.sum(x,dim=1)  / math.sqrt(T-1) +1

    else:    
        x, eps = torch.chunk(torch.randn(batch_size, 2 * dim), 2, dim=1)
        y = rho * x + torch.sqrt(torch.tensor(1. - rho**2).float()) * eps

    if cubic is not None:
        y = y ** 3

    return x, y


def rho_to_mi(dim, rho):
    """Obtain the ground truth mutual information from rho."""
    return -0.5 * np.log(1 - rho**2) * dim


def mi_to_rho(dim, mi):
    """Obtain the rho for Gaussian give ground truth mutual information."""
    return np.sqrt(1 - np.exp(-2.0 / dim * mi))


def mi_schedule(n_iter):
    """Generate schedule for increasing correlation over time."""
    mis = np.round(np.linspace(0.5, 5.5 - 1e-9, n_iter)) * 20  #  2.0
    return mis.astype(np.float32)


def mlp(dim, hidden_dim, output_dim, layers, activation):
    """Create a mlp from the configurations."""
    activation = {
        'relu': nn.ReLU
    }[activation]

    seq = [nn.Linear(dim, hidden_dim), activation()]
    for _ in range(layers):
        seq += [nn.Linear(hidden_dim, hidden_dim), activation()]
    seq += [nn.Linear(hidden_dim, output_dim)]

    return nn.Sequential(*seq)


class SeparableCritic(nn.Module):
    """Separable critic. where the output value is g(x) h(y). """

    def __init__(self, dim, hidden_dim, embed_dim, layers, activation, **extra_kwargs):
        super(SeparableCritic, self).__init__()
        self._g = mlp(dim, hidden_dim, embed_dim, layers, activation)
        self._h = mlp(dim, hidden_dim, embed_dim, layers, activation)

    def forward(self, x, y):
        scores = torch.matmul(self._h(y), self._g(x).t())
        return scores


class ConcatCritic(nn.Module):
    """Concat critic, where we concat the inputs and use one MLP to output the value."""

    def __init__(self, dim, hidden_dim, layers, activation, **extra_kwargs):
        super(ConcatCritic, self).__init__()
        # output is scalar score
        self._f = mlp(dim * 2, hidden_dim, 1, layers, activation)
        if Sequence:
            self.seq_proj = nn.LSTM(input_size=dim, hidden_size=dim,num_layers=layers,batch_first=True)

    def forward(self, x, y):
        if Sequence:
            output, _ = self.seq_proj(x)
            yhat = output[:,-1,:]  
        scores = -torch.cdist(yhat, y)**2
        batch_size = x.size(0)
        return scores


def log_prob_gaussian(x):
    return torch.sum(torch.distributions.Normal(0., 1.).log_prob(x), -1)
