import os
import fire
import json
import tqdm
from datasets import get_dataset_config_names, load_dataset


def main(cache_dir='./cache', output_dir='./bigbench_data'):
    os.makedirs(output_dir, exist_ok=True)

    subset_names = []
    for subset_name in tqdm.tqdm(
            get_dataset_config_names('tasksource/bigbench'),
            desc='processing datasets'):
        try:
            ds = load_dataset(
                'tasksource/bigbench', subset_name, cache_dir=cache_dir)
        except:
            print(subset_name, 'Error!')
            continue

        if len(ds['train'][0]['multiple_choice_targets']) > 0:
            continue
        else:
            subset_names.append(subset_name)

        origonal_train_filename = \
            f'{output_dir}/{subset_name}_train_original.jsonl'
        with open(origonal_train_filename, 'w') as train_file:
            for example_idx, example in enumerate(ds['train']):
                train_file.write(json.dumps({
                    'dataset': subset_name,
                    'example_idx': example_idx,
                    'instruction': example['inputs'],
                    'references': example['targets']
                }) + '\n')

        generator_train_filename = \
            f'{output_dir}/{subset_name}_train_generator.jsonl'
        with open(generator_train_filename, 'w') as train_file:
            for example_idx, example in enumerate(ds['train']):
                for tgt in example['targets']:
                    train_file.write(json.dumps({
                        'dataset': subset_name,
                        'example_idx': example_idx,
                        'instruction': example['inputs'],
                        'response': tgt,
                        'references': example['targets']
                    }) + '\n')

        scorer_train_filename = f'{output_dir}/{subset_name}_train_scorer.jsonl'
        with open(scorer_train_filename, 'w') as train_file:
            for example_idx, example in enumerate(ds['train']):
                for tgt in example['targets']:
                    train_file.write(json.dumps({
                        'dataset': subset_name,
                        'example_idx': example_idx,
                        'instruction': example['inputs'],
                        'response': tgt,
                        'references': example['targets'],
                        'label': 1.
                    }) + '\n')

                for tgt, score in zip(
                        example['multiple_choice_targets'],
                        example['multiple_choice_scores']):
                    if tgt not in example['targets']:
                        train_file.write(json.dumps({
                            'dataset': subset_name,
                            'example_idx': example_idx,
                            'instruction': example['inputs'],
                            'response': tgt,
                            'references': example['targets'],
                            'label': score
                        }) + '\n')

        validation_filename = f'{output_dir}/{subset_name}_validation.jsonl'
        with open(validation_filename, 'w') as validation_file:
            for example_idx, example in enumerate(ds['validation']):
                validation_file.write(json.dumps({
                    'dataset': subset_name,
                    'example_idx': example_idx,
                    'instruction': example['inputs'],
                    'references': example['targets']
                }) + '\n')

    json.dump(
        subset_names, open(f'{output_dir}/subset_names.json', 'w'), indent=4)


if __name__ == '__main__':
    fire.Fire(main)
