import os
import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d as BatchNorm, Linear, ReLU, Sequential
from torch_geometric.nn import GCNConv, GINConv, GATConv, SAGEConv, TransformerConv, global_add_pool
from utils import *
from utils_dist import *
from GMT_model.nets_dist import GraphMultisetTransformer, GraphMultisetTransformer_for_OGB

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, return_embeds=False, reg_term=0,loss_module=None):
        super(GCN, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_dim
        self.alpha = reg_term
        self.convs = torch.nn.ModuleList([GCNConv(input_dim, hidden_dim)])
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        self.lin1 = Linear(hidden_dim, hidden_dim)
        self.lin2 = Linear(hidden_dim, output_dim)
        self.dropout = dropout
        self.return_embeds = return_embeds
        self.loss_module=loss_module

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, edge_weight=None):
        x, edge_index, batch = data.x.to(dtype=torch.float), data.edge_index, data.batch
        layer_outputs = []  

        for conv in self.convs:
            x = conv(x, edge_index, edge_weight=edge_weight)
            x = F.relu(x)
            layer_outputs.append(x)

        pooled_outputs = [global_add_pool(layer_output, batch) for layer_output in layer_outputs]
        
        pooled_output = F.relu(self.lin1(pooled_outputs[-1])) 
        out = self.lin2(F.dropout(pooled_output, p=self.dropout, training=self.training))

        return out, pooled_outputs

    def loss(self, pred, y, pooled_outputs, task_type=None):
        return loss(pred, y, pooled_outputs, task_type, self.loss_module, self.alpha)

class GIN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, return_embeds=False, reg_term=0.01,loss_module=None):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.num_layers = num_layers
        self.return_embeds = return_embeds
        self.dropout = dropout
        self.reg_term = reg_term

        mlp = Sequential(Linear(input_dim, hidden_dim), ReLU(), BatchNorm(hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim))
        self.convs.append(GINConv(mlp, train_eps=True))
        
        for _ in range(num_layers - 1):
            mlp = Sequential(Linear(hidden_dim, hidden_dim), ReLU(), BatchNorm(hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim))
            self.convs.append(GINConv(mlp, train_eps=True))
        
        self.lin1 = Linear(hidden_dim, hidden_dim)
        self.lin2 = Linear(hidden_dim, output_dim)
        self.loss_module=loss_module

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, edge_weight=None):
        x, edge_index, batch = data.x.to(dtype=torch.float), data.edge_index, data.batch
        layer_outputs = []  
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            layer_outputs.append(x)
        pooled_outputs = [global_add_pool(layer_output, batch) for layer_output in layer_outputs]
        pooled_output = F.relu(self.lin1(pooled_outputs[-1]))  
        out = self.lin2(F.dropout(pooled_output, p=self.dropout, training=self.training))

        return out, pooled_outputs

    def loss(self, pred, y, pooled_outputs, task_type=None):
        return loss(pred, y, pooled_outputs, task_type, self.loss_module, self.reg_term)

class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, heads=1, return_embeds=False, reg_term=0,loss_module=None):
        super(GAT, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_dim
        self.convs = torch.nn.ModuleList([GATConv(input_dim, hidden_dim, heads=heads)])
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_dim * heads, hidden_dim, heads=heads))
        self.lin1 = Linear(hidden_dim * heads, hidden_dim)
        self.lin2 = Linear(hidden_dim, output_dim)
        self.dropout = dropout
        self.return_embeds = return_embeds
        self.alpha = reg_term
        self.loss_module=loss_module

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, edge_weight=None):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        layer_outputs = [] 
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            layer_outputs.append(x)

        pooled_outputs = [global_add_pool(layer_output, batch) for layer_output in layer_outputs]
        pooled_output = F.relu(self.lin1(pooled_outputs[-1]))  
        out = self.lin2(F.dropout(pooled_output, p=self.dropout, training=self.training))

        return out, pooled_outputs

    def loss(self, pred, y, pooled_outputs, task_type=None):
        return loss(pred, y, pooled_outputs, task_type, self.loss_module, self.alpha)

class GraphSAGE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, return_embed=False, reg_term=0,loss_module=None):
        super(GraphSAGE, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        self.return_embeds = return_embed
        self.alpha = reg_term

        self.convs = torch.nn.ModuleList([SAGEConv(input_dim, hidden_dim)])
        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))

        self.lin1 = Linear(hidden_dim, hidden_dim)
        self.lin2 = Linear(hidden_dim, output_dim)
        self.loss_module=loss_module

    def forward(self, data, edge_weight=None):
        x, edge_index, batch = data.x.to(dtype=torch.float), data.edge_index, data.batch
        layer_outputs = []  # To store output of each layer

        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            layer_outputs.append(x)

        pooled_outputs = [global_add_pool(layer_output, batch) for layer_output in layer_outputs]
        pooled_output = F.relu(self.lin1(pooled_outputs[-1]))  # Only use the last layer's output for final prediction
        out = self.lin2(F.dropout(pooled_output, p=self.dropout, training=self.training))

        return out, pooled_outputs

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def loss(self, pred, y, pooled_outputs, task_type=None):
        return loss(pred, y, pooled_outputs, task_type, self.loss_module, self.alpha)

class TransformerNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, return_embeds=False, reg_term=0, loss_module=None):
        super(TransformerNet, self).__init__()
        
        self.num_layers = num_layers
        self.hidden_channels = hidden_dim
        self.alpha = reg_term  # Regularization term
        self.convs = torch.nn.ModuleList([
            TransformerConv(input_dim, hidden_dim, heads=1, concat=True, dropout=dropout)
        ])
        for _ in range(num_layers - 1):
            self.convs.append(
                TransformerConv(hidden_dim, hidden_dim, heads=1, concat=True, dropout=dropout)
            )
        
        self.lin1 = Linear(hidden_dim, hidden_dim)
        self.lin2 = Linear(hidden_dim, output_dim)
        self.dropout = dropout
        self.return_embeds = return_embeds
        self.loss_module = loss_module  # Custom loss module for regularization

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, edge_attr=None):
        x, edge_index, batch = data.x.to(dtype=torch.float), data.edge_index, data.batch
        layer_outputs = []  # To store output of each layer

        for conv in self.convs:
            x = conv(x, edge_index, edge_attr=edge_attr)
            x = F.relu(x)
            layer_outputs.append(x)

        # Apply global pooling to get graph-level embeddings
        pooled_outputs = [global_add_pool(layer_output, batch) for layer_output in layer_outputs]
        
        # Final prediction layer
        pooled_output = F.relu(self.lin1(pooled_outputs[-1]))
        x = F.dropout(pooled_output, p=self.dropout, training=self.training)
        out = self.lin2(x)
        
        return out, pooled_outputs  


    def loss(self, pred, y, pooled_outputs, task_type=None):
        return loss(pred, y, pooled_outputs, task_type, self.loss_module, self.alpha)

