import random

import libmr
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score

from pgd_attack import *
from OOD_detection import *
import torch
from tqdm import tqdm
import torchvision
from PIL import Image

@torch.no_grad()
def get_clean_accuracy_on_test(model, device, testloader_in, total_batches=15):
    name_in, testloader_in = testloader_in
    if total_batches is None:
        total_batches_in = len(testloader_in)
    else:
        total_batches_in = total_batches

    correct = 0
    total = 0
    for step, data in enumerate(tqdm(testloader_in, total=total_batches_in, leave=False)):
        if step == total_batches_in:
            break
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on clean %s: %.3f %%' % (
        name_in, 100 * correct / total))


def get_adversarial_accuracy_on_test(model, device, testloader_in, epsilon, alpha, attack_iters,
                                     restarts, total_batches=15):
    name_in, testloader_in = testloader_in

    correct = 0
    total = 0
    for step, data in enumerate(tqdm(testloader_in, total=total_batches, leave=False)):
        if step == total_batches:
            break
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        # print(images.shape, labels.shape)
        delta = attack_pgd_classification(model, images, labels, epsilon, alpha, attack_iters, restarts, "l_inf")
        outputs = model(images + delta)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on %s with adversarial examples, epsilon=%.4f : %.3f %%' % (
        name_in, epsilon, 100 * correct / total))

def save_image(clean_image, adv_image, clean_score, adv_score, title):
    clean_image = (clean_image.permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8)
    adv_image = (adv_image.permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8)
    clean_score, adv_score = str(clean_score.cpu().item()).replace(".", "_"), str(adv_score.cpu().item()).replace(".","_")

    im = Image.fromarray(clean_image)
    im.save(f"results/images/{title}-{clean_score}-{adv_score}C.png")
    im = Image.fromarray(adv_image)
    im.save(f"results/images/{title}-{clean_score}-{adv_score}A.png")

def attack_and_plot(OOD_model, device, testloader_in, testloader_out, epsilon, alpha, attack_iters, restarts,
                    total_batches=15, FGSM=False, plot_images=0, print_auc=True, in_dist_help=None, eps_initialize=None):

    name_in, testloader_in = testloader_in
    name_out, testloader_out = testloader_out

    in_dist, out_dist, in_dist_attacked, out_dist_attacked = [], [], [], []
    if in_dist_help is None:
        for step, data in enumerate(tqdm(testloader_in, total=total_batches, leave=False)):
            if step == total_batches:
                break
            images, labels = data
            images = images.to(device)
            delta = attack_pgd_ood_detection(OOD_model, images, torch.zeros(images.shape[0], device=device),
                                             epsilon, alpha,
                                             attack_iters,
                                             restarts, "l_inf", FGSM=FGSM, eps_initialize=eps_initialize)

            adv_dist = OOD_model(images + delta).detach()
            in_dist_attacked.append(adv_dist.cpu())

            dist = OOD_model(images).detach()
            in_dist.append(dist.cpu())

            if step == 0 and plot_images > 0:
                for i in range(plot_images):
                    save_image(images[i], (images + delta)[i], dist[i], adv_dist[i], name_in)
        in_dist = torch.cat(in_dist).cpu().numpy()
        in_dist_attacked = torch.cat(in_dist_attacked).cpu().numpy()

    else:
        in_dist, in_dist_attacked = in_dist_help

    for step, data in enumerate(tqdm(testloader_out, total=total_batches, leave=False)):
        if step == total_batches:
            break
        images, labels = data
        images = images.to(device)
        delta = attack_pgd_ood_detection(OOD_model, images, torch.ones(images.shape[0], device=device),
                                         epsilon,
                                         alpha,
                                         attack_iters,
                                         restarts, "l_inf")
        adv_dist = OOD_model(images + delta).detach()
        out_dist_attacked.append(adv_dist.cpu())

        dist = OOD_model(images).detach()
        out_dist.append(dist.cpu())

        if step == 0 and plot_images > 0:
            for i in range(plot_images):
                save_image(images[i], (images + delta)[i], dist[i], adv_dist[i], name_out)

    out_dist = torch.cat(out_dist).cpu().numpy()
    out_dist_attacked = torch.cat(out_dist_attacked).cpu().numpy()

    ######################## indist(clean)   outdist(clean)   ############
    onehots = np.array([1] * out_dist.shape[0] + [0] * in_dist.shape[0])
    scores = np.concatenate([out_dist, in_dist], axis=0)
    auroc_clean =roc_auc_score(onehots, scores)

    if print_auc:
        print(auroc_clean)
    ######################## indist(attacked)  outdist(clean)   ############
    onehots = np.array([1] * out_dist.shape[0] + [0] * in_dist_attacked.shape[0])
    scores = np.concatenate([out_dist, in_dist_attacked], axis=0)
    auroc_in = roc_auc_score(onehots, scores)
    if print_auc:
        print(auroc_in)
    ######################## indist(clean)   outdist(attacked)   ############
    onehots = np.array([1] * out_dist_attacked.shape[0] + [0] * in_dist.shape[0])
    scores = np.concatenate([out_dist_attacked, in_dist], axis=0)
    auroc_out = roc_auc_score(onehots, scores)
    if print_auc:
        print(auroc_out)
    ######################## indist(attacked)   outdist(attacked)   ############
    onehots = np.array([1] * out_dist_attacked.shape[0] + [0] * in_dist_attacked.shape[0])
    scores = np.concatenate([out_dist_attacked, in_dist_attacked], axis=0)
    auroc_in_out = roc_auc_score(onehots, scores)

    if print_auc:
        print(auroc_in_out)

    return auroc_clean, auroc_in, auroc_out, auroc_in_out, (in_dist, in_dist_attacked)
