import torch.nn as nn
import torch

import numpy as np
import torch.nn.functional


def newton_step(x, model, u, loss_fun):
    loss = loss_fun(x, model(u))
    print(loss.item())
    a = model[0].weight
    h = hessian(loss, a)
    g = gradient(loss, a)
    h2 = h.cpu().detach().numpy()
    g2 = g.cpu().detach().numpy()
    da = np.linalg.lstsq(h2 + 1e-3 * np.eye(h2.shape[0]), -g2, rcond=-1)[0]

    loss = loss_fun(x, model(u))
    b = model[0].bias
    h = hessian(loss, b)
    g = gradient(loss, b)
    h2 = h.cpu().detach().numpy()
    g2 = g.cpu().detach().numpy()
    db = np.linalg.lstsq(h2 + 1e-3 * np.eye(h2.shape[0]), -g2, rcond=-1)[0]

    loss = loss_fun(x, model(u))
    c = model[2].weight
    h = hessian(loss, c)
    g = gradient(loss, c)
    h2 = h.cpu().detach().numpy()
    g2 = g.cpu().detach().numpy()
    dc = np.linalg.lstsq(h2 + 1e-3 * np.eye(h2.shape[0]), -g2, rcond=-1)[0]

    model[0].weight.data += torch.from_numpy(da).to(g).view(a.shape)
    model[0].bias.data += torch.from_numpy(db).to(g).view(b.shape)
    model[2].weight.data += torch.from_numpy(dc).to(g).view(c.shape)
    loss = loss_fun(x, model(u))
    print(loss.item())
    print("------------")


def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False):
    '''
    Compute the gradient of `outputs` with respect to `inputs`
    gradient(x.sum(), x)
    gradient((x * y).sum(), [x, y])
    '''
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)
    grads = torch.autograd.grad(outputs, inputs, grad_outputs,
                                allow_unused=True,
                                retain_graph=retain_graph,
                                create_graph=create_graph)
    grads = [x if x is not None else torch.zeros_like(y) for x, y in zip(grads, inputs)]
    return torch.cat([x.contiguous().view(-1) for x in grads])


def hessian(output, inputs, out=None, allow_unused=False, create_graph=False):
    '''
    Compute the Hessian of `output` with respect to `inputs`
    hessian((x * y).sum(), [x, y])
    '''
    assert output.ndimension() == 0

    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)

    n = sum(p.numel() for p in inputs)
    if out is None:
        out = output.new_zeros(n, n)

    ai = 0
    for i, inp in enumerate(inputs):
        [grad] = torch.autograd.grad(output, inp, create_graph=True, allow_unused=allow_unused)
        grad = torch.zeros_like(inp) if grad is None else grad
        grad = grad.contiguous().view(-1)

        for j in range(inp.numel()):
            if grad[j].requires_grad:
                row = gradient(grad[j], inputs[i:], retain_graph=True, create_graph=create_graph)[j:]
            else:
                row = grad[j].new_zeros(sum(x.numel() for x in inputs[i:]) - j)

            out[ai, ai:].add_(row.type_as(out))  # ai's row
            if ai + 1 < n:
                out[ai + 1:, ai].add_(row[1:].type_as(out))  # ai's column
            del row
            ai += 1
        del grad

    return out


def init_weights_gaussian(linear_layer):
    print("Initiallizing with Unit Gaussian")
    linear_layer.weight.data.normal_(mean=0, std=1.0)


def init_weights_one_over_sqrt_n(linear_layer):
    print("Initiallizing with 1/sqrt(n)")
    stdv = 1. / np.sqrt(linear_layer.weight.size(1))
    linear_layer.weight.data.uniform_(-stdv, stdv)
    if linear_layer.bias is not None:
        linear_layer.bias.data.uniform_(-stdv, stdv)


def init_weights_one_over_n(linear_layer):
    print("Initiallizing with 1/n")
    stdv = 1. / linear_layer.weight.size(1)
    linear_layer.weight.data.uniform_(-stdv, stdv)
    if linear_layer.bias is not None:
        linear_layer.bias.data.uniform_(-stdv, stdv)


def init_weights_uniform(linear_layer):
    print("Initiallizing with Uniform")
    stdv = 1.
    linear_layer.weight.data.uniform_(-stdv, stdv)


def init_weights_gaussian_one_over_sqrt_n(linear_layer):
    print("Initiallizing with Unit Gaussian")
    stdv = 1. / linear_layer.weight.size(1)
    linear_layer.weight.data.normal_(mean=0, std=stdv)
    if linear_layer.bias is not None:
        linear_layer.bias.data.normal_(mean=0, std=stdv)


class PolyReLU(nn.Module):
    def __init__(self, deg):
        super().__init__()
        self.deg = deg

    def forward(self, x):
        return nn.functional.relu(x) ** self.deg


def make_model(in_dim, out_dim, num_hidden_units, last_layer_bias=False,
               device='cpu', use_double=False, init='default', relu_deg=1):
    """
    Generate an MLP with ReLU nonlinearities and len(num_hidden_units) hidden layers where each layer, i has
    num_hidden_units[i] neurons.
    :param in_dim:
    :param out_dim:
    :param num_hidden_units: A list of integers representing the number of neurons per hidden layer
    :param last_layer_bias: Whether to use a bias in the last layer of the network
    :param device: The device on which to store the model
    :param use_double:
    :param init:
    :return: A torch.nn.Module representing the MLP
    """

    if init == "default":
        init_fun = None
    elif init == "normal":
        init_fun = init_weights_gaussian
    elif init == "uniform":
        init_fun = init_weights_uniform
    elif init == "one-over-n":
        init_fun = init_weights_one_over_n
    elif init == "one-over-sqrt-n":
        init_fun = init_weights_one_over_sqrt_n
    elif init == "normal-one-over-sqrt-n":
        init_fun = init_weights_gaussian_one_over_sqrt_n
    else:
        raise ValueError("Invalid value for init argument")

    try:
        num_hidden_units[0]
    except TypeError:
        num_hidden_units = [num_hidden_units]

    last_out = in_dim
    units = []
    for l in num_hidden_units:
        units.append(nn.Linear(last_out, l))
        units.append(PolyReLU(relu_deg))
        last_out = l
    units.append(nn.Linear(last_out, out_dim, bias=last_layer_bias))

    if init_fun is not None:
        for unit in units:
            if type(unit) == nn.Linear:
                init_fun(unit)

    if use_double:
        return nn.Sequential(*units).to(device).double()
    else:
        return nn.Sequential(*units).to(device)


def model_from_state(state):
    return make_model(state['in_dim'], state['out_dim'],
                      state['num_hidden_units'],
                      last_layer_bias=state['last_layer_bias'],
                      device=state['device'],
                      use_double=state['use_double'],
                      init=state['init'])
