import os
import json
import re
import numpy as np

##### Config #####
PATH = [
    "../data/mutation/gsm8k_train_l0.json",
    "../data/mutation/gsm8k_train_l1.json",
    "../data/mutation/gsm8k_train_l2.json",
    "../data/mutation/gsm8k_train_l3.json",
    "../data/mutation/gsm8k_train_l4.json",
]
SIZE = [
    30000,
    70000,
    70000,
    70000,
    70000,
]
RESULT = "../data/mutation/gsm8k_train_sample.json"
##### Config Config

datasets = []
for path, siz in zip(PATH, SIZE):
    with open(path, "r") as fr:
        problems = json.load(fr)
    
    assert len(problems) >= siz
    print(f"{path}: {len(problems)} => {siz}")
    rng = np.random.default_rng(seed=998244353)
    rng.shuffle(problems)
    problems = problems[: siz]

    datasets += problems

print(f"{RESULT}: {len(datasets)}")
rng = np.random.default_rng(seed=998244353)
rng.shuffle(datasets)
with open(RESULT, "w") as fw:
    json.dump(datasets, fw)