import random
import pickle
import numpy as np
from tqdm import tqdm
from PIL import Image
import open_clip
import pandas as pd
import wandb
import copy
import torch
import torchvision
import torch.nn as nn
from torch import optim
import torchxrayvision as xrv
from argparse import ArgumentParser
from torch.utils.data import DataLoader, Dataset
from models import DenseNetE2E, ViTE2E

random.seed(42)
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

def get_image_dir(modality, dataset_name):
    if modality == "xray":
        image_dir = "../data/datasets/"

    elif modality == "skin":
        if "ISIC" in dataset_name: image_dir = "../data/datasets/isic/images/"
        elif "PAD" in dataset_name: image_dir = "../data/datasets/PAD-UFES-20/images/"
        elif dataset_name == "HAM10000": image_dir = "../data/datasets/HAM10000/images/"
        elif dataset_name == "Melanoma": image_dir = "../data/datasets/Melanoma/"
        elif dataset_name == "UWaterloo": image_dir = "../data/datasets/UWaterloo/"
        else: image_dir = "../data/datasets/isic/images/"
    
    return image_dir


class ImageDataset():
    def __init__(self, class2images, preprocess, image_dir, class2label):
        # Initialize image paths and corresponding texts
        self.preprocess = preprocess
        self.image_paths = []
        self.labels = []
        self.images = []
        
        for class_name, images in class2images.items():
            for image in images:
                self.image_paths.append(f"{image_dir}{image}")
                self.labels.append(class2label[class_name])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.preprocess(Image.open(self.image_paths[idx]))
        label = self.labels[idx]
        return image, label


transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
def densenet_preprocess(image):
    image = image.convert("RGB")
    img = np.array(image)
    img = xrv.datasets.normalize(img, 255)
    img = img.mean(2)[None, ...]
    img = transform(img)
    img = torch.from_numpy(img)
    return img


def train_model(model, train_loader, val_loader, num_epochs, lr):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_acc = -float("inf")
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for i, (images, labels) in enumerate(train_loader):
            images = images.type(torch.float32).to(device)
            labels = labels.type(torch.LongTensor).to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            wandb.log({"train_loss": loss.item(), "epoch": epoch, "step": epoch * len(train_loader) + i})

        val_acc = eval_model(model, val_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / len(train_loader)}, Val Acc: {val_acc}")
        wandb.log({"val_acc": val_acc, "epoch": epoch})

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = copy.deepcopy(model)

    return best_model


def eval_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    
    return accuracy


def run_dataset(modality, model_name, base_model, dataset_name, batch_size, num_epochs, lr):
    wandb.init(project="e2e_classification", 
               name=f"{modality}_{model_name}_{dataset_name}_{batch_size}_{num_epochs}_{lr}",
               config={
                "modality": modality,
                "batch_size": batch_size,
                "epochs": num_epochs,
                "model_name": model_name,
                "dataset_name": dataset_name,
                "lr": lr}
                )
                
    print(f"Running {modality} {model_name} on {dataset_name}")
    dataset_dir = f"../data/datasets/{dataset_name}"
    class2images_train = pickle.load(open(f"{dataset_dir}/splits/class2images_train.p", "rb"))
    class2images_val = pickle.load(open(f"{dataset_dir}/splits/class2images_val.p", "rb"))
    class2images_test = pickle.load(open(f"{dataset_dir}/splits/class2images_test.p", "rb"))

    class2label = {class_name: i for i, class_name in enumerate(class2images_train.keys())}
    image_dir = get_image_dir(modality, dataset_name)

    dataset_train = ImageDataset(class2images_train, preprocess, image_dir, class2label)
    dataset_test = ImageDataset(class2images_test, preprocess, image_dir, class2label)
    dataset_val = ImageDataset(class2images_val, preprocess, image_dir, class2label)

    train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4)

    if model_name == "vit": model = ViTE2E(base_model, len(class2label))
    elif model_name == "densenet": model = DenseNetE2E(base_model, len(class2label))
    model.to(device)

    best_model = train_model(model, train_loader, val_loader, num_epochs, lr)

    # Evaluate the model
    val_acc = eval_model(best_model, val_loader)
    ood_acc = eval_model(best_model, test_loader)

    average_acc = round((val_acc + ood_acc) / 2, 2)
    gap = round(abs(val_acc - ood_acc), 2)

    print(f"Ind Acc: {val_acc}, OOD Acc: {ood_acc}, Gap: {gap}, Average: {average_acc}")

    # close wandb
    wandb.finish()
    return val_acc, ood_acc, gap, average_acc, best_model


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--modality", type=str, default="xray, skin")
    parser.add_argument("--model_name", type=str, default="vit, densenet")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=1e-6)
    args = parser.parse_args()
    
    modality = args.modality
    model_name = args.model_name
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    lr = args.lr

    if model_name == "vit":
        if modality == "xray":
            base_model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained="../data/mimic_cxr/clip_model/whyxrayclip.pt")
        elif modality == "skin":
            base_model, _, preprocess = open_clip.create_model_and_transforms("ViT-L-14", pretrained="../data/isic/clip_model/whylesionclip.pt")

    elif model_name == "densenet":
        preprocess = densenet_preprocess
        if modality == "xray":
            base_model = xrv.models.DenseNet(weights="densenet121-res224-mimic_nb")
        elif modality == "skin":
            model = DenseNetE2E(xrv.models.DenseNet(), 9)
            weight = torch.load(f"../data/isic/densenet_skin.pt")
            model.load_state_dict(weight)
            base_model = model.denset_model
    
    if modality == "xray":
        dataset_lists = ["NIH-gender", "NIH-age", "NIH-pos", "CheXpert-race", "NIH-CheXpert", "pneumonia", "COVID-QU", "NIH-CXR", "open-i", "vindr-cxr"]
    elif modality == "skin":
        dataset_lists = ["ISIC-gender", "ISIC-age", "ISIC-site", "ISIC-color", "ISIC-hospital", "HAM10000", "BCN20000", "PAD-UFES-20", "Melanoma", "UWaterloo"]
    
    results_dict = {}
    for dataset_name in dataset_lists:
        ind_acc, out_acc, gap, avg, best_model = run_dataset(modality, model_name, base_model, dataset_name, batch_size, num_epochs, lr)
        results_dict[dataset_name] = {"ind_acc": ind_acc, "out_acc": out_acc, "gap": gap, "avg": avg}
    
    # reshape to one row
    csv_df = pd.DataFrame.from_dict(results_dict).T

    # Reshape the data
    reshaped_df = pd.DataFrame(csv_df.values.flatten()).T

    # Create new column names
    new_columns = [f"{row_label}_{col_label}" for row_label in csv_df.index for col_label in csv_df.columns]

    # Assign new column names to reshaped dataframe
    reshaped_df.columns = new_columns

    reshaped_df.to_csv(f"../data/end2end_results/{model_name}_{modality}.csv")