import torch
import numpy as np
from torch.autograd import grad
from torch import sum, mul
# this doc calculate the relu NTK for a FCnn with ntk initialization, of input X, an Nxd data matrix.


def first_kernel(X, weight_std=1.0, bias_std=1.0):
    # return NxN
    d = X.shape[1]
    return weight_std * weight_std * X.mm(X.t()) / d + bias_std*bias_std


def pointwise_first_kernel(X, weight_std=1.0, bias_std=1.0):
    # return N
    d = X.shape[1]
    return weight_std * weight_std * torch.pow(X.norm(dim=1), 2) / d + bias_std*bias_std


def norm_mul(Sigma):
    if Sigma.dim() == 2:
        # Sigma is NxN
        norm = torch.sqrt(Sigma.diag().reshape(-1, 1))
        return norm.mm(norm.t())
    elif Sigma.dim() == 1:
        return Sigma


def theta(Sigma):
    if Sigma.dim() == 2:
        # NxN
        return torch.acos(
            torch.min(Sigma / norm_mul(Sigma), torch.tensor(1.0, device=Sigma.device))
        )
    elif Sigma.dim() == 1:
        # N
        return torch.acos(torch.ones_like(Sigma))


def J_1(u):
    return torch.sin(u) + (np.pi - u) * torch.cos(u)


def J_0(u):
    return np.pi - u


def Tau_Sigma(Sigma):
    return norm_mul(Sigma) * J_1(theta(Sigma)) / (2 * np.pi)


def Tau_Sigma_prime(Sigma):
    return J_0(theta(Sigma)) / (2 * np.pi)


def Next_Sigma(Sigma, weight_std=1.0, bias_std=1.0):
    # next layer's gaussian process kernel
    return weight_std * weight_std * Tau_Sigma(Sigma) + bias_std*bias_std


def Next_Theta(Theta, Sigma, weight_std=1.0, bias_std=1.0):
    # next layer's neural tangent kernel
    return Next_Sigma(Sigma, weight_std,
                      bias_std) + weight_std * weight_std * Theta * Tau_Sigma_prime(Sigma)


def Next(Sigma, Theta, weight_std=1.0, bias_std=1.0):

    next_Sigma = weight_std * weight_std * Tau_Sigma(Sigma) + bias_std*bias_std
    next_Theta = next_Sigma + weight_std * weight_std * Theta * Tau_Sigma_prime(Sigma)

    return next_Sigma, next_Theta


def ReLU_NTK(X, weight_std=1.0, bias_std=1.0, hidden_layer_num=5, pointwise=False):
    if pointwise:
        Sigma = pointwise_first_kernel(X, weight_std, bias_std)
    else:
        Sigma = first_kernel(X, weight_std, bias_std)
    Theta = Sigma

    for i in range(hidden_layer_num):
        Sigma, Theta = Next(Sigma, Theta, weight_std, bias_std)

    return Theta


def empirical_kernel(x, init_net):
    # only single value is effectively calculatable in Pytorch
    y0 = init_net(x)    # 1x1
    temp = torch.ones_like(y0, dtype=torch.float, requires_grad=True)

    fc_in_w_g, fc_in_b_g, fc_out_w_g, fc_out_b_g = grad(
        outputs=y0,
        inputs=(
            init_net.fc_in_w,
            init_net.fc_in_b,
            init_net.fc_out_w,
            init_net.fc_out_b,
        ),
        grad_outputs=temp,
        create_graph=True
    )    # pxp
    fc_hidden_w_g = grad(
        outputs=y0, inputs=init_net.fc_hidden_w, grad_outputs=temp, create_graph=True
    )
    fc_hidden_b_g = grad(
        outputs=y0, inputs=init_net.fc_hidden_b, grad_outputs=temp, create_graph=True
    )

    g = sum(mul(fc_in_w_g, fc_in_w_g)) + sum(mul(fc_in_b_g, fc_in_b_g)) + sum(
        mul(fc_out_w_g, fc_out_w_g)
    ) + sum(mul(fc_out_b_g, fc_out_b_g))

    for i in range(init_net.hidden_layer_num - 1):
        g = g + sum(mul(fc_hidden_w_g[i], fc_hidden_w_g[i])
                    ) + sum(mul(fc_hidden_b_g[i], fc_hidden_b_g[i]))

    return g.item()
