import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, global_add_pool

class NetGIN_readout(torch.nn.Module):
    
    def __init__(self, n_classes, n_features_node=6, width=5, depth=3):
        super(NetGIN_readout, self).__init__()
        self.n_classes  = n_classes
        self.depth = depth
        self.conv  = torch.nn.ModuleList()
        self.bn    = torch.nn.ModuleList()

        self.conv.append(GINConv(Sequential(Linear(n_features_node, width), ReLU(), Linear(width, width))))            
        self.bn.append(torch.nn.BatchNorm1d(width))        
        for layer in range(1, self.depth):
            self.conv.append(GINConv(Sequential(Linear(width, width), ReLU(), Linear(width, width))))
            self.bn.append(torch.nn.BatchNorm1d(width))
        self.fc = Linear(width, n_classes)
        
    def forward(self, x, edge_index, edge_attr, batch):
        x0 = x 
        for ilayer in range(self.depth):      
            if ilayer == 0:
                x = self.bn[ilayer](F.relu(self.conv[ilayer](x, edge_index)))
            else:
                x = self.bn[ilayer](F.relu(self.conv[ilayer](x, edge_index)) + x) 

        x = self.fc(x)
        x = global_add_pool(x, batch)
        return F.log_softmax(x, dim=-1)
    

class NetGIN(torch.nn.Module):
    
    def __init__(self, n_classes, n_features_node=6, width=5, depth=3):
        super(NetGIN, self).__init__()
        self.n_classes  = n_classes
        self.depth = depth
        self.conv  = torch.nn.ModuleList()
        self.bn    = torch.nn.ModuleList()

        self.conv.append(GINConv(Sequential(Linear(n_features_node, width), ReLU(), Linear(width, width))))            
        self.bn.append(torch.nn.BatchNorm1d(width))        
        for layer in range(1, self.depth):
            self.conv.append(GINConv(Sequential(Linear(width, width), ReLU(), Linear(width, width))))
            self.bn.append(torch.nn.BatchNorm1d(width))
        self.fc = Linear(width, n_classes)
        
    def forward(self, x, edge_index, edge_attr, batch):
        x0 = x 
        for ilayer in range(self.depth):      
            if ilayer == 0:
                x = self.bn[ilayer](F.relu(self.conv[ilayer](x, edge_index)))
            else:
                x = self.bn[ilayer](F.relu(self.conv[ilayer](x, edge_index)) + x) 

        x = self.fc(x)
        x = x[x0[:,1].nonzero(as_tuple=True)]
        return F.log_softmax(x, dim=-1)

def train(model, train_loader, device, epoch, optimizer):
    model.train()
    lr_scheduler(epoch, optimizer)            
    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()        
        output = model(data.x, data.edge_index, data.edge_attr, data.batch)                
        loss = F.nll_loss(output, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    del data    
    return loss_all / len(train_loader.dataset)

def test(model, loader, device):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred = output.max(dim=1)[1]
        correct += pred.eq(data.y).sum().item() 
    return correct / len(loader.dataset) 

def lr_scheduler(epoch, optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.001 * (0.983**(epoch/5)) 
