from datasets import load_dataset
import os
import json

from tqdm import tqdm

from utils.utils import get_tokenizer


def get_data(data_path, args=None):
    if "pile" in data_path:
        file = "../datas/download_data/pile-deduplicated/train-00000-of-01650-f70471ee3deb09c0.parquet"
        file_path = os.path.join(os.getcwd(), file)
        data = load_dataset("parquet", data_files={'train': file_path})['train']
    elif "passkey" in data_path:
        file_path = os.path.join(os.getcwd(), data_path)
        # data = load_dataset("parquet", data_files={'train': file_path})['train']
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    elif "gov_report" in data_path:
        assert args != None
        tokenizer = get_tokenizer(args.model_path)
        if os.path.isabs(data_path):
            file_path = data_path
        else:
            file_path = os.path.join(os.getcwd(), data_path)
        with open(file_path, "r") as f:
            data = list(map(json.loads, tqdm(f.readlines())))

        # 修改dataloader内的函数
        from torch.utils.data.dataloader import _SingleProcessDataLoaderIter
        def new_next_data(self):
            index = self._next_index()  # may raise StopIteration
            data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
            data["text"] = data.pop("report")
            data["target"] = data.pop("summary")
            data["token_length"] = len(tokenizer.encode(data["text"][0]))
            if self._pin_memory:
                data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
            return data

        _SingleProcessDataLoaderIter._next_data = new_next_data

    elif "longeval-topics" in data_path:
        tokenizer = get_tokenizer(args.model_path)
        def load_testcases(test_file):
            with open(test_file, 'r') as json_file:
                json_list = list(json_file)
            test_cases = []
            for test_case in json_list:
                test_case = json.loads(test_case)
                test_cases.append(test_case)
            return test_cases
        test_cases = []
        for num_topics in [5, 10, 15, 20, 25]:
            file = f"../datas/longeval/topics/testcases/{num_topics}_topics.jsonl"
            file_path = os.path.join(os.getcwd(), file)
            test_case = load_testcases(file_path)
            test_cases.extend(test_case)
        data = test_cases
        # 修改dataloader内的函数
        from torch.utils.data.dataloader import _SingleProcessDataLoaderIter
        def new_next_data(self):
            index = self._next_index()  # may raise StopIteration
            data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
            data["text"] = data.pop("prompt")
            data["target"] = data.pop("topics")
            data["token_length"] = len(tokenizer(data["text"]).input_ids)

            return data

        _SingleProcessDataLoaderIter._next_data = new_next_data

    elif "longeval-lines" in data_path:
        tokenizer = get_tokenizer(args.model_path)
        def load_testcases(test_file):
            with open(test_file, 'r') as json_file:
                json_list = list(json_file)
            test_cases = []
            for test_case in json_list:
                test_case = json.loads(test_case)
                test_cases.append(test_case)
            return test_cases
        test_cases = []
        for num_lines in [200, 300, 400, 500, 600, 680]:
            file = f"../datas/longeval/lines/testcases/{num_lines}_lines.jsonl"
            file_path = os.path.join(os.getcwd(), file)
            test_case = load_testcases(file_path)
            test_cases.extend(test_case)
        data = test_cases
        # 修改dataloader内的函数
        from torch.utils.data.dataloader import _SingleProcessDataLoaderIter
        def new_next_data(self):
            index = self._next_index()  # may raise StopIteration
            data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
            data["text"] = data.pop("prompt")
            data["target"] = data.pop("expected_number")
            data["token_length"] = len(tokenizer(data["text"]).input_ids)
            # prompt += f'Line <{test_case["random_idx"][0]}>: <REGISTER_CONTENT> is'
            data["random_idx"] = data.pop("random_idx")

            return data

        _SingleProcessDataLoaderIter._next_data = new_next_data

    else:
        raise NotImplementedError
    return data