#only experiment on graph classification tasks
import torch
from model_dist import GCN, GAT, GIN,GraphSAGE,TransformerNet
from GMT_model.nets_dist import GraphMultisetTransformer, GraphMultisetTransformer_for_OGB
from torch_geometric.loader import DataLoader
import os
from tqdm import tqdm
import torch.nn as nn
from utils import EarlyStopper
import wandb
import ot  
import numpy as np
import torch._dynamo
torch._dynamo.config.suppress_errors = True

TUData=["PROTEINS","IMDB-BINARY","REDDIT-BINARY","COLLAB","NCI1","NCI109","MUTAG","DD","PTC_MR", 'ENZYMES']
OGB_Data=['ogbg-molhiv','ogbg-molpcba']


    

class RankNetLoss(nn.Module):
    """
    RankNet loss implemented as a PyTorch nn.Module, excluding self-comparisons by removing diagonal elements.
    """
    def __init__(self):
        super(RankNetLoss, self).__init__()

    def forward(self, previous_layer, current_layer):
        """
        Parameters:
            - previous_layer: Tensor of labels indicating the preferred item (1 if first item is preferred, 0 otherwise).
            - current_layer: Tensor of item features or similarities.
        
        Returns:
            - RankNet loss for the input batch, excluding self-comparisons.
        """
        n = current_layer.size(0)
        
        mask = ~torch.eye(n, dtype=torch.bool, device=current_layer.device)

        pairwise_diffs = current_layer.unsqueeze(1) - current_layer.unsqueeze(0)
        pairwise_diffs = pairwise_diffs * mask.unsqueeze(2).float()

        #pairwise_diff
        sigmoid_diffs = torch.sigmoid(pairwise_diffs)
        label_matrix = previous_layer.unsqueeze(1) - previous_layer.unsqueeze(0)
        label_matrix = label_matrix.sign()
        label_matrix = (1 + label_matrix) / 2
        label_matrix = label_matrix * mask.float()

        losses = -label_matrix * torch.log(sigmoid_diffs + 1e-15) - (1 - label_matrix) * torch.log(1 - sigmoid_diffs + 1e-15)
        losses = torch.where(torch.isnan(losses) | ~mask, torch.zeros_like(losses), losses)
        mean_loss = losses.sum() / (mask.sum()*n)
        return mean_loss

#proportion of training set
train_splits=0.8  
validate_splits=0.1
criterion = torch.nn.CrossEntropyLoss()
save_path="./model/"

    
def get_model(model_name, dataset, device, config, loss_module="RankNetLoss"):
    dropout_ratio=config['dropout']
    hidden_size=config['hidden_size'] 
    num_layers=config['num_layers']
    reg_term=config['reg_term']
    
    
    if loss_module == "RankNetLoss":
        loss_func=RankNetLoss().to(device)
        print("Using RankNetLoss")

    if model_name == 'GCN':
        model = GCN(dataset.num_features,hidden_size,dataset.num_classes,num_layers=num_layers,dropout=dropout_ratio,reg_term=reg_term,loss_module=loss_func)
    elif model_name == 'GAT':
        model = GAT(dataset.num_features,hidden_size,dataset.num_classes,num_layers=num_layers,dropout=dropout_ratio,heads=1,reg_term=reg_term,loss_module=loss_func)
    elif model_name == 'GIN':
        model = GIN(dataset.num_features,hidden_size,dataset.num_classes,num_layers=num_layers,dropout=dropout_ratio,reg_term=reg_term,loss_module=loss_func)
    elif model_name == 'GraphSAGE':
        model = GraphSAGE(dataset.num_features,hidden_size,dataset.num_classes,num_layers=num_layers,dropout=dropout_ratio,reg_term=reg_term,loss_module=loss_func)
    elif model_name == 'GTransformer':
        model = TransformerNet(dataset.num_features,hidden_size,dataset.num_classes,num_layers=num_layers,dropout=dropout_ratio,reg_term=reg_term,loss_module=loss_func)
    elif model_name == 'GMT':
        if dataset.name in TUData:
            model = GraphMultisetTransformer(dataset.num_features,hidden_size,dataset.num_classes,config['heads'],avg_num_nodes=np.ceil([np.mean([data.num_nodes for data in dataset])]),reg_term=reg_term,loss_module=loss_func)
        elif dataset.name in OGB_Data:
            model = GraphMultisetTransformer_for_OGB(dataset.num_features,hidden_size,dataset.num_classes,num_heads=config['heads'],avg_num_nodes=np.ceil([np.mean([data.num_nodes for data in dataset])]),reg_term=reg_term,loss_module=loss_func)
    else:
        raise ValueError("This Model is not implemented")
    print(model_name, " Training...")
    return model

def train(model, data_loader, optimizer, device, task_type):
    model = model.to(device)
    model.train()
    print('Training...')
    for step, data in tqdm(enumerate(data_loader), total=len(data_loader)):
        data = data.to(device)
        if data.x.shape[0] == 1 or data.batch[-1] == 0:
            continue  # Skip batches that are too small or have incorrect batching
        
        optimizer.zero_grad()
        loss = 0
        
        out, pooled_outputs = model(data)


        sub_loss ,_ = model.loss(out, data.y, pooled_outputs, task_type)
        loss = loss + sub_loss
        
        loss.backward()

        optimizer.step()


def test(model, loader, device, evaluator=None, task_type=None):
    model.eval()

    correct = 0
    total_loss=0
    total_samples = 0
    total_dist_loss=0

    with torch.no_grad():
        pred_list = []
        y_list = []
        print("Testing...")

        for data in tqdm(loader, total=len(loader)):
            data = data.to(device)

            out, pooled_outputs = model(data)

            test_loss, dist_loss = model.loss(out, data.y, pooled_outputs, task_type)
           
            total_dist_loss += dist_loss.item()
            total_loss += test_loss.item()
            
            if evaluator is not None:
                pred_list.append(out)
            else:
                pred_list.append(out.argmax(dim=1))
            y_list.append(data.y)
    
            total_samples += data.y.size(0)  # Update total samples for accuracy calculation
        
        pred_list = torch.cat(pred_list)
        y_list = torch.cat(y_list)
        if evaluator is not None:
            test_loss = evaluator.eval({"y_true": y_list, "y_pred": pred_list})
            correct = test_loss[list(test_loss.keys())[0]]
        else:
            correct = torch.sum(pred_list==y_list).item()

    if evaluator is None:
        return correct / total_samples, total_loss / len(loader), total_dist_loss / len(loader)
    else:
        return correct, total_loss / len(loader), total_dist_loss / len(loader)


def train_model_dist(model_name,dataset,dataloaders,config,patience=30,
                min_delta=0.005,device=7,wandb_record=True,save_model=False,
                seed=12345, evaluator=None,task_type=None,
                ):
    
    
    epoch = config['epochs']
    
    
    device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")
    print("Model: ",model_name, "Dataset: ",dataset.name )
    print("Train on device:",device)

    if save_model:

        save_path='./dist_model/'+model_name+'/'
        if os.path.exists(save_path)==False:
            os.mkdir(save_path)
        

    else:
        model_saved_path=None
    
    [train_loader, valid_loader, test_loader] = dataloaders
    model = get_model(model_name,dataset,device=device,config=config)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'],betas=(0.9, 0.999))
    early_stopper = EarlyStopper(patience=patience, min_delta=min_delta,file_path="./rank_model/",saved=False)

    for epoch in range(0, epoch):
        
        #start_time = time.time()
        train(model,train_loader,optimizer,device,task_type)
        train_acc, train_loss, train_dist_loss = test(model,train_loader,device,evaluator,task_type)
        test_acc, _ ,test_dist_loss= test(model,test_loader,device,evaluator,task_type)
        val_acc, validation_loss, valid_dist_loss= test(model,valid_loader,device,evaluator,task_type)
       
        if early_stopper.early_stop(val_acc,epoch, test_acc, model):
            break

        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss:{train_loss:.4f}, Test Acc: {test_acc:.4f}', "Test_dist_loss: ",test_dist_loss,"Validation Loss: ",validation_loss,"acc_early_stop: ",early_stopper.test_acc_record)
    print("Early stopping at Epoch: %d, Test Acc: %f"%(early_stopper.epoch_counter, early_stopper.test_acc_record))
    
    
    
    #torch.save(model,model_saved_path)

    if wandb_record:  
        wandb.log({"test_acc_of_early_stop": early_stopper.test_acc_record})
    


def data_loader(dataset_name, dataset, batch_size):
    if dataset_name in TUData:
        #proportion of training set
        train_splits=0.8  
        validate_splits=0.1
        data_train_index=int(len(dataset)*train_splits)
        data_test_index=int(len(dataset)*(train_splits+validate_splits))
        
        train_dataset =dataset[:data_train_index]
        test_dataset = dataset[data_train_index:data_test_index]
        valid_dataset = dataset[data_test_index:]
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        valid_loader= DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

        return [train_loader, valid_loader, test_loader]
    
    elif dataset_name in OGB_Data:
        split_idx = dataset.get_idx_split()

        train_loader = DataLoader(dataset[split_idx["train"]], batch_size=batch_size, shuffle=True)
        valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(dataset[split_idx["test"]], batch_size=batch_size, shuffle=False)

        return [train_loader, valid_loader, test_loader]
    else:
        print("Error")
        raise Exception("Error in load_dataset")

