import json
import os
import pdb
import random
import sys

import torch
from tqdm import tqdm

root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)

from src.model_load import load_tokenizer, load_model
from src.transfer_matrix.common_vocabulary import CommonVocabulary
from src.transfer_matrix.transfer_matrix import ProbabilityTransferMatrix


def probability_transfer_matrix_save(model, anchor_point_index_list, temperature, device_compute, device, save_path):
    
    # pdb.set_trace()
    probability_transfer_matrix = probability_transfer_matrix_obj.get_final_probability_transfer_matrix(
        model=model, anchor_point_list=anchor_point_index_list, temperature=temperature,
        device_compute=device_compute, device=device)

    # if not os.path.isdir(os.path.basename(save_path)):
    #     os.makedirs(save_path)
    torch.save(probability_transfer_matrix.float(), save_path)
    

model1_path = "/data3/username/ModelsHub/Nanbeige/Nanbeige-16B-Base"
model2_path = "/data3/username/ModelsHub/01-ai/Yi-6B-hf"
model3_path = "/data3/username/ModelsHub/Skywork/Skywork-13B-base"
model4_path = "/data5/.cache/AI-ModelScope/Mixtral-8x7B-v0___1"
model5_path = "/data3/username/ModelsHub/NousResearch/Llama-2-70b-hf"
model6_path = "/data3/username/ModelsHub/TigerResearch/tigerbot-13b-base-v2"
model7_path = "/data3/username/ModelsHub/mistralai/Mistral-7B-v0.1"
model8_path = "/data3/username/ModelsHub/internlm-20b"

model9_path = "/data3/username/ModelsHub/Llama-2-13b-hf"


model_path_list = [
    model1_path,
    model2_path,
    # model3_path,
    # model4_path,
    # model5_path,
    # model6_path,
    # model7_path,
    # model8_path,
    # model9_path,
]

probability_transfer_matrix_name_list = []
probability_transfer_matrix_save_path = f"/data2/username/Experiments/LLM_ensemble/probability_transfer_matrix_0514/"

for model_path in model_path_list:
    last_part = os.path.basename(model_path)
    probability_transfer_matrix_name_list.append(last_part)
    probability_transfer_matrix_save_path += last_part + "_"
probability_transfer_matrix_save_path = probability_transfer_matrix_save_path[:-1]


temperature = 100

if len(model_path_list) == 2:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer)
elif len(model_path_list) == 3:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    model3_tokenizer = load_tokenizer(model3_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    print(len(model3_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer, model3_tokenizer)
elif len(model_path_list) == 4:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    model3_tokenizer = load_tokenizer(model3_path)
    model4_tokenizer = load_tokenizer(model4_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    print(len(model3_tokenizer.get_vocab()))
    print(len(model4_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer, model3_tokenizer, model4_tokenizer)
elif len(model_path_list) == 5:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    model3_tokenizer = load_tokenizer(model3_path)
    model4_tokenizer = load_tokenizer(model4_path)
    model5_tokenizer = load_tokenizer(model5_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    print(len(model3_tokenizer.get_vocab()))
    print(len(model4_tokenizer.get_vocab()))
    print(len(model5_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer, model3_tokenizer, model4_tokenizer,
                                         model5_tokenizer)
elif len(model_path_list) == 6:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    model3_tokenizer = load_tokenizer(model3_path)
    model4_tokenizer = load_tokenizer(model4_path)
    model5_tokenizer = load_tokenizer(model5_path)
    model6_tokenizer = load_tokenizer(model6_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    print(len(model3_tokenizer.get_vocab()))
    print(len(model4_tokenizer.get_vocab()))
    print(len(model5_tokenizer.get_vocab()))
    print(len(model6_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer, model3_tokenizer, model4_tokenizer,
                                         model5_tokenizer, model6_tokenizer)
elif len(model_path_list) == 7:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    model3_tokenizer = load_tokenizer(model3_path)
    model4_tokenizer = load_tokenizer(model4_path)
    model5_tokenizer = load_tokenizer(model5_path)
    model6_tokenizer = load_tokenizer(model6_path)
    model7_tokenizer = load_tokenizer(model7_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    print(len(model3_tokenizer.get_vocab()))
    print(len(model4_tokenizer.get_vocab()))
    print(len(model5_tokenizer.get_vocab()))
    print(len(model6_tokenizer.get_vocab()))
    print(len(model7_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer, model3_tokenizer, model4_tokenizer,
                                         model5_tokenizer, model6_tokenizer, model7_tokenizer)
elif len(model_path_list) == 8:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    model3_tokenizer = load_tokenizer(model3_path)
    model4_tokenizer = load_tokenizer(model4_path)
    model5_tokenizer = load_tokenizer(model5_path)
    model6_tokenizer = load_tokenizer(model6_path)
    model7_tokenizer = load_tokenizer(model7_path)
    model8_tokenizer = load_tokenizer(model8_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    print(len(model3_tokenizer.get_vocab()))
    print(len(model4_tokenizer.get_vocab()))
    print(len(model5_tokenizer.get_vocab()))
    print(len(model6_tokenizer.get_vocab()))
    print(len(model7_tokenizer.get_vocab()))
    print(len(model8_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer, model3_tokenizer, model4_tokenizer,
                                         model5_tokenizer, model6_tokenizer, model7_tokenizer, model8_tokenizer)
elif len(model_path_list) == 9:
    model1_tokenizer = load_tokenizer(model1_path)
    model2_tokenizer = load_tokenizer(model2_path)
    model3_tokenizer = load_tokenizer(model3_path)
    model4_tokenizer = load_tokenizer(model4_path)
    model5_tokenizer = load_tokenizer(model5_path)
    model6_tokenizer = load_tokenizer(model6_path)
    model7_tokenizer = load_tokenizer(model7_path)
    model8_tokenizer = load_tokenizer(model8_path)
    model9_tokenizer = load_tokenizer(model9_path)
    print(len(model1_tokenizer.get_vocab()))
    print(len(model2_tokenizer.get_vocab()))
    print(len(model3_tokenizer.get_vocab()))
    print(len(model4_tokenizer.get_vocab()))
    print(len(model5_tokenizer.get_vocab()))
    print(len(model6_tokenizer.get_vocab()))
    print(len(model7_tokenizer.get_vocab()))
    print(len(model8_tokenizer.get_vocab()))
    print(len(model9_tokenizer.get_vocab()))
    common_vocabulary = CommonVocabulary(model1_tokenizer, model2_tokenizer, model3_tokenizer, model4_tokenizer,
                                         model5_tokenizer, model6_tokenizer, model7_tokenizer, model8_tokenizer,
                                         model9_tokenizer)

common_vocab_list = common_vocabulary.get_common_vocab_list(*common_vocabulary.vocabs)
print(f"common_vocab_list:{len(common_vocab_list)}")
# pdb.set_trace()
# random_seed = 20
# anchor_point_count = 1000
# random.seed(random_seed)
# common_vocab_list = random.sample(common_vocab_list, anchor_point_count)
# probability_transfer_matrix_save_path += f"/anchor_point_count_{anchor_point_count}random_seed_{random_seed}"

print(probability_transfer_matrix_save_path)
try:
    os.makedirs(probability_transfer_matrix_save_path)
except:
    pass

with open(probability_transfer_matrix_save_path + "/common_vocab_list.json", "w", encoding="utf8") as f:
    f.write(json.dumps(common_vocab_list, ensure_ascii=False))

probability_transfer_matrix_obj = ProbabilityTransferMatrix()
anchor_point_index_list = probability_transfer_matrix_obj.get_anchor_point_list(
    common_vocab_list=common_vocab_list)
# pdb.set_trace()
for index, model_path in enumerate(model_path_list):
    model = load_model(model_path, "balanced_low_0")
    probability_transfer_matrix_save(model, anchor_point_index_list[index], temperature, device_compute="cuda:0",
                                     device="cuda:0",
                                     save_path=os.path.join(probability_transfer_matrix_save_path,
                                                            f"{os.path.basename(model_path)}.pth"))
