#%%
from pathlib import Path
from argparse import ArgumentParser

import torch
import torch.nn.parallel
from torch import nn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torch.nn.functional as F

from tqdm import tqdm

import sys
sys.path.append('.')
from experiments.imagenet_discrete import test_loaders, load_model
from experiments.imagenet_ddu import get_train_loader

from image_uncertainty.models.duq import (
    MultiLinearCentroids, LinearCentroids, benchmark
)
from image_uncertainty.utils.evaluate_ood import get_auroc_ood_dl


def parse_args():
    zhores_ind = "/gpfs/gpfs0/datasets/ImageNet/ILSVRC2012/val"
    zhores_ood = '/gpfs/gpfs0/k.fedyanin/space/imagenet_o'

    parser = ArgumentParser()
    parser.add_argument('--net', type=str, default='resnet50')
    parser.add_argument('-b', type=int, default=32)
    parser.add_argument('--data-folder', type=str, default=zhores_ind)
    parser.add_argument('--ood-folder', type=str, default=zhores_ood)
    parser.add_argument('--ood-name', type=str, default='imagenet-o')
    parser.add_argument('--gamma', type=float, default=0.999)
    parser.add_argument('--length-scale', type=float, default=1.0)
    parser.add_argument('--epochs', type=int, default=4)
    parser.add_argument('--subsample', action='store_true', default=False)
    parser.add_argument('--pretrained-head', action='store_true', default=False)
    parser.add_argument(
        "--architecture",
        default="linear",
        choices=["linear", "multilinear"],
        help="Pick an duq variant (default: linear)",
    )

    args = parser.parse_args()
    args.gpu = True
    return args


if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

IMAGENET_O_FOLDER = '/gpfs/gpfs0/k.fedyanin/space/imagenet_o'
IMAGENET_R_FOLDER = '/gpfs/gpfs0/k.fedyanin/space/imagenet_r'


def main():
    args = parse_args()

    train_loader = get_train_loader(batch_size=args.b, subsample=args.subsample)
    val_loader, ood_loader_o = test_loaders(args.data_folder, IMAGENET_O_FOLDER, args.b, subsample=args.subsample)
    _, ood_loader_r = test_loaders(args.data_folder, IMAGENET_R_FOLDER, args.b, subsample=args.subsample)

    args.dir_to_save = Path('checkpoint')

    feature_extractor = load_model(args.net)
    feature_extractor.linear = nn.Identity()
    feature_extractor.eval()

    gamma = args.gamma
    model_output_size = 2048
    centroid_size = 64
    batch_size = args.b
    length_scale = args.length_scale
    epochs = args.epochs
    milestones = [1, 5, 10]
    learning_rate = 3e-5
    weight_decay = 1e-4
    num_classes = 1000

    """
    python experiments/imagenet_duq.py --net='spectral' -b=64 --length-scale=2 --gamma=0.9995
    python experiments/imagenet_duq.py --net='spectral' -b=4 --length-scale=2 --gamma=0.9995 --architecture=multilinear
    """

    if args.architecture == 'linear':
        klass = LinearCentroids
    else:
        klass = MultiLinearCentroids


    if args.pretrained_head:
        dull_extractor = nn.Identity()
        model = klass(
            num_classes=num_classes,
            gamma=gamma,
            embedding_size=model_output_size,
            features=centroid_size,
            feature_extractor=dull_extractor,
            batch_size=batch_size,
            sigma=length_scale
        )
        model.load_state_dict(torch.load(
            f'experiments/checkpoint/duq/imagenet_{args.architecture}_head.pth'
        ))
        model.feature_extractor = feature_extractor
        model.eval()
    else:
        model = klass(
            num_classes=num_classes,
            gamma=gamma,
            embedding_size=model_output_size,
            features=centroid_size,
            feature_extractor=feature_extractor,
            batch_size=batch_size,
            sigma=length_scale
        )

    model = model.cuda()

    optimizer = torch.optim.SGD(
        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay
    )

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.2
    )

    batch_grad = 1
    # batch_grad = 20

    optimizer.zero_grad()
    for e in range(epochs):
        for i, (x, y) in enumerate(tqdm(train_loader)):
            model.train()
            if i % batch_grad == 0:
                optimizer.zero_grad()
            x, y = x.cuda(), y.cuda()

            x.requires_grad_(True)

            y_pred = model(x)

            y = F.one_hot(y, num_classes).float()

            loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

            loss.backward()
            if (i+1) % batch_grad == 0:
                optimizer.step()

            # x.requires_grad_(False)
            with torch.no_grad():
                model.eval()
                model.update_embeddings(x, y)

            if (i+1) % 500 == 0:
                benchmark(val_loader, model, e, loss.item(), num_classes)
                accuracy, auroc = get_auroc_ood_dl(val_loader, ood_loader_o, model)
                print('OOD Imagenet-o', accuracy, auroc)
        scheduler.step()
        accuracy, auroc = get_auroc_ood_dl(val_loader, ood_loader_r, model)
        print('OOD Imagenet-r', accuracy, auroc)


    # for ood_name in ['svhn', 'lsun', 'smooth']:
    #     ood_loader = get_test_dataloader(batch_size=batch_size, ood=True, ood_name=ood_name)
    #     accuracy, auroc = get_auroc_ood_dl(val_loader, ood_loader, model)
    #     print(f'OOD {ood_name} acc: {accuracy:.3f}, {auroc:.3f}')


if __name__ == '__main__':
    main()
