from glob import glob as glob
import h5py
import logging
from omegaconf import DictConfig, OmegaConf
import hydra
import os 
import json
import numpy as np
from sklearn.decomposition import PCA
from pathlib import Path

log = logging.getLogger(__name__)
@hydra.main(config_path="../conf")
def main(cfg: DictConfig) -> None:
    log.info("PCA compression of fMRI")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    log.info(f'Working directory {os.getcwd()}')

    if "sessions" in cfg.data_prep:
        sessions = cfg.data_prep.sessions
    else:
        sessions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]
    subject = cfg.data_prep.subject

    DATA_TRAIN_DIR = "/storage/user/semantic-decoding-brainbits/data_train"

    # training stories
    stories = []
    with open(os.path.join(DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f:
        sess_to_story = json.load(f) 
    for sess in sessions:
        stories.extend(sess_to_story[str(sess)])

    subject_dir = os.path.join(DATA_TRAIN_DIR, "train_response", subject, cfg.data_prep.in_dir)
    resp = {}
    for story in stories:
        log.info(f"reading data {story}")
        resp_path = os.path.join(subject_dir, "%s.hf5" % story)
        hf = h5py.File(resp_path, "r")
        resp[story] = np.nan_to_num(hf["data"][:])
        hf.close()

    all_stack = np.vstack([resp[story] for story in stories]) 
    pca = PCA(n_components=cfg.data_prep.n_components)
    pca.fit(all_stack)

    out_dir = f"pca_{cfg.data_prep.n_components}"
    if 'out_dir' in cfg.data_prep:
        out_dir = cfg.data_prep.out_dir

    out_dir_path = os.path.join(DATA_TRAIN_DIR, "train_response", subject, out_dir)
    Path(out_dir_path).mkdir(exist_ok=True, parents=True)

    for story in stories:
        log.info(f"writing data {story}")
        resp_path = os.path.join(subject_dir, "%s.hf5" % story)
        hf = h5py.File(resp_path, "r")
        data = np.nan_to_num(hf["data"][:])
        transformed = pca.transform(data)
        hf.close()

        out_path = os.path.join(out_dir_path, f"{story}.hf5")
        hf_out = h5py.File(out_path, "w")
        hf_out.create_dataset('data', data=transformed)
        hf_out.close()

    root_path = "/storage/user/semantic-decoding-brainbits"
    exps = ["imagined_speech",  "perceived_movie",  "perceived_multispeaker",  "perceived_speech"]
    for exp in exps:
        tasks = glob(os.path.join(root_path, "data_test", "test_response", subject, exp, 'orig', "*"))
        out_dir_path = os.path.join(root_path, "data_test", "test_response", subject, exp, out_dir)
        Path(out_dir_path).mkdir(exist_ok=True, parents=True)

        for task_path in tasks:
            task_name = Path(task_path).stem
            log.info(f"writing {task_name}")
            hf = h5py.File(task_path, "r")
            data = np.nan_to_num(hf["data"][:])
            transformed = pca.transform(data)
            hf.close()

            out_path = os.path.join(out_dir_path, f"{task_name}.hf5")
            hf_out = h5py.File(out_path, "w")
            hf_out.create_dataset('data', data=transformed)
            hf_out.close()

if __name__ == "__main__":
    main()
