import copy
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from collections import defaultdict

import torch.types
from tqdm import tqdm

from utils.utils import set_seed, get_promt_rouge
from datas.get_data import get_data
from torch.utils.data import DataLoader
from utils.utils import get_model
from utils.utils import get_promt
import torch.nn.functional as F
from utils.utils import compare_retrieval_acc

import numpy as np
import os

from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

model_custom_config = {
    "max_new_tokens": 1000,
    "temperature": 0.1,
    "top_p": 0.9
}

default_dict = {
                    "record": [],
                    "bleu": [],
                    "rouge-1": [],
                    "rouge-2": [],
                    "rouge-l": [],
                    "bleu-mean": 0.0,
                    "rouge-1-mean": 0.0,
                    "rouge-2-mean": 0.0,
                    "rouge-l-mean": 0.0,
                    "bleu-var": 0.0,
                    "rouge-1-var": 0.0,
                    "rouge-2-var": 0.0,
                    "rouge-l-var": 0.0
                }

def main(args):

    dataset = get_data(args.dataset, args)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    if args.cuda == "auto":
        device = "auto"
    else:
        device = torch.device(int(args.cuda))
    tokenizer, model = get_model(args.model_path, device, method=args.method, args=args)

    # prefix_prompt = get_promt(args.model_path)
    prefix_prompt = get_promt_rouge(args.model_path)

    all_length_acc = defaultdict(list)
    all_length_score = defaultdict(dict)
    rouge = Rouge()

    records = [0]*20

    pbar = tqdm(dataloader)
    count = 0
    for data in pbar:
        # if count > 5:
        #     break
        count += 1
        model.eval()
        with torch.no_grad():
            # query = prefix_prompt + data["text"][0] + "\n\n"
            query = prefix_prompt.format(data["text"][0])


            # 测试llama2-7b-chat, prompt
            SYSTEMINFO = """<s>[INST] 
Please help me summarize the summary from the text entered below, and require no more than 1000 tokens.


{}
[/INST]"""
            query = SYSTEMINFO.format(data["text"][0])


            # query = data["text"][0]
            inputs_token = tokenizer(query, return_tensors="pt").to(model.device)
            input_ids = inputs_token.input_ids
            input_token_len = len(input_ids[0])


            target = data["target"][0]
            target_token = tokenizer(target, return_tensors="pt").to(model.device)
            target_input_ids = target_token.input_ids
            target_token_len = len(target_input_ids[0])

            # if input_token_len < 17*1024 and target_token_len < 1000:
            # if input_token_len < 12 * 1024 and target_token_len < 1000:
            if input_token_len < 13 * 1024 and target_token_len < 1000:
                index = input_token_len // 1024
                if records[index] < 5: #8:
                    records[index] += 1
                else:
                    continue
            else:
                continue

            print("input token length: {}, real: {}".format(index * 1024, input_token_len))

            outputs = model.generate(input_ids, **model_custom_config)

            response = tokenizer.decode(outputs[0])[len(query):]
            target = data["target"][0]

            total_len = index * 1024
            all_length_score[total_len] = all_length_score.get(total_len, copy.deepcopy(default_dict))

            all_length_score[total_len]["record"].append(
                {
                    "response": response,
                    "target": target
                }
            )

            reference = [target]
            generated = [response]
            # 计算rouge

            print("rouge_score = rouge.get_scores(hyps=generated, refs=reference)")
            rouge_score = rouge.get_scores(hyps=generated, refs=reference)
            # print(rouge_score[0]["rouge-1"])
            # print(rouge_score[0]["rouge-2"])
            # print(rouge_score[0]["rouge-l"])
            all_length_score[total_len]["rouge-1"].append(rouge_score[0]["rouge-1"])
            all_length_score[total_len]["rouge-2"].append(rouge_score[0]["rouge-2"])
            all_length_score[total_len]["rouge-l"].append(rouge_score[0]["rouge-l"])
            # 计算BLEU
            reference = [words.split(" ") for words in reference]
            generated = generated[0].split(" ")
            bleu = sentence_bleu(reference, generated)
            print("BLEU Score:", bleu)
            all_length_score[total_len]["bleu"].append(bleu)

            # all_length_score[start]["rouge-1-mean"] = np.nanmean(np.array(all_length_score[start]["rouge-1"]))
            # all_length_score[start]["rouge-1-var"] = np.nanmean(np.array(all_length_score[start]["rouge-1"]))
            # all_length_score[start]["rouge-2-mean"] = np.nanmean(np.array(all_length_score[start]["rouge-2"]))
            # all_length_score[start]["rouge-2-var"] = np.nanmean(np.array(all_length_score[start]["rouge-2"]))
            # all_length_score[start]["rouge-l-mean"] = np.nanmean(np.array(all_length_score[start]["rouge-l"]))
            # all_length_score[start]["rouge-l-var"] = np.nanmean(np.array(all_length_score[start]["rouge-l"]))
            all_length_score[total_len]["bleu-mean"] = np.nanmean(np.array(all_length_score[total_len]["bleu"]))
            all_length_score[total_len]["bleu-var"] = np.nanmean(np.array(all_length_score[total_len]["bleu"]))

            torch.cuda.empty_cache()



        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{args.save_file}", "wb") as f:
            pickle.dump({"all_length_score": all_length_score}, f)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/llama-3b")
    parser.add_argument("--method", type=str, default="old")
    parser.add_argument("--dataset", type=str, default="/data/persist/dataset/gov_report/test.txt")
    parser.add_argument("--save_file", type=str, default="rouge_old_test.pkl")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--cuda", type=str, default="0")
    parser.add_argument("--seed", type=int, default=0)


    args = parser.parse_args()
    set_seed(args.seed)
    main(args)

















