import torch 
import torch.nn as nn 

import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from .BBP import Clamp

class PyroNN(PyroModule):
    def __init__(self, input_dim, W, L, scale_data=None, cuda=False, prior=None, F=None):
        super().__init__()
        device = torch.device("cuda" if cuda else "cpu")
        net_list = [PyroModule[nn.Linear](input_dim, W), PyroModule[nn.ReLU]()]
        if L >= 1:
            for i in range(L-1):
                net_list += [PyroModule[nn.Linear](W, W), PyroModule[nn.ReLU]()]
        net_list += [PyroModule[nn.Linear](W, 1)]
        if F is not None:
            net_list += [PyroModule[Clamp](F)]
        net = PyroModule[nn.Sequential](*net_list)

        for m in net.modules():
            for name, value in list(m.named_parameters(recurse=False)):
                setattr(m, name, PyroSample(prior=prior
                    .expand(value.shape)
                    .to_event(value.dim())))
        
        self.net = net
        self.scale_data = scale_data
        self.device = device
        self.input_dim = input_dim
        self.W = W
        self.L = L
        self.F = F
        if scale_data is not None:
            self.scale_data = torch.tensor(scale_data, device=device)
        else:
            self.scale_data = PyroSample(dist.Gamma(0.1, torch.tensor(0.1, device=device)))

    def forward(self, x, y=None):
        mu = self.net(x).squeeze()
        scale_data = self.scale_data
        with pyro.plate("data", len(x), device=self.device):
            return pyro.sample("obs", dist.Normal(mu, scale_data), obs=y)

    def set_deterministic(self, lr=1e-3):
        self.lr = lr
        net_list = [nn.Linear(self.input_dim, self.W), nn.ReLU()]
        if self.L >= 1:
            for i in range(self.L-1):
                net_list += [nn.Linear(self.W, self.W), nn.ReLU()]
        net_list += [nn.Linear(self.W, 1)]
        net = nn.Sequential(*net_list)
        self.net_deterministic = net
        self.net_deterministic.to(self.device)
        self.optimizer = torch.optim.Adam(
            self.net_deterministic.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=1000, gamma=0.8, last_epoch=-1) 
        self.loss_fn = nn.MSELoss()

    def fit_deterministic(self, x, y):
        self.optimizer.zero_grad()
        y_pred = self.net_deterministic(x)
        loss = self.loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()
        return loss.item()
