# from data_config import data_config
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.functional import jvp
from torch.nn.utils import parameters_to_vector as vec
import numpy as np


def Network_Forward(
    input,
    weight,
    input_dim,
    hidden_layer_dim,
    hidden_layer_num,
    output_dim,
    weight_std,
    bias_std,
):

    i = 0

    out = F.relu(
        F.linear(
            input,
            weight=weight[i:i + hidden_layer_dim*input_dim].reshape(hidden_layer_dim, input_dim) *
            weight_std / np.sqrt(input_dim),
            bias=weight[i + hidden_layer_dim*input_dim:i + hidden_layer_dim*input_dim +
                        hidden_layer_dim].reshape(hidden_layer_dim, ) * bias_std
        )
    )
    i = i + hidden_layer_dim*input_dim + hidden_layer_dim

    for j in range(hidden_layer_num - 1):
        out = F.relu(
            F.linear(
                out,
                weight=weight[i:i + hidden_layer_dim*
                              hidden_layer_dim].reshape(hidden_layer_dim, hidden_layer_dim) *
                weight_std / np.sqrt(hidden_layer_dim),
                bias=weight[i + hidden_layer_dim*hidden_layer_dim:i + hidden_layer_dim*
                            hidden_layer_dim + hidden_layer_dim].reshape(hidden_layer_dim, ) *
                bias_std
            )
        )
        i = i + hidden_layer_dim*hidden_layer_dim + hidden_layer_dim

    out = F.linear(
        out,
        weight=weight[i:i + hidden_layer_dim*output_dim].reshape(output_dim, hidden_layer_dim) *
        weight_std / np.sqrt(hidden_layer_dim),
        bias=weight[i + hidden_layer_dim*output_dim:i + hidden_layer_dim*output_dim +
                    output_dim].reshape(output_dim, ) * bias_std
    )

    # i = i + hidden_layer_dim*output_dim + output_dim

    return out


# network building
class Net(nn.Module):

    def __init__(
        self,
        input_dim=2,
        hidden_layer_dim=1024,
        hidden_layer_num=5,
        output_dim=1,
        weight_std=1.0,
        bias_std=1.0,
    ):
        super(Net, self).__init__()

        self.input_dim = input_dim
        self.hidden_layer_dim = hidden_layer_dim
        self.hidden_layer_num = hidden_layer_num
        self.output_dim = output_dim
        self.weight_std = weight_std
        self.bias_std = bias_std

        self.parameter_num = hidden_layer_dim * (input_dim+1) + hidden_layer_dim * (
            hidden_layer_dim+1
        ) * (hidden_layer_num-1) + output_dim * (hidden_layer_dim+1)

        self.weight = nn.Parameter(torch.empty(self.parameter_num))

        i = 0
        nn.init.normal_(self.weight[i:i + self.hidden_layer_dim * self.input_dim], std=weight_std)
        nn.init.normal_(
            self.weight[i + self.hidden_layer_dim * self.input_dim:i +
                        self.hidden_layer_dim * self.input_dim + self.hidden_layer_dim],
            std=bias_std
        )
        i = i + self.hidden_layer_dim * self.input_dim + self.hidden_layer_dim

        for j in range(self.hidden_layer_num - 1):
            nn.init.normal_(
                self.weight[i:i + self.hidden_layer_dim * self.hidden_layer_dim], std=weight_std
            )
            nn.init.normal_(
                self.weight[i + self.hidden_layer_dim * self.hidden_layer_dim:i +
                            self.hidden_layer_dim * self.hidden_layer_dim + self.hidden_layer_dim],
                std=bias_std
            )
            i = i + self.hidden_layer_dim * self.hidden_layer_dim + self.hidden_layer_dim

        nn.init.normal_(self.weight[i:i + self.hidden_layer_dim * self.output_dim], std=weight_std)
        nn.init.normal_(
            self.weight[i + self.hidden_layer_dim * self.output_dim:i +
                        self.hidden_layer_dim * self.output_dim + self.output_dim],
            std=bias_std
        )
        i = i + self.hidden_layer_dim * self.output_dim + self.output_dim

    def forward(self, input):
        return Network_Forward(
            input=input,
            weight=self.weight,
            input_dim=self.input_dim,
            hidden_layer_dim=self.hidden_layer_dim,
            hidden_layer_num=self.hidden_layer_num,
            output_dim=self.output_dim,
            weight_std=self.weight_std,
            bias_std=self.bias_std,
        )

    def vec(self):
        return vec(self.parameters())


class Linear_Net(nn.Module):

    def __init__(self, target):
        super(Linear_Net, self).__init__()

        self.input_dim = target.input_dim
        self.hidden_layer_dim = target.hidden_layer_dim
        self.hidden_layer_num = target.hidden_layer_num
        self.output_dim = target.output_dim
        self.weight_std = target.weight_std
        self.bias_std = target.bias_std

        self.parameter_num = target.parameter_num

        self.weight = nn.Parameter(target.weight.clone().detach())

    def forward(self, input, init_net):

        y0, dy = jvp(
            lambda param: Network_Forward(
                input,
                param,
                self.input_dim,
                self.hidden_layer_dim,
                self.hidden_layer_num,
                self.output_dim,
                self.weight_std,
                self.bias_std,
            ),
            inputs=init_net.weight,
            v=self.weight - init_net.weight,
            create_graph=True,
        )
        return y0 + dy

    def vec(self):
        return vec(self.parameters())
