""" 
Utility script to collect all neurons that were used in an experiment, according to the json file describing it.
"""

import json
import os
import sys
from argparse import ArgumentParser
from collections import Counter

import numpy as np


def extract_unit(query):
    query = query[len("/int_comp_out/stimuli/") :]
    # eg. model/layer/channel/natural_images/batch_0
    model, layer, channel, nat, batch = query.split("/")
    channel = int(channel[len("channel_") :])
    return f"{layer}__{channel}"


def main(args):
    with open(args.exp_json_file, "r") as fhandle:
        experiment = json.load(fhandle)

    used_units = []

    tasks = experiment["tasks"]
    print(f"Found {len(tasks)} tasks.")
    for task in tasks:
        trials = task["raw_trials"]
        print(f"Found {len(trials)} trials.")
        for trial in trials:
            queries = trial["queries"]
            unit = extract_unit(queries)
            used_units.append(unit)

    print("Length of used units total: ", len(used_units))
    counter = Counter(used_units)
    for key, val in counter.items():
        print(key, ": ", val)
    used_units = set(used_units)
    print("Length of actually used units:", len(used_units))

    res_dict = {"units": sorted(list(used_units))}

    with open(args.units_file, "w") as fhandle:
        json.dump(res_dict, fhandle)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--exp_json_file", "-e", type=str)
    parser.add_argument("--units_file", "-u", type=str)

    args = parser.parse_args()

    main(args)
