import torch
from torch_geometric.loader import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from gnn import GNN
from sklearn.metrics import roc_auc_score
from collections import Counter

import os
from tqdm import tqdm
import argparse
import time
import numpy as np
import pickle

### importing OGB
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

### importing loss
from losses import AUCLoss_multiLabel

dtype = torch.cuda.FloatTensor

cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def zero_grad(model):
    for name, p in model.named_parameters():
        if p.grad is not None:
            p.grad.data.zero_()


def proj_sca(x, bound):
    if x > bound:
        return bound
    elif x < 0:
        return 0
    else:
        return x

class AUCLoss_multiLabel():

    def __init__(self, imratio, m=1.0):
        self.p = imratio
        self.m = m

    def g1(self, outputs, a, b, targets, task=0):
        p_i = self.p[task]
        a_i = a[task]
        b_i = b[task]
        loss_val = (1 - p_i) * torch.mean((outputs - a_i) ** 2 * (1 == targets).float()) + \
                   p_i * torch.mean((outputs - b_i) ** 2 * (0 == targets).float())
        return loss_val

    def g1_grad_a(self, outputs, a, targets, task=0):
        p_i = self.p[task]
        a_i = a[task]
        grad_val = -2 * (1 - p_i) * torch.mean((outputs - a_i) * (1 == targets).float())
        return grad_val

    def g1_grad_b(self, outputs, b, targets, task=0):
        p_i = self.p[task]
        b_i = b[task]
        grad_val = -2 * p_i * torch.mean((outputs - b_i) * (0 == targets).float())
        return grad_val

    def g2(self, outputs, targets, task=0):
        p_i = self.p[task]
        loss_val = -2 * (1 - p_i) * torch.mean(outputs * (1 == targets).float()) + \
                   2 * p_i * torch.mean(outputs * (0 == targets).float()) + \
                   2 * p_i * (1-p_i) * self.m
        return loss_val

    def g3(self, alpha, task=0):
        p_i = self.p[task]
        alpha_i = alpha[task]
        return p_i * (1 - p_i) * alpha_i ** 2

    def g3_grad(self, alpha, task=0):
        p_i = self.p[task]
        alpha_i = alpha[task]
        return 2 * p_i * (1 - p_i) * alpha_i


def eval_rocauc(y_true, y_pred):
    '''
        compute ROC-AUC averaged across tasks
    '''

    rocauc_list = []

    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            # ignore nan values
            is_labeled = y_true[:, i] == y_true[:, i]
            rocauc_list.append(roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]))

    if len(rocauc_list) == 0:
        raise RuntimeError('No positively labeled data available. Cannot compute ROC-AUC.')

    return sum(rocauc_list) / len(rocauc_list)


def train(model, device, loader, task_type, imratio, lr_decay=1):
    model.train()

    lr = args.lr/lr_decay
    beta = args.beta
    beta_ct = args.beta_ct

    label_set = np.linspace(0, 128- 1, 128).astype(int)
    tempmark = 1

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)


        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            pred = model(batch)

            if (step%2 == 0) and (args.method == 'ct'):
                ## ignore nan targets (unlabeled) when computing training loss.
                is_labeled = batch.y == batch.y

                np.random.shuffle(label_set)
                selectTasks = np.sort(label_set[:args.task_BATCH_SIZE])

                for task_idx in range(128):
                    if task_idx in selectTasks:
                        continue
                    else:
                        is_labeled[:,task_idx]=False

                loss_ce = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
                zero_grad(model)
                grads_ce = torch.autograd.grad(loss_ce, model.parameters(), retain_graph=False)
                for g, (name, w) in zip(grads_ce, model.named_parameters()):
                    w.data = w.data - lr * beta_ct * g.data
            else:

                np.random.shuffle(label_set)
                selectTasks = label_set[:args.task_BATCH_SIZE]

                pred_sig = torch.sigmoid(pred)
                Loss_auc = AUCLoss_multiLabel(imratio=imratio, m=1)
                grads_a, grads_b, grads_alp, loss_auc = 0, 0, 0, 0

                for task_idx in selectTasks:
                    ## ignore nan targets (unlabeled) when computing training loss.
                    is_labeled_i = batch.y[:,task_idx]==batch.y[:,task_idx]

                    if sum(is_labeled_i)==0:
                        continue

                    y_pred_i = pred_sig.to(torch.float32)[:,task_idx][is_labeled_i]
                    y_true_i = batch.y.to(torch.float32)[:, task_idx][is_labeled_i]

                    grads_a += Loss_auc.g1_grad_a(y_pred_i, a, y_true_i, task=task_idx)
                    grads_b += Loss_auc.g1_grad_b(y_pred_i, b, y_true_i, task=task_idx)
                    grads_alp += Loss_auc.g2(y_pred_i, y_true_i, task=task_idx) - Loss_auc.g3_grad(alpha, task=task_idx)
                    loss_auc += Loss_auc.g1(y_pred_i, a, b, y_true_i, task=task_idx) \
                               + alpha[task_idx] * Loss_auc.g2(y_pred_i, y_true_i, task=task_idx) \
                               - Loss_auc.g3(alpha, task=task_idx)

                grads_a = grads_a / args.task_BATCH_SIZE
                grads_b = grads_b / args.task_BATCH_SIZE
                grads_alp = grads_alp / args.task_BATCH_SIZE
                loss_auc = loss_auc / args.task_BATCH_SIZE

                z_a.data = (1 - beta) * z_a + beta * grads_a
                a.data = a - lr * z_a

                z_b.data = (1 - beta) * z_b + beta * grads_b
                b.data = b - lr * z_b

                alpha.data = alpha.data + lr * grads_alp
                alpha.data = torch.clamp(alpha.data, 0, 999)

                # w updates
                zero_grad(model)
                grads_auc = torch.autograd.grad(loss_auc, model.parameters(), retain_graph=False)

                for g, z_w, (name, w) in zip(grads_auc, z_w_list, model.named_parameters()):
                    z_w.data = (1 - beta) * z_w + beta * g.data
                    w.data = w.data - lr * z_w


def eval(model, device, loader):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    return eval_rocauc(y_true, y_pred)

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--gpu_id', type=str, default='0',
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--gnn', type=str, default='gin',
                        help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument('--drop_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5)')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset', type=str, default="ogbg-molpcba",
                        help='dataset name (default: ogbg-molpcba)')
    parser.add_argument('--feature', type=str, default="full",
                        help='full feature or simple feature')
    parser.add_argument('--filename', type=str, default='results_main_pyg_ct_rand',
                        help='filename to output result (default: )')
    parser.add_argument('--SEED', default=123, type=int, help='random seed')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--beta', default=0.9, type=float)
    parser.add_argument('--beta_ct', default=0.9, type=float)
    parser.add_argument('--decay_point', default=1000, type=int)
    parser.add_argument('--task_BATCH_SIZE', default=10, type=int)
    parser.add_argument('--method', default='ct', type=str)


    global args
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    set_all_seeds(args.SEED)

    ### automatic dataloading and splitting
    dataset = PygGraphPropPredDataset(name = args.dataset)

    if args.feature == 'full':
        pass
    elif args.feature == 'simple':
        print('using simple feature')
        # only retain the top two node/edge features
        dataset.data.x = dataset.data.x[:,:2]
        dataset.data.edge_attr = dataset.data.edge_attr[:,:2]

    split_idx = dataset.get_idx_split()

    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

    ###### Compute imratio
    labels = train_loader.dataset.data.y[split_idx["train"]]

    imratio = []
    for i in range(128):
        nonzero_counts = torch.count_nonzero(labels[:, i])
        nan_counts = torch.isnan(labels[:, i]).sum()
        imratio.append((nonzero_counts-nan_counts) / len(labels))

    if args.gnn == 'gin':
        model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    else:
        raise ValueError('Invalid GNN type')


    valid_curve = []
    test_curve = []
    train_curve = []

    ## Initials
    global alpha, a, b, z_a, z_b, z_w_list

    alpha = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    a = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    b = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    z_a = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    z_b = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    z_w_list = []
    for (name, w) in model.named_parameters():
        z_w_list.append(torch.zeros_like(w))

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')

        if epoch>=(args.decay_point-1):
            lr_decay = 10
        else:
            lr_decay = 1
        train(model, device, train_loader, dataset.task_type, imratio, lr_decay=lr_decay)

        print('Evaluating...')
        train_perf = eval(model, device, train_loader)
        valid_perf = eval(model, device, valid_loader)
        test_perf = eval(model, device, test_loader)

        print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})

        train_curve.append(train_perf)
        valid_curve.append(valid_perf)
        test_curve.append(test_perf)

    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)

    print('TEST_AUC = ')
    print(test_curve)
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))

    if not args.filename == '':
        torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch], 'Train': train_curve[best_val_epoch], 'BestTrain': best_train}, args.filename)

if __name__ == "__main__":
    main()