import datasets
from src.utils.file_utils import load_jsonl

datasets.disable_caching()
import random


def shuffle(data: datasets.Dataset):
    print("Convert to list")
    data = data.to_list()
    print("Shuffle")
    random.shuffle(data)
    print("Convert back")
    return datasets.Dataset.from_list(data)


def merge_two_stage_train_data(stage1_paths, stage2_paths, save_path):
    print("Stage 1")
    all_samples = []
    if len(stage1_paths):
        for i, data_path in enumerate(stage1_paths):
            ds = load_jsonl(data_path)
            all_samples.extend(ds)
        random.shuffle(all_samples)
        stage1_ds = datasets.Dataset.from_list(all_samples)
    if len(stage2_paths):
        all_samples = []
        for i, data_path in enumerate(stage1_paths):
            ds = load_jsonl(data_path)
            all_samples.extend(ds)
        stage2_ds = datasets.Dataset.from_list(all_samples)
    if len(stage1_paths) and len(stage2_paths):
        ds = datasets.concatenate_datasets([stage1_ds, stage2_ds])
    else:
        ds = stage1_ds if len(stage1_paths) else stage2_ds
    print(save_path)
    ds.save_to_disk(save_path)


if __name__ == "__main__":
    merge_two_stage_train_data(
        [
            "data/cot_pretrain_data_sampled.jsonl",
            "data/tool_pretrain_data_sampled.jsonl",
        ],
        ["data/finetune_data_sampled.jsonl"],
        "data/train_data_sampled"
    )
