
import os
import sys
import json
import numpy as np
import random
import argparse

from tqdm import tqdm

random.seed(42)

parser = argparse.ArgumentParser()

parser.add_argument(
    "--dataset",
    type=str,
    default='nsd',
)

parser.add_argument(
    "--subject",
    type=str,
    default='subj01'
)

args = parser.parse_args()


def main():
    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'
    source = json.load(open(f'{root_dir}/fmris/{args.subject}/{args.dataset}_fmri2image.json', 'r'))
    coco_caption = json.load(open(f'{root_dir}/{args.dataset}_captions.json', 'r'))  #  _coco

    train_dict = {}
    counter = 0
    for idx, image_id in enumerate(source['train']):
        if image_id not in train_dict:
            train_dict[image_id] = {
                'ids': counter,
                'subject': args.subject,
                'image': f'{root_dir}/images/{args.dataset}_image_{image_id:06}.png',
                'fmri': [f'{root_dir}/fmris/{args.subject}/whole/{args.dataset}_betas_tr_{idx:06}.npy'],
                'vision_embeds': f'{root_dir}/vision_embeds/vision_{image_id:06}.npy',
                'caption': coco_caption[image_id]["captions"]  # coco_
            }
            counter += 1
        else:
            train_dict[image_id]['fmri'].append(f'{root_dir}/fmris/{args.subject}/whole/{args.dataset}_betas_tr_{idx:06}.npy')

    val_dict = {}
    counter = 0
    for idx, image_id in enumerate(source['val']):
        if image_id not in val_dict:
            val_dict[image_id] = {
                'ids': counter,
                'subject': args.subject,
                'image': f'{root_dir}/images/{args.dataset}_image_{image_id:06}.png',
                'fmri': [f'{root_dir}/fmris/{args.subject}/whole/{args.dataset}_betas_te_{idx:06}.npy'],
                'vision_embeds': f'{root_dir}/vision_embeds/vision_{image_id:06}.npy',
                'caption': coco_caption[image_id]["captions"]  # coco_
            }
            counter += 1
        else:
            val_dict[image_id]['fmri'].append(f'{root_dir}/fmris/{args.subject}/whole/{args.dataset}_betas_te_{idx:06}.npy')

    train_list = list(train_dict.values())
    test_list = list(val_dict.values())

    mean = f'{root_dir}/fmris/{args.subject}/whole/{args.dataset}_whole_betas_mean.npy'
    std = f'{root_dir}/fmris/{args.subject}/whole/{args.dataset}_whole_betas_std.npy'
    atlas = f'{root_dir}/fmris/{args.subject}/atlas.json' if os.path.exists(f'{root_dir}/fmris/{args.subject}/atlas.json') else None

    with open(f'{root_dir}/fmris/{args.subject}/pretrain_new.json', 'w') as f:
        json.dump({
            'mean': mean,
            'std': std,
            'atlas': atlas,
            'train': train_list,
            'val': test_list
        }, f, indent=4)


if __name__ == '__main__':
    main()
