import os
import json
import torch
import random
import pickle
import numpy as np
from tqdm import tqdm
from argparse import ArgumentParser
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

random.seed(0)

def get_mapping(modality, annotator, question, metadata, features, max_examples, train_samples):
    annotation_path = f"../data/binary_{modality}/annotations_{annotator}/{question}"
    # check if there are positive and negative txt files
    if os.path.exists(annotation_path + "/positive.txt") and os.path.exists(annotation_path + "/negative.txt"):
        positive = open(annotation_path + "/positive.txt", "r").read().split("\n")
        negative = open(annotation_path + "/negative.txt", "r").read().split("\n")
    else:
        positive = []
        negative = []
        for file in os.listdir(annotation_path):
            report_id = file.split(".")[0]
            with open(f"{annotation_path}/{file}", "r") as f:
                answer = f.read().strip()
            if answer == "yes": positive.append(report_id)
            elif answer == "no": negative.append(report_id)
    
    positive_images = []
    negative_images = []

    if modality == "xray":
        for report_id in positive:
            images = metadata[report_id]["images"]
            for image, image_type in images:
                if image_type in ["AP", "PA"] and image in features:
                    positive_images.append(image)
        
        for report_id in negative:
            images = metadata[report_id]["images"]
            for image, image_type in images:
                if image_type in ["AP", "PA"] and image in features:
                    negative_images.append(image)
    
    elif modality == "skin":
        for report_id in positive:
            images = metadata[report_id]["images"]
            for image in images:
                if image in features:
                    positive_images.append(image)
        
        for report_id in negative:
            images = metadata[report_id]["images"]
            for image in images:
                if image in features:
                    negative_images.append(image)

    random.seed(0)
    random.shuffle(positive_images)
    random.shuffle(negative_images)

    # equally add positive and negative examples up to max_examples
    if len(positive_images) > len(negative_images):
        negative_images_selected = negative_images[:min(len(negative_images), max_examples//2)]
        positive_images_selected = positive_images[:max_examples - len(negative_images_selected)]
    else:
        positive_images_selected = positive_images[:min(len(positive_images), max_examples//2)]
        negative_images_selected = negative_images[:max_examples - len(positive_images_selected)]
    
    val_len = min(int(0.1*min(len(positive_images_selected), len(negative_images_selected))), 50)

    if val_len < 10:
        print(f"Test length too small for {question}. Skipping ...")
        return False

    mapping = {'0': {}, '1': {}}
    positive_train, positive_val = train_test_split(positive_images_selected, test_size=val_len, random_state=0)
    negative_train, negative_val = train_test_split(negative_images_selected, test_size=val_len, random_state=0)

    mapping['1']['train'] = positive_train[:int(train_samples*0.5)]
    mapping['1']['val'] = positive_val
    mapping['0']['train'] = negative_train[:(train_samples - len(mapping['1']['train']))]
    mapping['0']['val'] = negative_val

    return mapping


def search_hyperparameters(train_features, val_features, train_labels, val_labels, iterations):
    # Initial grid search setup
    search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6]
    acc_list = []
    
    # Grid search to initialize the best C
    for c_weight in search_list:
        clf = LogisticRegression(max_iter=1000, C=c_weight, class_weight='balanced', n_jobs=16)
        clf.fit(train_features, train_labels)
        pred = clf.predict(val_features)
        acc_val = np.mean(pred == val_labels)
        acc_list.append(acc_val)
    
    # Find the index of the best initial C
    best_idx = np.argmax(acc_list)
    best_c = search_list[best_idx]

    # Set initial range for binary search
    c_left, c_right = best_c * 0.1, best_c * 10

    # Binary search to refine the best C
    for _ in range(iterations):
        # Train with current left and right C values
        clf_left = LogisticRegression(max_iter=1000, C=c_left, class_weight='balanced', n_jobs=16)
        clf_left.fit(train_features, train_labels)
        acc_left = np.mean(clf_left.predict(val_features) == val_labels)

        clf_right = LogisticRegression(max_iter=1000, C=c_right, class_weight='balanced', n_jobs=16)
        clf_right.fit(train_features, train_labels)
        acc_right = np.mean(clf_right.predict(val_features) == val_labels)

        # Adjust the search range based on comparison
        if acc_left < acc_right:
            c_left = np.sqrt(c_left * c_right)  # geometric mean
        else:
            c_right = np.sqrt(c_left * c_right)  # geometric mean

    # Final model with the refined best C
    final_c = (c_left + c_right) / 2
    final_clf = LogisticRegression(max_iter=1000, C=final_c, class_weight='balanced', n_jobs=16)
    final_clf.fit(train_features, train_labels)
    final_acc = np.mean(final_clf.predict(val_features) == val_labels)
    
    return final_c, final_acc


def train_binary_model(question, mapping, features, save_path):
    model_save_path = f"{save_path}/{question}"

    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    
    positive_data_train, negative_data_train = mapping['1']['train'], mapping['0']['train']
    positive_data_val, negative_data_val = mapping['1']['val'], mapping['0']['val']

    # downsample to keep the training data balanced
    random.seed(0)
    if len(positive_data_train) > len(negative_data_train): positive_data_train = random.sample(positive_data_train, len(negative_data_train))
    else: negative_data_train = random.sample(negative_data_train, len(positive_data_train))

    print(f"Positive train: {len(positive_data_train)}, Negative train: {len(negative_data_train)}")
    print(f"Positive val: {len(positive_data_val)}, Negative val: {len(negative_data_val)}")

    train_features = [features[image].cpu() for image in positive_data_train] + [features[image].cpu() for image in negative_data_train]
    train_labels = [1]*len(positive_data_train) + [0]*len(negative_data_train)

    val_features = [features[image] for image in positive_data_val] + [features[image] for image in negative_data_val]
    val_labels = [1]*len(positive_data_val) + [0]*len(negative_data_val)

    train_features = torch.stack(train_features).cpu()
    val_features = torch.stack(val_features).cpu()

    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)

    model = LogisticRegression(max_iter=1000, class_weight='balanced', n_jobs=16)
    model.fit(train_features, train_labels)

    train_score = model.score(train_features, train_labels)
    val_score = model.score(val_features, val_labels)

    print(f"Saving model for {question} ...")
    print(f"Train score: {train_score}")
    print(f"Test score: {val_score}")

    with open(f"{model_save_path}/{question}_results.txt", 'w') as f:
        f.write(f"{train_score},{val_score}")

    # save model
    pickle.dump(model, open(f"{model_save_path}/{question}.p", 'wb'))


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--modality", type=str, default="xray")
    parser.add_argument("--annotator", type=str, default="gpt4")
    parser.add_argument("--model_name", type=str, default="whyxrayclip")
    parser.add_argument("--normalize", type=bool, default=True)
    parser.add_argument("--question_type", type=str, default="findings")
    parser.add_argument("--max_examples", type=int, default=10000)
    parser.add_argument("--train_samples", type=int, default=2000)

    args = parser.parse_args()

    model_name = args.model_name
    modality = args.modality
    annotator = args.annotator

    save_path = f"../data/binary_{modality}/models_{model_name}_{annotator}_{args.train_samples}_sklearn"
    if not os.path.exists(save_path): os.makedirs(save_path)

    print("Loading features ...")
    if modality == "xray":
        features = torch.load(f'../data/mimic_cxr/mimic_cxr_{model_name}.pt')
        metadata = json.load(open('../data/mimic_cxr/mimic_data.json', 'r'))

    elif modality == "skin":
        features = torch.load(f'../data/isic/isic_{model_name}.pt')
        metadata = json.load(open('../data/isic/isic_data.json', 'r'))
    
    if args.normalize:
        print("Normalizing features ...")
        for key in features:
            features[key] /= features[key].norm(dim=-1, keepdim=True)

    all_questions = []
    question_type = args.question_type
    with open(f"../data/binary_{modality}/questions/{question_type}.txt", "r") as f:
        questions = f.read().strip().split("\n")
        all_questions += questions
    
    for question in tqdm(all_questions):
        if os.path.exists(f"{save_path}/{question}/{question}_results.txt"):
            print(f"Model for {question} already exists. Skipping ...")
            continue
         
        mapping = get_mapping(modality, annotator, question, metadata, features, args.max_examples, args.train_samples)

        if mapping != False:
            train_binary_model(
                question=question,
                mapping=mapping,
                features=features,
                save_path=save_path
            )