import torch_geometric.transforms as T
import warnings
warnings.filterwarnings('ignore')
import torch
import numpy as np
import random
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import Actor
from torch_geometric.datasets import WebKB
from torch_geometric.datasets import Coauthor
from torch_geometric.datasets import WikiCS
import os
from torch_geometric.utils import dense_to_sparse

current_file_path = os.path.abspath(__file__)
print("Current file path:", current_file_path)

def Data_Loader(name):
    name = name.lower()
    root_path = 'dataset/'
    if name in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(root_path, name, split='random', num_train_per_class=20, num_val=500, num_test=1000,
                            transform=T.NormalizeFeatures())
    elif name in ['computers', 'photo']:
        dataset = Amazon(root_path, name, T.NormalizeFeatures())

    elif name in ['cs', 'physics']:
        dataset = Coauthor(root_path, name, T.NormalizeFeatures())

    elif name in ['chameleon', 'squirrel']:

        preProcDs = WikipediaNetwork(
            root=root_path, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(
            root=root_path, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        dataset.data = data
        return dataset

    elif name in ['film']:
        dataset = Actor(root=root_path+'/Actor', transform=T.NormalizeFeatures())
        dataset.name=name
    elif name in ['texas', 'cornell', 'wisconsin']:
        dataset = WebKB(root=root_path, name=name, transform=T.NormalizeFeatures())
    elif name in ['wikics']:
        dataset = WikiCS(root=root_path+'/WikiCS', transform=T.NormalizeFeatures())
    else:
        raise ValueError(f'dataset {name} not supported in dataloader')
    return dataset



def getMatrix(matrix, pr):
    non_zero_values = matrix[matrix != 0]
    num_ones = int(pr * non_zero_values.numel())
    sorted_values, _ = torch.sort(non_zero_values, descending=True)
    threshold_value = sorted_values[num_ones - 1]
    #matrix[matrix < threshold_value] = 0.0
    return threshold_value


def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

