import numpy as np
import torch
import torch.nn as nn

    
class MLP(nn.Module):
    def __init__(self, m, d=2, d_out=1, p=3, D=2, activation=nn.ReLU(), **kwargs):
        super(MLP, self).__init__()
        self.d = d
        self.m = m
        self.p = p
        self.D = D
        self.d_out = d_out
        
        self.fc1 = nn.Linear(in_features=d, out_features=m, bias=True)
        for l in range(2,D):
            setattr(self, f'fc{l}', nn.Linear(in_features=m, out_features=m, bias=True))
        setattr(self, f'fc{D}', nn.Linear(in_features=m, out_features=self.d_out, bias=False))
        self.activation = self.activation = activation
        
        self.initialize()
                        
    def forward(self, x):
        out = torch.pow(self.activation(self.fc1(x)), self.p)
        for l in range(2, self.D):
            out = getattr(self, f'fc{l}')(out)
            out = torch.pow(self.activation(out), self.p)
        out = getattr(self, f'fc{self.D}')(out)
        out = out / np.sqrt(self.m)
        return out
    
    def initialize(self):
        for l in range(1,self.D):
            layer = getattr(self, f'fc{l}')
            fan_in = layer.in_features
            fan_out = layer.out_features
            w = np.random.uniform(low=-np.sqrt(3 / (fan_in+1)),
                                  high=np.sqrt(6 / (fan_in+1)),
                                  size=(fan_out, fan_in))
            b = np.random.uniform(low=-np.sqrt(3 / (fan_in+1)),
                                  high=np.sqrt(6 / (fan_in+1)),
                                  size=fan_out)
            getattr(self, f'fc{l}').weight.data = torch.tensor(w, dtype=torch.float32)
            getattr(self, f'fc{l}').bias.data = torch.tensor(b, dtype=torch.float32)
            del w, b
        
        layer = getattr(self, f'fc{self.D}')
        fan_in = layer.in_features
        fan_out = layer.out_features
        w = np.random.uniform(low=-np.sqrt(3 / fan_in),
                                      high=np.sqrt(6 / fan_in),
                                      size=(fan_out, fan_in))
        getattr(self, f'fc{self.D}').weight.data = torch.tensor(w, dtype=torch.float32)
        del w

        return
