import collections
import json
import os
import sys


def extract_top_axtree_roles(folder: str) -> None:
    role_frequency = collections.Counter()

    def calc_frequency(file_path: str) -> dict[str, int]:
        data = []
        with open(file_path, "r") as f:
            for line in f:
                data.append(json.loads(line))
        for element in data:
            tree = element["ax_tree"]
            for node in tree.split("\n"):
                toks = node.split()
                if len(toks) < 2:
                    continue
                role = toks[1]
                role_frequency[role] += 1

    for index in range(10):
        calc_frequency(f"{folder}/en0000-01.{index:02d}.jsonl")

    # sort the roles by frequency
    role_frequency = dict(
        sorted(role_frequency.items(), key=lambda item: item[1], reverse=True)
    )

    tot = sum(role_frequency.values())

    role_stats = {k: [v, v / tot] for k, v in role_frequency.items()}

    with open("data_vis/stats/top_axtree_roles.json", "w") as f:
        json.dump(role_stats, f, indent=4)


if __name__ == "__main__":
    extract_top_axtree_roles("data/clueweb")
