from datasets.utils import *
from modelZoo.utils import get_model
from utils import *

if __name__ == "__main__":
    model_name = "Rade2021Helper_R18_extra"

    image_size = 32
    batch_size = 128

    seed = 2022
    dataset_name = "cifar10"
    dataloader_train_in = get_trainloader_cifar10(image_size=image_size, batch_size=batch_size)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    try:
        os.mkdir("results")
    except:
        pass
    try:
        os.mkdir("results/images")
    except:
        pass

    l = get_outdist_dataloaders(
        ["cifar10", 'mnist', 'tiny_imagenet', 'places365', 'LSUN', 'iSUN', 'birds', 'flowers', 'coil_100'],
        image_size=image_size, batch_size=batch_size)
    dataset_name = "cifar10"

    model = get_model(model_name, dataset_name, device)
    OOD_model = OOD_MSP(model)
    for out_dataset in l[1:]:
        attack_and_plot(OOD_model, device, l[0], out_dataset, 8 / 255, ((8 / 255) / 10) * 2.5, 10, 1,
                        total_batches=1, plot_images=32)
