import json

from datasets import load_from_disk, load_dataset
import pathlib
import fire

from measure_coverage_patch import log, extract_coverages_from_eval_output, load_eval_outputs, BLACKLIST, BLACKLIST_FULL


def extract_total_coverage(coverage):
    return sum(h > 0 for lines in coverage.values() for l, h in lines), sum(len(lines) for lines in coverage.values()), len(coverage)

from nltk.tokenize import word_tokenize


def main(
    eval_output_dir: str = "evaluation_output/swt_lite_golden_test/mode_vanillafuzzy",
    dataset: str = "princeton-nlp/SWE-bench_Lite",
    split: str = "test",
):
    eval_outputs_by_instance = load_eval_outputs(eval_output_dir)
    dataset = load_dataset(dataset)
    num_files = []
    num_lines = []
    len_issues = []
    coverages = []
    for example in dataset[split]:
        instance_id = example["instance_id"]
        if instance_id in BLACKLIST_FULL:
            continue
        len_issues.append(len(word_tokenize(example["problem_statement"])))
        eval_outputs = eval_outputs_by_instance[instance_id]
        coverage = extract_coverages_from_eval_output(eval_outputs)
        if len(coverage) != 4:
            continue
        original, original_w_test, patched_w_test, patched = coverage
        coverage = [original, patched_w_test]

        res = [extract_total_coverage(c) for c in coverage]
        num_lines.extend([r[1] for r in res])
        num_files.extend([r[2] for r in res])
        coverages.extend([r[0] / r[1] for r in res])
    avg = lambda x: sum(x) / len(x)
    for name, values in [("num_files", num_files), ("num_lines", num_lines), ("len_issues", len_issues), ("coverage", coverages)]:
        print(name, "&", avg(values), "&", max(values))




if __name__ == "__main__":
    fire.Fire(main)
