import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import os
import json
import math

from utils import DemosTemplate, EvalTemplate
from src.infl_torch import get_ridge_weights, calc_influence
from datasets import load_dataset

import argparse



map_int_to_abcd = {
    0: "A",
    1: "B",
    2: "C",
    3: "D",
}

map_abcd_to_int = {
    "A": 0,
    "B": 1,
    "C": 2,
    "D": 3,
}


def load_model(model_path="lmsys/vicuna-7b-v1.3", tokenizer_path=None, device="cuda"):
    if tokenizer_path is None:
        tokenizer_path = model_path
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        trust_remote_code=True,
    ).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        # use_fast=False
    )
    return model, tokenizer


def process_raw_data(dataset_name, seed, total_num_data):
    global map_abcd_to_int, map_int_to_abcd
    if dataset_name == "ag_news":
        map_int_to_abcd = {
            0: "A",
            1: "B",
            2: "C",
            3: "D",
        }

        map_abcd_to_int = {
            "A": 0,
            "B": 1,
            "C": 2,
            "D": 3,
        }
        raw_dataset = load_dataset(dataset_name)
        data = raw_dataset["train"]
        data = data.shuffle(seed=seed)
        # total_num_data = bs * (num_data + 1) # +1 for test query
        assert len(data) >= total_num_data, f"Dataset size is {len(data)}, which is less than {total_num_data}"
        map_fn = lambda x: map_int_to_abcd[x]
        data_input = data['text'][:total_num_data]
        data_output = list(map(map_fn, data['label'][:total_num_data]))
    elif dataset_name == "rotten_tomatoes":
        map_int_to_abcd = {
            0: "A",
            1: "B",
        }

        map_abcd_to_int = {
            "A": 0,
            "B": 1,
        }
        raw_dataset = load_dataset("rotten_tomatoes")
        data = raw_dataset["train"]
        data = data.shuffle(seed=seed)
        # total_num_data = bs * (num_data + 1) # +1 for test query
        assert len(data) >= total_num_data, f"Dataset size is {len(data)}, which is less than {total_num_data}"
        map_fn = lambda x: map_int_to_abcd[x]
        data_input = data['text'][:total_num_data]
        data_output = list(map(map_fn, data['label'][:total_num_data]))
    elif dataset_name == "sst2":
        map_int_to_abcd = {
            0: "A",
            1: "B",
        }

        map_abcd_to_int = {
            "A": 0,
            "B": 1,
        }
        raw_dataset = load_dataset("SetFit/sst2")
        data = raw_dataset["train"]
        data = data.shuffle(seed=seed)
        # total_num_data = bs * (num_data + 1) # +1 for test query
        assert len(data) >= total_num_data, f"Dataset size is {len(data)}, which is less than {total_num_data}"
        map_fn = lambda x: map_int_to_abcd[x]
        data_input = data['text'][:total_num_data]
        data_output = list(map(map_fn, data['label'][:total_num_data]))
    elif dataset_name == "sst5":
        map_int_to_abcd = {
            0: "A",
            1: "B",
            2: "C",
            3: "D",
            4: "E",
        }

        map_abcd_to_int = {
            "A": 0,
            "B": 1,
            "C": 2,
            "D": 3,
            "E": 4,
        }
        raw_dataset = load_dataset("SetFit/sst5")
        data = raw_dataset["train"]
        data = data.shuffle(seed=seed)
        # total_num_data = bs * (num_data + 1) # +1 for test query
        assert len(data) >= total_num_data, f"Dataset size is {len(data)}, which is less than {total_num_data}"
        map_fn = lambda x: map_int_to_abcd[x]
        data_input = data['text'][:total_num_data]
        data_output = list(map(map_fn, data['label'][:total_num_data]))
    elif dataset_name == "subj":
        map_int_to_abcd = {
            0: "A",
            1: "B",
        }

        map_abcd_to_int = {
            "A": 0,
            "B": 1,
        }
        raw_dataset = load_dataset("SetFit/subj")
        data = raw_dataset["train"]
        data = data.shuffle(seed=seed)
        # total_num_data = bs * (num_data + 1) # +1 for test query
        assert len(data) >= total_num_data, f"Dataset size is {len(data)}, which is less than {total_num_data}"
        map_fn = lambda x: map_int_to_abcd[x]
        data_input = data['text'][:total_num_data]
        data_output = list(map(map_fn, data['label'][:total_num_data]))
    elif dataset_name == "sentence_similarity":
        map_int_to_abcd = {
            0: "A",
            1: "B",
            2: "C",
            3: "D",
            4: "E",
            5: "F",
        }

        map_abcd_to_int = {
            "A": 0,
            "B": 1,
            "C": 2,
            "D": 3,
            "E": 4,
            "F": 5,
        }
        with open("./data/sentence_similarity.json") as f:
            sentence_similarity = json.load(f)
        samples = list(sentence_similarity['examples'].values())
        # shuffle samples
        np.random.seed(seed)
        np.random.shuffle(samples)
        outputs = list(map(lambda x: x['output'], samples))
        labels = np.unique(outputs)
        mapped_outputs = list(map(lambda x: map_int_to_abcd[labels.tolist().index(x)], outputs))
        inputs = list(map(lambda x: x['input'], samples))
        data_input = inputs[:total_num_data]
        data_output = mapped_outputs[:total_num_data]
    else:
        raise NotImplementedError(f"Dataset {dataset_name} not implemented")
    
    return data_input, data_output


def get_data(data_input, data_output, bs=200, num_data=10):
    dataset = []
    for i in range(bs):
        cur_input = data_input[i*num_data:(i+1)*num_data]
        cur_output = data_output[i*num_data:(i+1)*num_data]
        cur_eval_input = data_input[i*num_data+num_data:i*num_data+num_data+1]
        cur_eval_output = data_output[i*num_data+num_data:i*num_data+num_data+1]
        cur_demo_data = (cur_input, cur_output)
        cur_eval_data = (cur_eval_input, cur_eval_output)
        dataset.append((cur_demo_data, cur_eval_data))
    return dataset


def get_context(demo_data, eval_data):
    demos_template = "Input: [INPUT]\nOutput: [OUTPUT]"
    eval_template = "[full_DEMO]\n\nInput: [INPUT]\nOutput:"

    d_template = DemosTemplate(demos_template)
    e_template = EvalTemplate(eval_template)
    demos = d_template.fill(demo_data)
    evals = e_template.fill(full_demo=demos, input=eval_data[0][0])

    return evals


def get_embedding(hidden_states, label_pos, layer_num, device="cuda"):
    return hidden_states[layer_num][0][label_pos].to(device)


def get_label_pos(tokenizer, input_ids, model_name="vicuna-7b", position="column"):
    output_id = tokenizer.convert_tokens_to_ids("Output")
    all_label_pos = []
    labels = list(map_abcd_to_int.keys())
    for label in labels:
        if model_name in ["vicuna-7b", "Llama-2-7b", "Llama-2-13b", "mistral-7b", "wizardlm-7b"]:
            label_id = tokenizer.convert_tokens_to_ids(f"▁{label}")
        elif model_name in ["gpt2-xl", "falcon-7b", "mamba"]:
            label_id = tokenizer.convert_tokens_to_ids(f"Ġ{label}")
        else:
            raise NotImplementedError(f"Label_id for model {model_name} not implemented")
        if position == "label":
            label_pos = [i for i, x in enumerate(input_ids) if x == label_id and input_ids[i - 2] == output_id]
        elif position == "column":
            label_pos = [i - 1 for i, x in enumerate(input_ids) if x == label_id and input_ids[i - 2] == output_id]
        all_label_pos += label_pos
    q_pos = [len(input_ids) - 1]
    return sorted(all_label_pos + q_pos)


def compute_infl(demo_data, eval_data, layer_nums, project_dim=None, alpha=1.0, device="cuda", score="test"):
    context = get_context(demo_data, eval_data)
    input_ids = tokenizer.encode(context, return_tensors="pt").to(device)
    out = model(input_ids, output_hidden_states=True,)
    hidden_states = out.hidden_states
    position = "column"
    if use_label_pos:
        position = "label"
    label_pos = get_label_pos(tokenizer, input_ids[0], model_name=model_name, position=position)

    infl_store = np.array([0. for _ in range(len(demo_data[0]))])

    for layer_num in layer_nums:
        emb = get_embedding(hidden_states, label_pos, layer_num, device=device).to(torch.float32)

        y = torch.tensor([map_abcd_to_int[label] for label in demo_data[1]]).to(device).to(torch.int64)
        if project_dim is None or project_dim == emb.shape[1]:
            # use an identity matrix
            project_matrix = torch.nn.Identity(emb.shape[1]).to(device)
        else:
            project_matrix = torch.nn.Linear(emb.shape[1], project_dim, bias=False).to(device)
            torch.nn.init.normal_(project_matrix.weight, mean=0.0, std=math.sqrt(1.0/project_dim))
        w = get_ridge_weights(emb[:-1], y, len(map_abcd_to_int), project_matrix, device=device).to(torch.float32)
        cur_infls = []
        for i in range(len(demo_data[0])):
            infl = calc_influence(i, emb, w, eval_data[1][0], y, map_abcd_to_int, project_matrix, device=device, score=score, alpha=alpha)
            cur_infls.append(infl)
        cur_infls = np.asarray(cur_infls)
        infl_store += cur_infls
    return infl_store



def perturb_labels(demo_data, flip_idx):
    new_demo_labels = list(demo_data[1])
    for i in flip_idx:
        orig_label = demo_data[1][i]
        orig_num = map_abcd_to_int[orig_label]
        new_label = np.random.choice([x for x in range(len(map_abcd_to_int)) if x != orig_num])
        new_demo_labels[i] = map_int_to_abcd[new_label]
    new_demo_data = (demo_data[0], new_demo_labels)
    return new_demo_data


def remove_data(demo_data, remove_idx):
    new_demo_data = (np.delete(demo_data[0], remove_idx), np.delete(demo_data[1], remove_idx))
    return new_demo_data


def check_answer(model, tokenizer, context, eval_data, device="cuda"):
    input_ids = tokenizer.encode(context, return_tensors="pt").to(device)
    output_tokens = model.generate(input_ids, do_sample=False, max_new_tokens=3, pad_token_id=tokenizer.eos_token_id)
    # check if the answer is corret
    answer = tokenizer.decode(output_tokens[0][len(input_ids[0]):], skip_special_tokens=True)
    if answer.strip() == eval_data[1][0].strip():
        return True
    else:
        return False


@torch.no_grad()
def run(num_remove):
    acc_orig = []
    acc_rem_high = []
    acc_rem_low = []
    acc_rem_random = []

    for seed in range(num_trials):
        data_input, data_output = process_raw_data(dataset_name, seed, bs * (icl_dataset_size + 1))
        dataset = get_data(data_input, data_output, bs=bs, num_data=icl_dataset_size)

        num_success_orig = 0
        num_success_rem_high = 0
        num_success_rem_low = 0
        num_success_rem_random = 0


        for demo_data, eval_data in dataset:
            context = get_context(demo_data, eval_data)

            infls = compute_infl(demo_data, eval_data, layer_nums, project_dim=project_dim, device=device, score="test", alpha=1.0)
            
            # check the original context
            if check_answer(model, tokenizer, context, eval_data, device=device):
                num_success_orig += 1
            
            # remove the top {num_remove} data points with the highest influence
            remove_idx = np.argsort(infls)[-num_remove:]
            if corrupt:
                new_demo_data = perturb_labels(demo_data, remove_idx)
            else:
                new_demo_data = remove_data(demo_data, remove_idx)
            new_context = get_context(new_demo_data, eval_data)
            if check_answer(model, tokenizer, new_context, eval_data, device=device):
                num_success_rem_high += 1
            
            # remove the top {num_remove} data points with the lowest influence
            remove_idx = np.argsort(infls)[:num_remove]
            if corrupt:
                new_demo_data = perturb_labels(demo_data, remove_idx)
            else:
                new_demo_data = remove_data(demo_data, remove_idx)
            new_context = get_context(new_demo_data, eval_data)
            if check_answer(model, tokenizer, new_context, eval_data, device=device):
                num_success_rem_low += 1

            # remove {num_remove} random data points
            remove_idx = np.random.choice(len(demo_data[0]), num_remove, replace=False)
            if corrupt:
                new_demo_data = perturb_labels(demo_data, remove_idx)
            else:
                new_demo_data = remove_data(demo_data, remove_idx)
            new_context = get_context(new_demo_data, eval_data)
            if check_answer(model, tokenizer, new_context, eval_data, device=device):
                num_success_rem_random += 1

        cur_acc_orig = num_success_orig / bs
        cur_acc_rem_high = num_success_rem_high / bs
        cur_acc_rem_low = num_success_rem_low / bs
        cur_acc_rem_random = num_success_rem_random / bs

        acc_orig.append(cur_acc_orig)
        acc_rem_high.append(cur_acc_rem_high)
        acc_rem_low.append(cur_acc_rem_low)
        acc_rem_random.append(cur_acc_rem_random)

        print(f"Acc Original: {cur_acc_orig}")
        print(f"Acc Remove high: {cur_acc_rem_high}")
        print(f"Acc Remove low: {cur_acc_rem_low}")
        print(f"Acc Remove random: {cur_acc_rem_random}")
    
    return acc_orig, acc_rem_high, acc_rem_low, acc_rem_random


if __name__ == "__main__":

    args = argparse.ArgumentParser()
    args.add_argument("--layer_nums", type=int, nargs="+", default=[15])
    args.add_argument("--model", type=str, default="vicuna-7b")
    args.add_argument("--dataset_name", type=str, default="ag_news")
    args.add_argument("--project_dim", type=int, default=None)
    args.add_argument("--use_label_pos", action="store_true")
    args.add_argument("--bs", type=int, default=100)
    args.add_argument("--num_trials", type=int, default=10)
    args.add_argument("--corrupt", action="store_true")
    args = args.parse_args()

    device = "cuda"
    model_name = args.model
    icl_dataset_size = 20
    project_dim = args.project_dim
    bs = args.bs
    num_trials = args.num_trials
    dataset_name = args.dataset_name
    save_dir = f"results/cls_llm_{dataset_name}/"
    os.makedirs(save_dir, exist_ok=True)
    layer_nums = args.layer_nums
    corrupt = args.corrupt
    use_label_pos = args.use_label_pos

    if model_name == "vicuna-7b":
        model_path = "lmsys/vicuna-7b-v1.3"
    elif model_name == "Llama-2-7b":
        model_path = "meta-llama/Llama-2-7b-chat-hf"
    elif model_name == "Llama-2-13b":
        model_path = "meta-llama/Llama-2-13b-chat-hf"
    elif model_name == "mamba":
        model_path = "state-spaces/mamba-2.8b-hf"
    elif model_name == "gpt2-xl":
        model_path = "gpt2-xl"
    elif model_name == "mistral-7b":
        model_path = "mistralai/Mistral-7B-v0.1"
    elif model_name == "falcon-7b":
        # model_path = "tiiuae/falcon-7b"
        model_path = "OpenBuddy/openbuddy-falcon-7b-v6-bf16"
    elif model_name == "wizardlm-7b":
        model_path = "WizardLM/WizardMath-7B-V1.1"
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")

    model, tokenizer = load_model(model_path=model_path, device=device)
    # raw_dataset = load_dataset(dataset_name)
    
    for num_remove in [4,7,10,13,16]:
        acc_orig, acc_rem_high, acc_rem_low, acc_rem_random = run(num_remove)
        print(f"Finishied running for removing {num_remove} data points")
        # save the results
        np.save(os.path.join(save_dir, f"acc_orig_{model_name}_{num_remove}_{'remove' if not corrupt else 'corrupt'}{'_use_label_pos' if use_label_pos else ''}.npy"), acc_orig)
        np.save(os.path.join(save_dir, f"acc_rem_high_{model_name}_{num_remove}_{'remove' if not corrupt else 'corrupt'}{'_use_label_pos' if use_label_pos else ''}.npy"), acc_rem_high)
        np.save(os.path.join(save_dir, f"acc_rem_low_{model_name}_{num_remove}_{'remove' if not corrupt else 'corrupt'}{'_use_label_pos' if use_label_pos else ''}.npy"), acc_rem_low)
        np.save(os.path.join(save_dir, f"acc_rem_random_{model_name}_{num_remove}_{'remove' if not corrupt else 'corrupt'}{'_use_label_pos' if use_label_pos else ''}.npy"), acc_rem_random)
