import os
import sys
import random
import json
import argparse
import pprint
import pickle

import numpy as np
from common_utils import Logger
# from utils import get_all_files
# from analyze_dataset import filter_bad_replays
# import global_consts as gc
# from process_instruction import *
# import inst_dict

from create_dataset import *



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=99)
    parser.add_argument('--val_ratio', type=float, default=0.1)
    parser.add_argument('--min_num_target', type=int, default=0)
    parser.add_argument('--min_num_instruction', type=int, default=0)
    parser.add_argument('--raw_json_root', type=str, required=True,
                        help='used to decide which to filter')
    parser.add_argument('--processed_json_root', type=str, required=True,
                        help='used to create the dataset')
    parser.add_argument('--output', type=str, required=True)

    args = parser.parse_args()
    logger_path = os.path.join(args.output, 'config')
    sys.stdout = Logger(logger_path)

    print('configs:')
    pprint.pprint(vars(args))

    train_files = json.load(open('ref_train.json'))
    val_files = json.load(open('ref_valid.json'))

    def modify_path(path):
        path = path.replace(
            '/private/home/hengyuan/rts-replays/replays',
            args.processed_json_root)
        path = path + '.p0.json'
        return path

    train_files = [modify_path(f) for f in train_files]
    val_files = [modify_path(f) for f in val_files]

    trainset = create_dataset(train_files)
    add_base_frame(trainset)
    valset = create_dataset(val_files)
    add_base_frame(valset)

    print('len(trainset) =', len(trainset))
    print('len(valset) =', len(valset))

    inst_dict, trainset, valset = \
        create_dictionary_and_correct_instruction(trainset, valset)
    pickle.dump(inst_dict, open(os.path.join(args.output, 'dict.pt'), 'wb'))

    print('writing dev to file')
    devset = valset[:2000]
    with open(os.path.join(args.output, 'dev.json'), 'w') as f:
        json.dump(devset, f)

    print('writing val to file')
    with open(os.path.join(args.output, 'val.json'), 'w') as f:
        json.dump(valset, f)
    with open(os.path.join(args.output, 'val_files.json'), 'w') as f:
        json.dump(val_files, f, indent=4)

    print('writing train to file')
    with open(os.path.join(args.output, 'train.json'), 'w') as f:
        json.dump(trainset, f)
    with open(os.path.join(args.output, 'train_files.json'), 'w') as f:
        json.dump(train_files, f, indent=4)
