# download packages
#!pip install transformers==4.8.2

# import packages
import re
import torch
import random
import pandas as pd
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from transformers import (
    TrainingArguments,
    Trainer,
    GPTNeoForCausalLM,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    AutoModel,
    PreTrainedModel,
)
from transformers.utils import logging
import argparse

logging.set_verbosity(40)

import pickle
import os
import sklearn
from torch import nn
from tqdm import tqdm


class ClassificationDataset(Dataset):
    def __init__(
        self,
        txt_list,
        label_list,
        tokenizer,
        max_length,
        is_train=True,
        is_lr_solver=False,
        model_name="EleutherAI/gpt-neo-125M",
        generate_embeddings=False,
        embed_type="mean",  # "last"
        text_threshold=100,
        model_type=None,
    ):
        # add assert function if model_type is NOne
        assert model_type is not None

        # define variables
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        self.text = []
        self.label_tokens = []
        self.embeddings = []
        self.numeric_labels = []
        map_label = {0: "negative", 1: "positive"}
        # iterate through the dataset
        if "pythia" in model_name:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
            ).cuda()
        else:
            model = GPTNeoForCausalLM.from_pretrained(
                model_name,
            ).cuda()

        model.resize_token_embeddings(len(tokenizer))
        positive_token_id_space = tokenizer(" positive").input_ids[0]
        negative_token_id_space = tokenizer(" negative").input_ids[0]

        for txt, label in zip(txt_list, label_list):
            # prepare the text
            self.text.append(txt.strip()[0:text_threshold])
            if is_train:
                if model_type == "ft_adaptor":
                    print("ft_adaptor dataset")
                    prep_txt = f"Sentence: {txt.strip()[0:text_threshold]}\nLabel:"
                else:
                    prep_txt = f"Sentence: {txt.strip()[0:text_threshold]}\nLabel: {map_label[label]}"
                encodings_dict = tokenizer(
                    prep_txt,
                    truncation=True,
                    max_length=max_length,
                    padding="max_length",
                )
            elif is_lr_solver:
                prep_txt = f"Sentence: {txt.strip()[0:text_threshold]}"
                encodings_dict = tokenizer(
                    [prep_txt],
                    truncation=True,
                )
            else:
                prep_txt = f"Sentence: {txt.strip()[0:text_threshold]}\nLabel:"
                if model_type == "ft_adaptor":
                    encodings_dict = tokenizer(
                        [prep_txt],
                        truncation=True,
                        max_length=max_length,
                        padding="max_length",
                    )
                else:
                    encodings_dict = tokenizer(
                        [prep_txt],
                    )
            # append to list
            self.input_ids.append(torch.tensor(encodings_dict["input_ids"]))
            self.attn_masks.append(torch.tensor(encodings_dict["attention_mask"]))
            # where encodings_dict["input_ids"] is a positive token
            inp = torch.tensor(encodings_dict["input_ids"])
            label_tokens = torch.where(
                (inp == positive_token_id_space) | (inp == negative_token_id_space),
                inp,
                torch.tensor(-100),
            )
            self.label_tokens.append(label_tokens)

            self.labels.append(map_label[label])
            self.numeric_labels.append(torch.Tensor([label]))
            if generate_embeddings:
                embeds = model(
                    input_ids=torch.tensor(encodings_dict["input_ids"]).cuda(),
                    return_dict=True,
                    output_hidden_states=True,
                ).hidden_states[-1]
                embeds = embeds.detach().cpu()
                # embedding (batch_size, sequence_length, hidden_size) get the hidden representation of the last sequence item
                if embed_type == "mean":
                    embed_mean = torch.mean(embeds, dim=1)
                    self.embeddings.append(embed_mean.squeeze(0))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return (
            self.input_ids[idx],
            self.attn_masks[idx],
            self.label_tokens[idx],
            self.numeric_labels[idx],
        )

    def __gettext__(self, idx):
        return self.text[idx]

    def __getembedding__(self, idx):
        return self.embeddings[idx].detach()

    def __getnumericlabel__(self, idx):
        return self.numeric_labels[idx]


class AdaptorModel(PreTrainedModel):
    def __init__(self, config, model_backbone):
        super().__init__(config)
        self.config = config
        self.backbone = model_backbone
        self.num_labels = 1
        self.lm_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, 1),
        )
        # self.lm_head = nn.Linear(config.hidden_size, 1)
        # freeze the backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask=None, labels=None):
        out = self.backbone(input_ids, return_dict=True, output_hidden_states=True)
        last_embed = torch.mean(out["hidden_states"][-1], dim=1)
        logits = self.lm_head(last_embed)

        loss_fct = nn.BCELoss(reduction="none")
        prob = torch.sigmoid(logits)

        loss = loss_fct(prob, labels)
        # loss = loss.squeeze(0)[0]
        # loss = loss.squeeze(0)[0].squeeze()

        return (loss.mean(), logits)


# Data load function
def load_data(data_path, dataset="sms", input_key="sms", seed=42):
    # load training data
    train_set_path = os.path.join(data_path, f"{dataset}/train_samples_s{seed}.csv")
    training_set = pd.read_csv(train_set_path)
    X_train = training_set[input_key].tolist()
    y_train = training_set["label"].tolist()

    # load test data
    if "ag_news" in dataset or "dbpedia" in dataset or "civil_comments" in dataset:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_bal.csv")
    else:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_orig.csv")
    test_set = pd.read_csv(test_set_path)
    X_test = test_set[input_key].tolist()
    y_test = test_set["label"].tolist()

    # return
    return (X_train, y_train), (X_test, y_test)


def evaluate_adaptor(
    model,
    test_dataset,
    tokenizer,
):
    gt_label, predicted_label, original_text, predicted_text, predicted_scores = (
        [],
        [],
        [],
        [],
        [],
    )
    _ = model.eval()
    map_label = {0: "negative", 1: "positive"}

    test_ds = ClassificationDataset(
        test_dataset[0],
        test_dataset[1],
        tokenizer,
        max_length=128,
        is_train=False,
        model_type="ft_adaptor",
    )
    with torch.no_grad():
        # iter over all of the test data
        for idx in range(test_ds.__len__()):
            # get the text and label
            input_ids, attn_mask, _, label = test_ds.__getitem__(idx)
            # label = test_ds.__getnumericlabel__(idx)
            # create prompt (in compliance with the one used during training)
            # perform prediction
            sample_outputs = model(
                input_ids.cuda(),
                attention_mask=attn_mask.cuda(),
                labels=label.unsqueeze(0).cuda(),
            )
            # decode the predicted tokens into texts

            logits = sample_outputs[-1]
            pred = torch.sigmoid(logits)
            if pred >= 0.5:
                pred_text = "positive"
                pred_label = "positive"
            else:
                pred_text = "negative"
                pred_label = "negative"

            # append results
            label = int(label[0].item())
            if label in map_label:
                gt_label.append(map_label[label])
            else:
                gt_label.append(label)
            predicted_label.append(pred_label)
            original_text.append(
                test_ds.__gettext__(idx)
            )  # TODO (): a bit hacky, maybe we have dataset return raw prompt
            predicted_text.append(pred_text)
            predicted_scores.append((pred[0].item(), 1 - pred[0].item()))

        # transform result into dataframe
        eval_outputs = {
            "original_text": original_text,
            "predicted_label": predicted_label,
            "gt_label": gt_label,
            "predicted_text": predicted_text,
            "predicted_scores": predicted_scores,
            "accuracy": sum(
                [1 if x == y else 0 for x, y in zip(gt_label, predicted_label)]
            )
            / len(gt_label),
        }

        # predict the accuracy
        return eval_outputs


def evaluate_model(
    model,
    test_dataset,
    tokenizer,
):
    gt_label, predicted_label, original_text, predicted_text, predicted_scores = (
        [],
        [],
        [],
        [],
        [],
    )
    _ = model.eval()
    map_label = {0: "negative", 1: "positive"}
    postive_token_id_no_space = tokenizer("positive").input_ids[0]
    negative_token_id_no_space = tokenizer("negative").input_ids[0]
    positive_token_id_space = tokenizer(" positive").input_ids[0]
    negative_token_id_space = tokenizer(" negative").input_ids[0]

    test_ds = ClassificationDataset(
        test_dataset[0],
        test_dataset[1],
        tokenizer,
        max_length=128,
        is_train=False,
        model_type="ft_model",
    )
    # iter over all of the test data
    with torch.no_grad():
        for idx in tqdm(range(test_ds.__len__())):
            # get the text and label
            input_ids, _, _, label = test_ds.__getitem__(idx)
            # create prompt (in compliance with the one used during training)
            # perform prediction
            sample_outputs = model.generate(
                input_ids.cuda(),
                do_sample=False,
                max_new_tokens=1,
                temperature=0,
                output_scores=True,
                return_dict_in_generate=True,
            )
            # decode the predicted tokens into texts
            pred_text = tokenizer.decode(
                sample_outputs["sequences"][0], skip_special_tokens=True
            )

            logits = torch.softmax(sample_outputs["scores"][0], axis=-1)
            pos_score_space = logits[:, positive_token_id_space].item()
            pos_score_no_space = logits[:, postive_token_id_no_space].item()
            neg_score_space = logits[:, negative_token_id_space].item()
            neg_score_no_space = logits[:, negative_token_id_no_space].item()

            pred_label = pred_text.split(":")[-1].strip()
            # extract the predicted sentiment

            # append results
            label = int(label[0].item())
            if label in map_label:
                gt_label.append(map_label[label])
            else:
                gt_label.append(label)
            predicted_label.append(pred_label)
            original_text.append(
                test_ds.__gettext__(idx)
            )  # TODO (): a bit hacky, maybe we have dataset return raw prompt
            predicted_text.append(pred_text)
            predicted_scores.append(
                (
                    pos_score_space,
                    pos_score_no_space,
                    neg_score_space,
                    neg_score_no_space,
                )
            )

    # transform result into dataframe
    eval_outputs = {
        "original_text": original_text,
        "predicted_label": predicted_label,
        "gt_label": gt_label,
        # "predicted_text": predicted_text,
        "predicted_scores": predicted_scores,
        "accuracy": sum([1 if x == y else 0 for x, y in zip(gt_label, predicted_label)])
        / len(gt_label),
    }

    # predict the accuracy
    return eval_outputs


def evaluate_LRSolver(
    solver,
    test_dataset,
    tokenizer,
    x_mean,
    model_name,
):
    gt_label, predicted_label, original_text, predicted_text, predicted_scores = (
        [],
        [],
        [],
        [],
        [],
    )

    map_label = {0: "negative", 1: "positive"}

    test_ds = ClassificationDataset(
        test_dataset[0],
        test_dataset[1],
        tokenizer,
        max_length=128,
        is_train=False,
        is_lr_solver=True,
        generate_embeddings=True,
        model_name=model_name,
        model_type="lr_solver",
    )

    X, y = [], []
    original_text = []
    for idx in range(test_ds.__len__()):
        X.append(test_ds.__getembedding__(idx))
        y.append(test_ds.__getnumericlabel__(idx))
        original_text.append(test_ds.__gettext__(idx))
    # iter over all of the test data
    X = np.stack(X)
    X = X - x_mean
    X = X / np.linalg.norm(X, axis=1, keepdims=True)

    y = np.stack(y)

    numeric_label = solver.predict(X)
    predicted_label = [map_label[label] for label in numeric_label]
    gt_label = [map_label[label] for label in y]

    predicted_scores = solver.predict_proba(X)
    # predicted_scores dim = (n_samples, n_classes). flatten into a list of tupes with (pos_score, neg_score)
    predicted_scores = [(score[1], score[0]) for score in predicted_scores]

    print(sklearn.metrics.accuracy_score(gt_label, predicted_label))

    # transform result into dataframe
    eval_outputs = {
        "original_text": original_text,
        "predicted_label": predicted_label,
        "gt_label": gt_label,
        "predicted_text": predicted_text,  # empty list
        "predicted_scores": predicted_scores,  # list of tuples of (pos_score, neg_score)
    }

    # predict the accuracy
    return eval_outputs


def finetune_LR(
    solver_type: str = "saga",
    penalty: str = "l2",
    dataset="sms",
    eval_dataset="sms",
    input_key_eval="sms",
    input_key_train="sms",
    seed=42,
    data_path="./data",
    save_dir="./ft_outputs",
    do_save=False,
    do_eval=False,
    model_name="EleutherAI/gpt-neo-125M",
):
    print(f"model_name: {model_name}")
    (X_train, y_train), (X_test, y_test) = load_data(
        data_path, dataset, input_key_train, seed
    )

    # load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|pad|>")
    tokenizer.truncation_side = "left"

    results = {}

    for num_epoch in [0]:
        results[num_epoch] = {}
        c = [1e-2, 1e-2, 1e-2, 0.1, 0.2, 0.4, 0.6, 0.7, 1]

        for i, k in enumerate([4, 8, 16, 32, 48, 64, 128]):
            print(f"Finetuning LR with seed {seed}, {k} samples")
            # set model name
            results[num_epoch][k] = None
            checkpoint_name = f"{model_name}_{dataset}_s{seed}_k{k}_e{num_epoch}"
            # instantiate model
            if "pythia" in model_name:
                model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                ).cuda()
            else:
                model = GPTNeoForCausalLM.from_pretrained(
                    model_name,
                ).cuda()
            model.resize_token_embeddings(len(tokenizer))

            x_k, y_k = X_train[0:k], y_train[0:k]

            train_dataset = ClassificationDataset(
                x_k,
                y_k,
                tokenizer,
                max_length=128,
                is_train=False,
                is_lr_solver=True,
                generate_embeddings=True,
                model_name=model_name,
                model_type="lr_solver",
            )

            # collate embeddings and labels to pass to LR
            X, y = [], []
            for idx in range(train_dataset.__len__()):
                X.append(train_dataset.__getembedding__(idx))
                y.append(train_dataset.__getnumericlabel__(idx))

            # stack embeddings and labels as np. arrays
            X = np.stack(X)
            x_mean = X.mean(axis=0)
            # normalize X = dim( 4,768) to have zero mean
            X = X - x_mean
            # normalize xs to have unit norm
            X = X / np.linalg.norm(X, axis=1)[:, None]

            y = np.stack(y)
            lr_solver = LogisticRegression(penalty=penalty, solver=solver_type, C=c[i])
            lr_solver.classes_ = np.array([0, 1])
            lr_solver.fit(X, y)

            # evaluate on test set
            # evaluate model
            if do_eval:
                # set model to eval mode
                _ = model.eval()

                _, (X_test, y_test) = load_data(
                    data_path, eval_dataset, input_key_eval, seed
                )
                eval_outputs = evaluate_LRSolver(
                    lr_solver, (X_test, y_test), tokenizer, x_mean, model_name
                )
                results[num_epoch][k] = eval_outputs

    return results


def load_model_and_tokenizer(
    model_name,
    ft_full=False,
    ft_layers=False,
    ft_adaptor=False,
    ft_head=False,
    layers_to_freeze=2,
):
    model_config = AutoConfig.from_pretrained(
        model_name,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
    ).cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|pad|>")
    tokenizer.truncation_side = "left"
    model.resize_token_embeddings(len(tokenizer))

    if ft_layers:
        print("Finetuning last 3 layers")
        # freeze all layers except last 3
        if "pythia" in model_name:
            t_layers = model_config.num_hidden_layers
        elif "bloom" in model_name:
            t_layers = model_config.n_layer
        else:
            t_layers = model_config.num_layers
        for param in model.parameters():
            param.requires_grad = False
        for i in range(layers_to_freeze):
            layer_idx = t_layers - i
            if "pythia" in model_name:
                for param in model.base_model.layers[layer_idx:].parameters():
                    param.requires_grad = True
                for param in model.embed_out.parameters():
                    param.requires_grad = True
            else:
                for param in model.transformer.h[layer_idx:].parameters():
                    param.requires_grad = True

                for param in model.lm_head.parameters():
                    param.requires_grad = True
    elif ft_head:
        print("Finetuning last layer")
        # freeze all layers except last 3
        for param in model.parameters():
            param.requires_grad = False
        if "pythia" in model_name:
            for param in model.embed_out.parameters():
                param.requires_grad = True
        else:
            for param in model.lm_head.parameters():
                param.requires_grad = True
    elif ft_adaptor:
        print("Finetuning adaptor")
        model_backbone = AutoModel.from_pretrained(
            model_name,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|pad|>")
        tokenizer.truncation_side = "left"
        model_backbone.resize_token_embeddings(len(tokenizer))
        model = AdaptorModel(config=model_config, model_backbone=model_backbone).cuda()
    else:
        print("Finetuning full model")
    return tokenizer, model


def finetune_model(
    model_name,
    model=None,
    tokenizer=None,
    dataset="hate",
    eval_dataset="hate",
    input_key_eval="text",
    input_key_train="text",
    seed=42,
    data_path="./data",
    save_dir="./ft_outputs",
    do_save=False,
    do_eval=False,
    lr=1e-4,
    checkpoint_postfix=None,
    epoch_range=[1, 3, 5, 10],
    k_range=[4, 8, 16, 32, 64, 128],
    loss_type="label",  # label
    model_type=None,
    ft_full=False,
    ft_layers=False,
    ft_adaptor=False,
    ft_head=False,
    layers_to_freeze=2,
    max_eval=None,
    do_grad_desc=False,
):
    print(do_save)
    print(do_grad_desc)
    # Runs a finetuning experiment for a single batch of training data parameterized by seed
    # Returns a dictionary with predictions @ each epoch
    # metrics are NOT computed here

    # load training and test data
    (X_train, y_train), (X_test, y_test) = load_data(
        data_path, dataset, input_key_train, seed
    )

    results = {}

    for num_epoch in epoch_range:
        results[num_epoch] = {}
        for k in k_range:
            # if model is None and tokenizer is None:
            tokenizer, model = load_model_and_tokenizer(
                model_name=model_name,
                ft_full=ft_full,
                ft_layers=ft_layers,
                ft_adaptor=ft_adaptor,
                ft_head=ft_head,
                layers_to_freeze=layers_to_freeze,
            )
            print("Epoch: ", num_epoch)

            # set model name
            results[num_epoch][k] = None

            checkpoint_name = (
                f"{model_name}_{dataset}_s{seed}_k{k}_e{num_epoch}_{checkpoint_postfix}"
            )

            x_k, y_k = X_train[0:k], y_train[0:k]

            # TODO (): remove hard coded max length
            train_dataset = ClassificationDataset(
                x_k, y_k, tokenizer, max_length=128, model_type=model_type
            )

            # set model name
            if do_grad_desc:
                bs = k
            else:
                bs = 1

            training_args = TrainingArguments(
                num_train_epochs=num_epoch,
                logging_steps=10,
                lr_scheduler_type="constant",
                save_strategy="no",
                save_total_limit=1,
                evaluation_strategy="no",
                per_device_train_batch_size=bs,
                per_device_eval_batch_size=1,
                gradient_accumulation_steps=1,
                warmup_steps=0,
                weight_decay=0.01,
                logging_dir="logs",
                learning_rate=lr,
                output_dir=save_dir,
                logging_strategy="epoch",
                report_to="none",
            )

            print(f"start training: epoch {num_epoch}, k {k}, seed {seed}")
            print(model_type)
            print(loss_type)
            if model_type == "ft_adaptor":
                label_idx = 3
            elif model_type != "ft_adaptor" and loss_type == "sentence":
                label_idx = 0
            else:
                label_idx = 2

            # start training
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                data_collator=lambda data: {
                    "input_ids": torch.stack([f[0] for f in data]),
                    "attention_mask": torch.stack([f[1] for f in data]),
                    "labels": torch.stack([f[label_idx] for f in data]),
                },
            )

            trainer.train()

            # TODO (): device save model policy
            # save model
            if do_save:
                if not os.path.exists(os.path.join(save_dir, dataset)):
                    os.makedirs(os.path.join(save_dir, dataset))

                save_path = os.path.join(
                    os.path.join(save_dir, dataset), checkpoint_name
                )
                trainer.save_model(save_path)

            # evaluate model
            if do_eval:
                # set model to eval mode
                _ = model.eval()

                _, (X_test, y_test) = load_data(
                    data_path, eval_dataset, input_key_eval, seed
                )
                if max_eval is not None:
                    X_test, y_test = X_test[:max_eval], y_test[:max_eval]

                if model_type != "ft_adaptor":
                    print("testing")
                    # eval_outputs = evaluate_model(model, (X_test, y_test), tokenizer)
                    eval_outputs = evaluate_model(model, (X_test, y_test), tokenizer)
                    print(eval_outputs["accuracy"])
                else:
                    # eval_outputs = evaluate_adaptor(model, (X_test, y_test), tokenizer)
                    # eval_outputs = evaluate_adaptor(model, (X_test, y_test), tokenizer)
                    eval_outputs = evaluate_adaptor(model, (X_test, y_test), tokenizer)
                    print(eval_outputs["accuracy"])
                results[num_epoch][k] = eval_outputs

    return results


## Load model and data
# -------

# import default dict
from collections import defaultdict

accuracy = defaultdict(lambda: defaultdict(list))

torch.manual_seed(42)


def call_finetune_model(
    model_name,
    model,
    tokenizer,
    dataset,
    eval_dataset,
    input_key_train,
    input_key_eval,
    lr,
    do_save=False,
    do_eval=False,
    checkpoint_postfix=None,
    epoch_range=[1, 3, 5, 10],
    k_range=[4, 8, 16, 32, 64, 128],
    loss_type="label",
    model_type="ft_adaptor",
    output_dir="./ft_test",
    data_path="./data",
    save_dir="./ft_outputs",
    ft_full=False,
    ft_layers=False,
    ft_adaptor=False,
    ft_head=False,
    layers_to_freeze=2,
    do_grad_desc=False,
):
    model_name = model_name
    dataset = dataset
    eval_dataset = eval_dataset
    input_key_train = input_key_train
    input_key_eval = input_key_eval

    file_name = f"train_{dataset}_eval_{eval_dataset}_lr{lr}.pkl".replace("/", "-")
    final_results = {}
    for seed in [42, 69, 128, 512, 1024]:  # 9, 204, 405, 9205, 2020]:
        results = finetune_model(
            model_name=model_name,
            model=model,
            tokenizer=tokenizer,
            dataset=dataset,
            eval_dataset=dataset,
            input_key_train=input_key_train,
            input_key_eval=input_key_eval,
            seed=seed,
            data_path=data_path,
            save_dir=save_dir,
            do_save=do_save,
            do_eval=do_eval,
            lr=lr,
            checkpoint_postfix=None,
            epoch_range=epoch_range,
            k_range=k_range,
            loss_type=loss_type,
            model_type=model_type,
            ft_full=ft_full,
            ft_layers=ft_layers,
            ft_adaptor=ft_adaptor,
            ft_head=ft_head,
            layers_to_freeze=layers_to_freeze,
            do_grad_desc=do_grad_desc,
        )
        final_results[seed] = results

    model_name_split = model_name.split("/")[-1]
    output_dir = f"{output_dir}/{dataset}/{model_name_split}/{model_type}/"
    save_path = os.path.join(output_dir, file_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    pickle.dump(final_results, open(save_path, "wb"))


def run_epoch_tune(
    model_name,
    dataset,
    eval_dataset,
    input_key_train,
    input_key_eval,
    lr,
    do_save=False,
    do_eval=False,
    checkpoint_postfix=None,
    epoch_range=[1, 3, 5, 10],
    k_range=[4, 8, 16, 32, 64, 128],
    loss_type="label",
    model_type="ft_adaptor",
    output_dir="./ft_test",
    ft_full=True,
    ft_layers=False,
    ft_adaptor=False,
    ft_head=False,
    layers_to_freeze=2,
):
    model = None
    tokenizer = None
    final_results = {}
    for learn_rate in [lr]:
        for seed in [42, 69, 128]:
            results = finetune_model(
                model_name=model_name,
                model=model,
                tokenizer=tokenizer,
                dataset=dataset,
                eval_dataset=dataset,
                input_key_train=input_key_train,
                input_key_eval=input_key_eval,
                seed=seed,
                data_path="./data",
                save_dir="./ft_outputs",
                do_save=do_save,
                do_eval=do_eval,
                lr=learn_rate,
                checkpoint_postfix=None,
                epoch_range=epoch_range,
                k_range=k_range,
                loss_type=loss_type,
                model_type=model_type,
                ft_full=ft_full,
                ft_layers=ft_layers,
                ft_adaptor=ft_adaptor,
                ft_head=ft_head,
                layers_to_freeze=layers_to_freeze,
            )
            final_results[seed] = results
    file_name = f"HParamSearch_ds{dataset}_m{model_name}_mt{model_type}".replace(
        "/", "-"
    )
    save_path = os.path.join(output_dir, file_name)
    pickle.dump(final_results, open(save_path, "wb"))


def run_hyperparam_tune(
    model_name,
    dataset,
    eval_dataset,
    input_key_train,
    input_key_eval,
    lr,
    do_save=False,
    do_eval=False,
    checkpoint_postfix=None,
    epoch_range=[1, 3, 5, 10],
    k_range=[4, 8, 16, 32, 64, 128],
    loss_type="label",
    model_type="ft_adaptor",
    output_dir="./ft_test",
    data_path=None,
    save_dir=None,
    seeds=[42, 69, 128, 512, 1024],
    ft_full=True,
    ft_layers=False,
    ft_adaptor=False,
    ft_head=False,
    layers_to_freeze=2,
    do_grad_desc=False,
):
    print(output_dir)
    final_results = {}
    model = None
    tokenizer = None
    for lr in [1e-3, 8e-6, 1e-5, 3e-5, 1e-4, 1e-3]:
        final_results[lr] = {}
        for seed in seeds:
            final_results[lr][seed] = None
            results = finetune_model(
                model_name=model_name,
                model=model,
                tokenizer=tokenizer,
                dataset=dataset,
                eval_dataset=dataset,
                input_key_train=input_key_train,
                input_key_eval=input_key_eval,
                seed=seed,
                data_path=data_path,
                save_dir=save_dir,
                do_save=do_save,
                do_eval=do_eval,
                lr=lr,
                checkpoint_postfix=None,
                epoch_range=epoch_range,  # [epoch_range[-1]],
                k_range=[k_range[-1]],
                loss_type=loss_type,
                model_type=model_type,
                max_eval=None,
                ft_full=ft_full,
                ft_layers=ft_layers,
                ft_adaptor=ft_adaptor,
                ft_head=ft_head,
                layers_to_freeze=layers_to_freeze,
                do_grad_desc=do_grad_desc,
            )

            # final_results[lr].append(results[epoch_range[-1]][k_range[-1]]["accuracy"])
            final_results[lr][seed] = results
    file_name = f"HParamSearch_ds{dataset}_m{model_name}_mt{model_type}".replace(
        "/", "-"
    )
    save_path = os.path.join(output_dir, file_name)
    pickle.dump(final_results, open(save_path, "wb"))


def call_lr_solver(model_name):
    model_name = model_name
    dataset = "synthetic_noise_0.25_seqlen_12"
    eval_dataset = "synthetic_noise_0.25_seqlen_12"
    input_key_train = "text"
    input_key_eval = "text"
    solver_type = "saga"
    penalty = "l2"

    file_name = f"LRSolver_{solver_type}_Penalty_{penalty}_train_{dataset}_eval_{eval_dataset}_pythia2.8".replace(
        "/", "-"
    )
    final_results = {}
    for seed in [42, 69, 128, 512, 1024]:  # 9, 204, 405, 9205, 2020]:
        results = finetune_LR(
            model_name=model_name,
            solver_type=solver_type,
            penalty=penalty,
            dataset=dataset,
            eval_dataset=dataset,
            input_key_train=input_key_train,
            input_key_eval=input_key_eval,
            seed=seed,
            data_path="./data",
            save_dir="./ft_outputs",
            do_save=False,
            do_eval=True,
        )
        final_results[seed] = results

    pickle.dump(final_results, open(file_name, "wb"))


if __name__ == "__main__":
    # import argparse

    parser = argparse.ArgumentParser(description="Say hello")
    parser.add_argument(
        "--model_name",
        type=str,
        default="EleutherAI/gpt-neo-125m",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="ag_news",
    )
    parser.add_argument(
        "--eval_dataset",
        type=str,
        default="ag_news",
    )
    parser.add_argument(
        "--input_key_train",
        type=str,
        default="text",
    )
    parser.add_argument(
        "--input_key_eval",
        type=str,
        default="text",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=0.0001,
    )
    parser.add_argument(
        "--checkpoint_postfix",
        type=str,
        default="",
    )
    parser.add_argument(
        "--loss_type",
        type=str,
        default="label",
    )
    parser.add_argument(
        "--layers_to_freeeze",
        type=int,
        default=2,
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="./outputs",
    )
    parser.add_argument("--data_path", type=str, default="./data")

    parser.add_argument("--k_range", type=str, default="[4, 8, 16, 32, 64, 128]")
    parser.add_argument("--epoch_range", type=str, default="[5, 10]")
    parser.add_argument("--seeds", type=str, default="[42, 69]")
    parser.add_argument("--output_dir", type=str, default="./ft_test")
    parser.add_argument("--ft_full", action="store", default=False)
    parser.add_argument("--ft_adaptor", action="store", default=False)
    parser.add_argument("--ft_layers", action="store", default=False)
    parser.add_argument("--ft_head", action="store", default=False)
    parser.add_argument("--do_eval", action="store", default=False)
    parser.add_argument("--do_save", action="store", default=False)
    parser.add_argument("--do_hpsearch", action="store_true", default=False)
    parser.add_argument("--do_epochtune", action="store_true", default=False)
    parser.add_argument("--do_gd", action="store_true", default=False)

    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.ft_full:
        model_type = "ft_full"
    elif args.ft_adaptor:
        model_type = "ft_adaptor"
    elif args.ft_layers:
        model_type = "ft_layers"
    elif args.ft_head:
        model_type = "ft_head"

    if args.do_hpsearch:
        print("Running hyperparam search")
        for model_type in [
            "ft_full",
            "ft_adaptor",
            "ft_layers",
            "ft_head",
        ]:  # "ft_full", "ft_adaptor",
            if model_type == "ft_full":
                args.ft_full = True
            elif model_type == "ft_adaptor":
                args.ft_adaptor = True
            elif model_type == "ft_layers":
                args.ft_layers = True
            elif model_type == "ft_head":
                args.ft_head = True

            run_hyperparam_tune(
                model_name=args.model_name,
                dataset=args.dataset,
                eval_dataset=args.eval_dataset,
                input_key_train=args.input_key_train,
                input_key_eval=args.input_key_eval,
                lr=args.learning_rate,
                do_save=args.do_save,
                do_eval=args.do_eval,
                checkpoint_postfix=args.checkpoint_postfix,
                epoch_range=eval(args.epoch_range),
                k_range=eval(args.k_range),
                loss_type=args.loss_type,
                model_type=model_type,
                output_dir=args.output_dir,
                data_path=args.data_path,
                save_dir=args.save_dir,
                seeds=eval(args.seeds),
                ft_full=args.ft_full,
                ft_layers=args.ft_layers,
                ft_adaptor=args.ft_adaptor,
                ft_head=args.ft_head,
                layers_to_freeze=args.layers_to_freeeze,
                do_grad_desc=args.do_gd,
            )

    elif args.do_epochtune:
        for model_type in ["ft_full", "ft_adaptor", "ft_layers", "ft_head"]:
            if model_type == "ft_full":
                args.ft_full = True
            elif model_type == "ft_adaptor":
                args.ft_adaptor = True
            elif model_type == "ft_layers":
                args.ft_layers = True
            elif model_type == "ft_head":
                args.ft_head = True

            run_hyperparam_tune(
                model_name=args.model_name,
                dataset=args.dataset,
                eval_dataset=args.eval_dataset,
                input_key_train=args.input_key_train,
                input_key_eval=args.input_key_eval,
                lr=args.learning_rate,
                do_save=args.do_save,
                do_eval=args.do_eval,
                checkpoint_postfix=args.checkpoint_postfix,
                epoch_range=eval(args.epoch_range),
                k_range=eval(args.k_range),
                loss_type=args.loss_type,
                model_type=model_type,
                output_dir=args.output_dir,
                ft_full=args.ft_full,
                ft_layers=args.ft_layers,
                ft_adaptor=args.ft_adaptor,
                ft_head=args.ft_head,
                layers_to_freeze=args.layers_to_freeeze,
            )

    else:
        # tokenizer, model = load_model_and_tokenizer(
        #     model_name=args.model_name,
        #     ft_full=args.ft_full,
        #     ft_layers=args.ft_layers,
        #     ft_adaptor=args.ft_adaptor,
        #     ft_head=args.ft_head,
        #     layers_to_freeze=args.layers_to_freeeze,
        # )
        if args.ft_full == "True":
            args.ft_full = True
        if args.ft_adaptor == "True":
            args.ft_adaptor = True
        if args.ft_layers == "True":
            args.ft_layers = True
        if args.ft_head == "True":
            args.ft_head = True

        call_finetune_model(
            model_name=args.model_name,
            model=None,
            tokenizer=None,
            dataset=args.dataset,
            eval_dataset=args.eval_dataset,
            input_key_train=args.input_key_train,
            input_key_eval=args.input_key_eval,
            lr=args.learning_rate,
            do_save=args.do_save,
            do_eval=args.do_eval,
            checkpoint_postfix=args.checkpoint_postfix,
            epoch_range=eval(args.epoch_range),
            k_range=eval(args.k_range),
            loss_type=args.loss_type,
            model_type=model_type,
            output_dir=args.output_dir,
            data_path=args.data_path,
            save_dir=args.save_dir,
            ft_full=args.ft_full,
            ft_layers=args.ft_layers,
            ft_adaptor=args.ft_adaptor,
            ft_head=args.ft_head,
            layers_to_freeze=args.layers_to_freeeze,
            do_grad_desc=args.do_gd,
        )
