from datasets.utils import *
from modelZoo.utils import get_model
from utils import *
import numpy as np  # v 1.19.2
import matplotlib.pyplot as plt  # v 3.3.2
from matplotlib.lines import Line2D
import seaborn as sns
import pandas as pd
import os


def plot_dists(axes, names, address, score_name, model_name, inverse_scores=False, legend_loc="upper left"):
    with open(f'{address}', 'rb') as f:
        in_dist = np.load(f)
        out_dist = np.load(f)
        in_dist_attacked = np.load(f)
        out_dist_attacked = np.load(f)
    print(in_dist.shape, out_dist.shape)
    if inverse_scores:
        in_dist, out_dist, in_dist_attacked, out_dist_attacked = in_dist * -1, out_dist * -1, in_dist_attacked * -1, out_dist_attacked * -1
    font_size = 10
    linewidth = 2
    markersize = 7

    in_dist_vals, in_dist_bins = np.histogram(in_dist, bins=51)
    out_dist_vals, out_dist_bins = np.histogram(out_dist, bins=51)
    in_dist_attacked_vals, in_dist_attacked_bins = np.histogram(in_dist_attacked, bins=51)
    out_dist_attacked_vals, out_dist_attacked_bins = np.histogram(out_dist_attacked, bins=51)

    in_dist_bin_centers = (in_dist_bins[1:] + in_dist_bins[:-1]) / 2.0
    out_dist_bin_centers = (out_dist_bins[1:] + out_dist_bins[:-1]) / 2.0
    in_dist_attacked_bin_centers = (in_dist_attacked_bins[1:] + in_dist_attacked_bins[:-1]) / 2.0
    out_dist_attacked_bin_centers = (out_dist_attacked_bins[1:] + out_dist_attacked_bins[:-1]) / 2.0

    axes[1].plot(in_dist_bin_centers, in_dist_vals, linewidth=linewidth, color="blue", marker="")
    axes[1].plot(out_dist_attacked_bin_centers, out_dist_attacked_vals, linewidth=linewidth, color="red", marker="*",
                 markevery=0.1, markersize=markersize)
    axes[1].set_title(f"{names[1]}: {model_name}(attack out data)", fontsize=font_size)
    axes[1].legend(labels=["standard in-distribution", "adversarial out-distribution"], fontsize=font_size,
                   loc=legend_loc)
    axes[1].set_xlabel(score_name, fontsize=9)
    axes[1].set_ylabel('count', fontsize=9)
    axes[1].tick_params(labelleft=True)

    axes[0].plot(in_dist_bin_centers, in_dist_vals, linewidth=linewidth, color="blue")
    axes[0].plot(out_dist_bin_centers, out_dist_vals, linewidth=linewidth, color="red")
    axes[0].set_title(f"{names[0]}: {model_name}(without attack)", fontsize=font_size)
    axes[0].legend(labels=["standard in-distribution", "standard out-distribution"], fontsize=font_size, loc=legend_loc)
    axes[0].set_xlabel(score_name, fontsize=9)
    axes[0].set_ylabel('count', fontsize=9)
    axes[0].tick_params(labelleft=True)

    axes[2].plot(in_dist_attacked_bin_centers, in_dist_attacked_vals, linewidth=linewidth, color="blue", marker="*",
                 markevery=0.1, markersize=markersize)
    axes[2].plot(out_dist_bin_centers, out_dist_vals, linewidth=linewidth, color="red", marker="")
    axes[2].set_title(f"{names[2]}: {model_name}(attack in data)", fontsize=font_size)
    axes[2].legend(labels=["adversarial in-distribution", "standard out-distribution"], fontsize=font_size,
                   loc=legend_loc)
    axes[2].set_xlabel(score_name, fontsize=9)
    axes[2].set_ylabel('count', fontsize=9)
    axes[2].tick_params(labelleft=True)


def attack_and_save_for_fig2(OOD_model, device, testloader_in, testloader_out, epsilon, alpha, attack_iters, restarts,
                             file_address):
    name_in, testloader_in = testloader_in
    name_out, testloader_out = testloader_out

    in_dist, out_dist, in_dist_attacked, out_dist_attacked = [], [], [], []
    total_batches = len(testloader_in)

    for step, data in enumerate(tqdm(testloader_in, total=total_batches, leave=False)):
        if step == total_batches:
            break
        images, labels = data
        images = images.to(device)
        delta = attack_pgd_ood_detection(OOD_model, images, torch.zeros(images.shape[0], device=device),
                                         epsilon, alpha,
                                         attack_iters,
                                         restarts, "l_inf", FGSM=False)

        adv_dist = OOD_model(images + delta).detach()
        in_dist_attacked.append(adv_dist.cpu())

        dist = OOD_model(images).detach()
        in_dist.append(dist.cpu())

    total_batches = len(testloader_out)
    for step, data in enumerate(tqdm(testloader_out, total=total_batches, leave=False)):
        if step == total_batches:
            break
        images, labels = data
        images = images.to(device)
        delta = attack_pgd_ood_detection(OOD_model, images, torch.ones(images.shape[0], device=device),
                                         epsilon,
                                         alpha,
                                         attack_iters,
                                         restarts, "l_inf")
        adv_dist = OOD_model(images + delta).detach()
        out_dist_attacked.append(adv_dist.cpu())

        dist = OOD_model(images).detach()
        out_dist.append(dist.cpu())

    in_dist = torch.cat(in_dist).cpu().numpy()
    out_dist = torch.cat(out_dist).cpu().numpy()
    in_dist_attacked = torch.cat(in_dist_attacked).cpu().numpy()
    out_dist_attacked = torch.cat(out_dist_attacked).cpu().numpy()


    with open(file_address, 'wb') as f:
        np.save(f, in_dist)
        np.save(f, out_dist)
        np.save(f, in_dist_attacked)
        np.save(f, out_dist_attacked)
    return


if __name__ == "__main__":
    compute = True
    if compute:
        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        attack_iters = 10
        dataset_name_in = "cifar10"
        dataset_name_out = "tiny_imagenet"
        dataloader_train_in = get_trainloader_cifar10()

        models = []
        models.append(
             OOD_openMax(get_model("open-set", dataset_name_in, device), dataloader_train_in, device))
        models.append(OOD_MSP(get_model("ALOE", dataset_name_in, device)))

        epsilon = 8 / 255

        l = get_outdist_dataloaders([dataset_name_in, dataset_name_out])
        alpha = (epsilon / attack_iters) * 2.5

        attack_and_save_for_fig2(models[0], device, l[0], l[1], epsilon, alpha, attack_iters, 1, "OSAD.npy")
        attack_and_save_for_fig2(models[1], device, l[0], l[1], epsilon, alpha, attack_iters, 1, "ALOEMSP.npy")

    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(6 * 3, 8), sharey="row")
    axes = axes.flatten()

    plot_dists(axes[:3], ['A', 'B', 'C'], address="ALOEMSP.npy", score_name="MSP score", model_name="ALOE",
               inverse_scores=True)
    plot_dists(axes[3:], ['D', 'E', 'F'], address="OSAD.npy", score_name="OpenMax score", model_name="OSAD")

    plt.tight_layout()
    plt.savefig('fig2.png')