import os
import pdb
import sys
import time
import json
import torch
import queue
import logging
from tqdm import tqdm
from transformers import LogitsProcessorList
root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)

from src.logits_processor.model_processor_factory import ModelProcessorFactory
from src.post_processing.answer_extract import answer_extract


import argparse
from src.instruction_generate import demon_prompt_generate, task_instruction_generate
from src.main_model_thread import MainModelThread
from src.model_load import load_model
from src.assist_model_thread import AssistModelThread
from src.common_vocabulary import CommonVocabulary
from src.transfer_matrix import ProbabilityTransferMatrix


def main():
    
    parser = argparse.ArgumentParser(description='Process some files.')
    
    parser.add_argument('--config', help='the name of the file to process')
    parser.add_argument('--learning_rate', '-lr', default=0.0, type=float, required=False, help="learning_rate")
    parser.add_argument('--anchor_point_count', '-apc', default=32000, type=int, required=False,
                        help='anchor_point_count')
    parser.add_argument('--learning_epochs_nums', '-len', default=5, type=int, required=False,
                        help='learning_epochs_nums')
    parser.add_argument('--result_save_dir', '-rsd', default="./", type=str, required=False, help='result_save_dir')
    parser.add_argument('--run_mode', '-rm', default="dev", type=str, required=False, help='result_save_dir')
    parser.add_argument('--logits_processor_mode', '-lpm', default="based_on_probility_transfer_logits_processor",
                        type=str,
                        required=False,
                        help='logits_processor_mode')
    parser.add_argument('--device_compute', '-dp', default="cuda:1", type=str, required=False,
                        help='device_compute')
    parser.add_argument('--device0', '-d0', default="auto", type=str, required=False,
                        help='device0')
    parser.add_argument('--device1', '-d1', default="auto", type=str, required=False,
                        help='device1')
    parser.add_argument('--device2', '-d2', default="auto", type=str, required=False,
                        help='device2')
    parser.add_argument('--device3', '-d3', default="auto", type=str, required=False,
                        help='device3')

    parser.add_argument('--main_temperature', '-mt', default=100, type=float, required=False,
                        help='main_temperature')
    parser.add_argument('--assist_temperature', '-at', default=100, type=float, required=False,
                        help='assist_temperature')
    parser.add_argument('--min_prob', default=0.8, type=float, required=False,
                        help='min_prob')
    parser.add_argument('--max_prob', default=0.9, type=float, required=False,
                        help='max_prob')

    # 解析命令行参数
    args = parser.parse_args()

    # 使用指定的文件名来操作文件
    with open(args.config, 'r', encoding='utf-8') as f:
        config_json = json.load(f)

    main_model_path = config_json["model_path"]["main_model_path"]
    assist_model1_path = config_json["model_path"]["assist_model1_path"]
    # assist_model2_path = config_json["model_path"]["assist_model2_path"]
    # assist_model3_path = config_json["model_path"]["assist_model3_path"]
    main_model_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"]["main_model_path"]
    # assist_model1_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"][
    #     "assist_model1_path"]
    dev_file_path = config_json["file_path"]["dev_file_path"]
    test_file_path = config_json["file_path"]["test_file_path"]

    demon_file_path = config_json["file_path"]["demon_file_path"]

    instruction = config_json["prompt_template"]["instruction"]
    instruction_parameter = config_json["prompt_template"]["instruction_parameter"]
    main_model_system_template = config_json["prompt_template"]["main_model_system_template"]
    # assist_model1_system_template = config_json["prompt_template"]["assist_model1_system_template"]
    # assist_model2_system_template = config_json["prompt_template"]["assist_model2_system_template"]
    # assist_model3_system_template = config_json["prompt_template"]["assist_model3_system_template"]
    max_new_tokens = 1
    # start_index = config_json["run_parameter"]["start_index"]
    # end_index = config_json["run_parameter"]["end_index"]

    end_index = config_json["run_parameter"]["end_index"]
    try:
        end_token_id = config_json["run_parameter"]["end_token_id"]
    except:
        end_token_id = 2

    demon_parameter = config_json["prompt_template"]["demon_parameter"]

    result_process_parameter = config_json["result_process_parameter"]
    try:
        early_stop_string_list = result_process_parameter["early_stop_string_list"]
    except:
        early_stop_string_list = None
    result_save_dir = args.result_save_dir
    logits_processor_mode = args.logits_processor_mode
    if os.path.isdir(result_save_dir):
        pass
    else:
        os.makedirs(result_save_dir)

    anchor_point_count = args.anchor_point_count
    learning_rate = args.learning_rate
    learning_epochs_nums = args.learning_epochs_nums
    run_mode = args.run_mode

    device_compute = args.device_compute
    device0 = args.device0
    device1 = args.device1
    device2 = args.device2
    device3 = args.device3

    main_temperature = args.main_temperature
    assist_temperature = args.assist_temperature

    input_file_path = dev_file_path if run_mode == "dev" else test_file_path

    logging.basicConfig(filename=os.path.join(result_save_dir,
                                              f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.process.log'),
                        level=logging.DEBUG)
    logging.info(f'\n【config_json:】{config_json}')
    logging.info(f'\n【result_save_dir:】{result_save_dir}')
    logging.info(f'\n【anchor_point_count:】{anchor_point_count}')
    logging.info(f'\n【learning_rate:】{learning_rate}')
    logging.info(f'\n【learning_epochs_nums:】{learning_epochs_nums}')

    # main_model_probability_transfer_matrix = torch.load(main_model_probability_transfer_matrix_path,
    #                                                     map_location=device0)
    # assist_model_probability_transfer_matrix1 = torch.load(assist_model1_probability_transfer_matrix_path,
    #                                                        map_location=device1)

    main_model, main_model_tokenizer, main_model_streamer = load_model(main_model_path, "auto")

    # assist_model1, assist_model_tokenizer1, _ = load_model(assist_model1_path, "auto")
    # assist_model2, assist_model_tokenizer2, _ = load_model(assist_model2_path, "auto")
    # assist_model3, assist_model_tokenizer3, _ = load_model(assist_model3_path, "auto")
    # pdb.set_trace()
    # common_vocabulary = CommonVocabulary(main_model_tokenizer, assist_model_tokenizer1)
    #
    # common_vocab_list = common_vocabulary.get_common_vocab_list(*common_vocabulary.vocabs)
    #
    # probability_transfer_matrix = ProbabilityTransferMatrix()
    # anchor_point_list = probability_transfer_matrix.get_anchor_point_list(common_vocab_list=common_vocab_list)


    # =============================================================================================================
    result_file_path = os.path.join(result_save_dir,
                                    f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.jsonl')
    try:
        with open(result_file_path, 'r') as file:
            lines = file.readlines()
            line_count = len(lines)
        start_index = line_count
    except FileNotFoundError:
        start_index = 0
    with open(input_file_path, 'r', encoding='utf-8') as input_file:
        try:
            demon_instruction, demon_count = demon_prompt_generate(demon_file_path, demon_parameter)
        except:
            demon_instruction = ""
            demon_count = 0
        contents = input_file.readlines()
        start_time = time.time()  

        for line in tqdm(contents[start_index:]):
            line = json.loads(line)

            task_instruction = task_instruction_generate(line, instruction_parameter)
            final_input_prompt = instruction + demon_instruction + task_instruction
            main_model_input = main_model_system_template.format(final_input_prompt)

            information_key_list = demon_parameter['key']
            information_dict = {}
            for key in information_key_list:
                information_dict[key] = line[key]
            information_dict['main_model_input'] = main_model_input
            information_dict['demon_count'] = demon_count
            information_dict['task_instruction'] = task_instruction
            information_dict['max_new_tokens'] = 1
            information_dict['result_process_parameter'] = result_process_parameter
            information_dict['logits_processor_mode'] = logits_processor_mode
            information_dict['anchor_point_list'] = "all"
            information_dict['forced_eos_token_id'] = end_token_id
            main_model_logits_processor_list = LogitsProcessorList()

            processor_factory = ModelProcessorFactory()

            # 传递其他参数
            additional_kwargs = {
                "learning_rate": learning_rate, "anchor_point_count": anchor_point_count,
                "ensemble_model_output_ids_queue": None,
                "assist_model_score_queue_list": [],
                "learning_epochs_nums": learning_epochs_nums,
                "main_model_probability_transfer_matrix_list": None,
                "assist_model_probability_transfer_matrix_list": None,
                "result_save_dir": result_save_dir,
                "main_model_tokenizer": main_model_tokenizer,
                "assist_model_tokenizer": None,
                "device": device0,
                "device_compute": device_compute,
                "early_stop_string_list":early_stop_string_list,
                "forced_eos_token_id": information_dict['forced_eos_token_id']
            }
            # 创建对象
            logits_processor_mode = information_dict['logits_processor_mode']
            logits_processor_instance = processor_factory.create_processor(logits_processor_mode,
                                                                           **additional_kwargs)
            main_model_logits_processor_list.append(logits_processor_instance)
            # main_model_logits_processor_list.append()

            main_model_input = information_dict['main_model_input']
            max_new_tokens = information_dict['max_new_tokens']
            forced_eos_token_id = information_dict['forced_eos_token_id']
            main_model_input_ids = main_model_tokenizer(main_model_input, return_tensors="pt",
                                                  add_special_tokens=False).input_ids.to(device0)
            generation_kwargs = {
                "input_ids": main_model_input_ids,
                "max_new_tokens": max_new_tokens,
                "do_sample": False,
                "num_beams": 1,
                "eos_token_id": main_model_tokenizer.eos_token_id,
                "bos_token_id": main_model_tokenizer.bos_token_id,
                # "pad_token_id": tokenizer.pad_token_id,
                # "forced_eos_token_id": [forced_eos_token_id]
            }

            # generate_ids = model.generate(**generation_kwargs,pad_token_id=tokenizer.eos_token_id,
            #                                    logits_processor=main_model_logits_processor_list,
            #                                    streamer=model_streamer)
            generate_ids = main_model.generate(**generation_kwargs, pad_token_id=main_model_tokenizer.eos_token_id,
                                               logits_processor=main_model_logits_processor_list)

            text = main_model_tokenizer.decode(generate_ids[0])
            # print(text)
            result_process_parameter = information_dict['result_process_parameter']
            split_key_before_list = result_process_parameter["split_key_before"]
            split_key_behind_list = result_process_parameter["split_key_behind"]

            model_answer, prediction = answer_extract(text, information_dict['demon_count'], split_key_before_list,
                                                      split_key_behind_list)
            print(information_dict['question'])
            print(prediction.strip())
            model_answer_dict = {'answer': information_dict['answer'],
                                 'prediction': prediction.strip(), 'main_model_input': main_model_input, 'all': text,
                                 'model_answer': model_answer,
                                 'question': information_dict['question']}

            result_file_path = os.path.join(result_save_dir,
                                            f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.jsonl')
            with open(result_file_path, 'a+', encoding='utf-8') as result_file:
                result_file.write(json.dumps(model_answer_dict, ensure_ascii=False) + '\n')

    time_elapsed = time.time() - start_time  # 获得时间差
    minutes = int(time_elapsed / 60)
    seconds = int(time_elapsed % 60)
    logging.info('\nTime taken: {} min {} sec'.format(minutes, seconds))

    print('Time taken: {} min {} sec'.format(minutes, seconds))


if __name__ == '__main__':
    main()
