import json
import os
import sys
import re

inference_output_path = sys.argv[2]

with open(f"{inference_output_path}") as f:
    model_outputs = f.readlines()
model_outputs = [json.loads(l) for l in model_outputs]
model_outputs_by_instance_id = {
    l["instance_id"]: l for l in model_outputs
}

def count_successful_diff(path):
    total = []
    total_ap = []
    total_se = []
    total_fp = []
    total_ff = []
    total_pp = []
    total_pf = []
    total_new_test_cases = []
    instance_id = None
    for logfile in os.listdir(path):
        with open(f"{path}/{logfile}") as f:
            lines = f.readlines()
        applied_patch = False
        syntax_error = False
        ff, fp, pp, pf = False, False, False, False
        add_test_cases = []
        for l in lines:
            if ">>>>> Applied Patch (fuzzy_try)" in l:
                applied_patch = True
            if ">>>>> Applied Patch (custom_try)" in l:
                applied_patch = True
            if ">>>>> Applied Patch (pred_try)" in l:
                applied_patch = True
            if ">>>>> Applied Patch (pred_minimal_try)" in l:
                applied_patch = True
            if "Tests passed before/after golden patch: False;True" in l:
                fp = True
            if "Tests passed before/after golden patch: False;False" in l:
                ff = True
            if "Tests passed before/after golden patch: True;False" in l:
                pf = True
            if "Tests passed before/after golden patch: True;True" in l:
                pp = True
            if "SyntaxError" in l or "IndentationError" in l:
                syntax_error = True
            test_cases = re.findall(r"collected (\d+) items", l)
            if test_cases:
                add_test_cases.append(test_cases[0])
            if "	- Instance ID: " in l:
                instance_id = l.split("	- Instance ID: ")[1].strip()
        total.append(logfile)
        if applied_patch:
            total_ap.append(logfile)
        if syntax_error:
            total_se.append(logfile)
        if fp:
            total_fp.append(logfile)
        if ff:
            total_ff.append(logfile)
            print("------------------------------------------------------")
            print("FF")
            print(instance_id)
            # print(model_outputs_by_instance_id[instance_id]["text"].split("</issue>")[0])
            print(model_outputs_by_instance_id[instance_id]["full_output"])
            input()
        if pf:
            total_pf.append(logfile)
        if pp:
            total_pp.append(logfile)
        if len(add_test_cases) == 3:
            orig_num = int(add_test_cases[0])
            new_num = int(add_test_cases[1])
            add_test_cases = new_num - orig_num
            total_new_test_cases.append(add_test_cases)
    return total, total_ap, total_se, total_ff, total_fp, total_pp, total_pf, total_new_test_cases


def aggregate_successful_diff(path, seeds=(1, 2, 3, 4, 5)):
    paths = [re.sub(r"seed=\d+", f"seed={i}", path, 1) for i in seeds]
    total = []
    total_ap = []
    total_se = []
    total_fp = []
    total_ff = []
    total_pp = []
    total_pf = []
    total_new_test_cases = []
    for p in paths:
        t, ap, se, ff, fp, pp, pf, new_test_cases = count_successful_diff(p)
        total.append(t)
        total_ap.append(ap)
        total_se.append(se)
        total_ff.append(ff)
        total_fp.append(fp)
        total_pp.append(pp)
        total_pf.append(pf)
        total_new_test_cases.append(new_test_cases)
    amt = len(seeds)
    return (
        sum(len(x) for x in total) / amt,
        sum(len(x) for x in total_ap) / amt,
        sum(len(set(x) - set(y)) for x, y in zip(total_ap, total_se)) / amt,
        sum(len(x) for x in total_ff) / amt,
        sum(len(x) for x in total_fp) / amt,
        sum(len(x) for x in total_pp) / amt,
        sum(len(x) for x in total_pf) / amt,
        sum(sum(x)/len(x) if x else 0 for x in total_new_test_cases) / amt,
    )


if __name__ == "__main__":
    eval_output_path = sys.argv[1]
    inference_output_path = sys.argv[2]
    seeds = (1,)
    temps = ("0",)
    for temp in temps:
        if temp == "0":
            local_seeds = (0,)
        else:
            local_seeds = seeds
        res = aggregate_successful_diff(
            re.sub(r"temperature=\d+", f"temperature={temp}", eval_output_path, 1),
            local_seeds,
        )
        print(*[f"{x:.1f}" for x in res])
