"""
An implementation of a LIBRO like operation
THIS IS NOT FINISHED YET, EXPECT WEIRD BEHAVIOR

Takes several samples for unittest generation,
looks at their evaluation trace and picks the one that most closely
resembles the issue description (picked by LLM or so)
"""
import bisect
import functools
import re
from collections import defaultdict
from typing import List, Optional, Tuple
import orjson as json
import pathlib

import fire
from cachier import cachier
from unidiff import PatchSet

from datasets import load_from_disk, load_dataset

first = True
def log(msg):
    global first
    if first:
        print(",".join(msg.keys()), flush=True)
        first = False
    print(",".join(str(v) for v in msg.values()), flush=True)
    pass

@functools.lru_cache(maxsize=None)
def cached_extract_coverages_from_eval_output(eval_output: str) -> Optional[List[list]]:
    return extract_coverages_from_eval_output(eval_output)

def extract_coverages_from_eval_output(eval_output: str) -> Optional[List[list]]:

    # extract the coverages from the evaluation output
    # 1. coverage before adding golden patch and test case
    # 2. coverage after test case
    # 3. coverage after golden patch
    eval_output = eval_output.splitlines(keepends=True)
    start_line = 0
    coverages = []
    while True:
        for i, line in enumerate(eval_output[start_line:], start=start_line):
            if line.startswith("Coverage Script: ") and "coverage.cover" in line:
                break
        if i >= len(eval_output) - 1:
            return coverages
        for j, line in enumerate(eval_output[i:], start=i):
            if line.startswith("Testbed: "):
                break
        testbed = eval_output[j][len("Testbed: "):].strip() + "/"
        res = {}
        for line in eval_output[i+2:j]:
            try:
                d = json.loads(line.strip())
            except json.JSONDecodeError:
                continue
            try:
                for k, v in d.items():
                    l = [(int(l), int(h)) for l, h in v.items()] #  if ih > 0]
                    l.sort(key=lambda x: x[0])
                    res[k[len(testbed):]] = l
            except:
                continue

        coverages.append(res)
        start_line = j

def extract_changed_lines_from_patch(patch: PatchSet) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
    # extract the lines that were changed by the patch
    # (lines that were removed, lines that were added)
    added_lines = []
    removed_lines = []
    for file in patch.modified_files + patch.added_files + patch.removed_files:
        for hunk in file:
            for line in hunk:
                if line.is_removed:
                    removed_lines.append((file.source_file[2:], line.source_line_no))
                if line.is_added:
                    added_lines.append((file.target_file[2:], line.target_line_no))
    return (removed_lines, added_lines)


def no_calls(coverage, file, line):
    if file not in coverage:
        return 0
    # find last line before the relevant line that was executed -> this is our line of interest
    items = coverage[file]
    index = bisect.bisect_right(items, (line, float("inf")))
    # insertion point just right of the element that we are looking for
    return items[index-1][1]

def lines_covered(original_cov, add_coverage, changes):
    lines = []
    for file, line in changes:
        if no_calls(original_cov, file, line) < no_calls(add_coverage, file, line):
            lines.append((file, line))
    return lines

def no_lines_covered(original_cov, add_coverage, changes):
    return len(lines_covered(original_cov, add_coverage, changes))

def missed_lines_covered(original_cov, add_coverage, changes):
    lines = []
    for file, line in changes:
        if no_calls(original_cov, file, line) == 0 and no_calls(add_coverage, file, line) > 0:
            lines.append((file, line))
    return lines

def no_missed_lines_covered(original_cov, add_coverage, changes):
    return len(missed_lines_covered(original_cov, add_coverage, changes))

def missed_lines(original_cov, changes):
    lines = []
    for file, line in changes:
        if no_calls(original_cov, file, line) == 0:
            lines.append((file, line))
    return lines

def no_missed_lines(original_cov, changes):
    return len(missed_lines(original_cov, changes))

def coverage_of_patchset(coverage, changes):
    diffed_coverage = defaultdict(list)
    cur_file = None
    coverage_of_file = []
    start = 0
    for file, line in changes:
        if file != cur_file:
            cur_file = file
            coverage_of_file = coverage.get(file, [])
            start = 0
        if not coverage_of_file:
            continue
        pos_of_line_in_coverage = bisect.bisect_right(coverage_of_file[start:], (line, float("inf")))
        cov = coverage_of_file[start+pos_of_line_in_coverage-1]
        if cov[1] > 0:
            diffed_coverage[file].append(cov)
        start += pos_of_line_in_coverage

    return diffed_coverage



def extract_good_case_from_eval_output(eval_output: str) -> Tuple[bool, bool, bool, bool, bool]:
    runs = []
    prev_line = 0
    num_failed = []
    eval_output = eval_output.splitlines()
    for i, l in enumerate(eval_output):
        if l.startswith(">>>>> All Tests Passed"):
            counter = [0, 0]  # PASSED, FAILED
            res = None
            for test_line in eval_output[i-1:prev_line:-1]:
                if any(s in test_line for s in ["==", "finished:"]) and any(s in test_line for s in ["passed", "failed", "warning", "error", "exceptions"]):
                    num_failed_here = re.findall(r"(\d+) failed", test_line)
                    if num_failed_here:
                        num_failed.append(int(num_failed_here[0]))
                    else:
                        num_failed.append(0)
                    if "error" in test_line:
                        res = 2 # 2 error
                        break
                    elif "failed" not in test_line and "error" not in test_line and "exceptions" not in test_line:
                        res = 1 # 1 pass
                        break
                    else:
                        res = 0 # 0 fail
                        break
                elif test_line.startswith("OK"):
                    res = 1
                    break
                elif test_line.startswith("FAILED"):
                    counter[1] += 1
                elif test_line.startswith("PASSED"):
                    counter[0] += 1
            if res is None and counter[1] > 0:
                res = 0
                num_failed.append(counter[1])
            elif res is None and counter[0] > 0:
                res = 1
                num_failed.append(0)
            if res is None:
                res = 2
            runs.append(res)
            prev_line = i
        if l.startswith(">>>>> Some Tests Failed"):
            runs.append(2)
            prev_line = i
    if len(runs) < 3:
        if len(runs) < 2:
            return False, False, True, True, True
        return False, False, runs[1] == 0, runs[1] == 2, runs[1] == 3
    # if there are failures before we also accept if the number of failures is increased/decreased
    try:
        if runs[0] == runs[1] == runs[2] == 0:
            if num_failed[0] == num_failed[2] and num_failed[1] > num_failed[0]:
                runs[2] = 1
    except IndexError:
        pass
    # 3 symbolizes compilation error but we can currently not reliably detect this
    # F->P, E->P, F, E, CompilationError
    return runs[1] == 0 and runs[2] == 1, runs[1] == 2 and runs[2] == 1,  runs[1] == 0, runs[1] == 2, runs[1] == 3

def extract_number_added_tests_from_patch(patch: PatchSet) -> int:
    # extract the number of tests that were added by the patch
    sm = 0
    for file in patch.modified_files + patch.added_files + patch.removed_files:
        for hunk in file:
            for i, line in enumerate(hunk):
                if line.is_added and line.value.strip().startswith("def") and "test" in line.value:
                    sm += 1

                if line.is_removed and line.value.strip().startswith("def") and "test" in line.value:
                    sm -= 1
    return sm

def save_div(a, b, default=1):
    if b == 0:
        return default
    return a / b


def extract_patch_from_eval_output(eval_output: str):
    if eval_output is None:
        return None
    patch_extracted = eval_output.split("Patch extracted and applied", maxsplit=1)
    if len(patch_extracted) < 2:
        return None
    return json.loads(patch_extracted[1].split("\n", maxsplit=2)[1].strip())["patch"]

def visually_combine_patch_and_coverage(patchs: PatchSet, removed_lines_covered, added_lines_covered):
    # visually combine the patch and the coverage
    # to see which lines were added/removed and which lines were covered
    # by the patch
    for patch in patchs:
        file = patch.path
        for hunk in patch:
            for line in hunk:
                if line.is_removed:
                    if (file, line.source_line_no) in removed_lines_covered:
                        line.value = f"C {line.value}"
                    else:
                        line.value = f"  {line.value}"
                if line.is_added:
                    if (file, line.target_line_no) in added_lines_covered:
                        line.value = f"C {line.value}"
                    else:
                        line.value = f"  {line.value}"
    print(patchs)


def load_eval_outputs(eval_output_dir: str):
    eval_output_by_instance = dict()
    for eval_output_file in pathlib.Path(eval_output_dir).glob("**/*.log"):
        with open(eval_output_file) as f:
            eval_output = f.read()
        start_of_instance_id = eval_output.find("Instance ID: ")
        end_of_instance_id = eval_output[start_of_instance_id:].find("\n")
        instance_id = eval_output[start_of_instance_id + len("Instance ID: "):start_of_instance_id + end_of_instance_id].strip()
        eval_output_by_instance[instance_id] = eval_output
    return eval_output_by_instance

def line_delta(local_coverage1, local_coverage2):
    # sorted join
    i1 = 0
    i2 = 0
    local_diff = []
    while i2 < len(local_coverage2):
        if i1 < len(local_coverage1) and local_coverage1[i1][0] < local_coverage2[i2][0]:
            i1 += 1
        elif i1 >= len(local_coverage1) or local_coverage1[i1][0] > local_coverage2[i2][0]:
            local_diff.append(local_coverage2[i2])
            i2 += 1
        else:
            cov_diff = local_coverage2[i2][1] - local_coverage1[i1][1]
            if cov_diff > 0:
                local_diff.append((local_coverage2[i2][0], cov_diff))
            i1 += 1
            i2 += 1
    return local_diff


def coverage_diff(coverage1, coverage2):
    """
    Returns the additional lines covered by coverage2 compared to coverage1
    """
    diff = {}
    for file in coverage2:
        if file not in coverage1:
            diff[file] = coverage2[file]
            continue
        if coverage2[file] == coverage1[file]:
            continue
        local_diff = line_delta(coverage1[file], coverage2[file])
        if local_diff:
            diff[file] = local_diff
    return diff


def test_diff():
    local_diff1 = [(2, 1), (7, 1), (8, 1), (26, 1), (27, 1), (28, 1), (31, 1), (32, 2), (33, 2), (34, 1), (36, 2), (37, 2), (38, 1), (40, 2), (41, 2), (42, 1), (44, 2), (45, 2), (46, 1), (48, 2), (49, 2), (50, 1), (52, 2), (53, 2), (54, 1), (59, 1), (60, 1), (61, 1), (62, 1), (63, 1), (64, 1), (65, 1)]
    local_diff2 = [(2, 1), (7, 1), (8, 1), (26, 1), (27, 1), (28, 1), (31, 2), (32, 2), (33, 1), (34, 1), (35, 1), (38, 1), (39, 2), (40, 2), (41, 1), (43, 2), (44, 2), (45, 1), (47, 2), (48, 2), (49, 1), (51, 2), (52, 2), (53, 1), (55, 2), (56, 2), (57, 1), (59, 2), (60, 2), (61, 1), (63, 1), (64, 1), (65, 1), (66, 2)]
    delta = [(31, 1), (35, 1), (39, 2), (43, 2), (47, 2), (51, 2), (55, 2), (56, 2), (57, 1), (59, 1), (60, 1), (66, 2)]
    assert line_delta(local_diff1, local_diff2) == delta

def line_intersection(local_coverage1, local_coverage2):
    # sorted join
    i1 = 0
    i2 = 0
    local_diff = []
    while i1 < len(local_coverage1) and i2 < len(local_coverage2):
        if local_coverage1[i1][0] < local_coverage2[i2][0]:
            i1 += 1
        elif i1 >= len(local_coverage1) or local_coverage1[i1][0] > local_coverage2[i2][0]:
            i2 += 1
        else:
            local_int = min(local_coverage2[i2][1], local_coverage1[i1][1])
            if local_int > 0:
                local_diff.append((local_coverage2[i2][0], local_int))
            i1 += 1
            i2 += 1
    return local_diff

def coverage_intersection(coverage1, coverage2):
    """
    Returns the intersection of lines covered by coverage2 and coverage1
    """
    intersection = {}
    for file in coverage2:
        if file not in coverage1:
            continue
        if file not in coverage2:
            continue
        if coverage2[file] == coverage1[file]:
            intersection[file] = coverage1[file]
            continue
        local_diff = line_intersection(coverage1[file], coverage2[file])
        if local_diff:
            intersection[file] = local_diff
    return intersection


def test_intersection():
    local_diff1 = [(2, 1), (7, 1), (8, 1), (26, 1), (27, 1), (28, 1), (31, 1), (32, 2), (33, 2), (34, 1), (36, 2), (37, 2), (38, 1), (40, 2), (41, 2), (42, 1), (44, 2), (45, 2), (46, 1), (48, 2), (49, 2), (50, 1), (52, 2), (53, 2), (54, 1), (59, 1), (60, 1), (61, 1), (62, 1), (63, 1), (64, 1), (65, 1)]
    local_diff2 = [(2, 1), (7, 1), (8, 1), (26, 1), (27, 1), (28, 1), (31, 2), (32, 2), (33, 1), (34, 1), (35, 1), (38, 1), (39, 2), (40, 2), (41, 1), (43, 2), (44, 2), (45, 1), (47, 2), (48, 2), (49, 1), (51, 2), (52, 2), (53, 1), (55, 2), (56, 2), (57, 1), (59, 2), (60, 2), (61, 1), (63, 1), (64, 1), (65, 1), (66, 2)]
    delta = [(2, 1), (7, 1), (8, 1), (26, 1), (27, 1), (28, 1), (31, 1), (32, 2), (33, 1), (34, 1), (38, 1), (40, 2), (41, 1), (44, 2), (45, 1), (48, 2), (49, 1), (52, 2), (53, 1), (59, 1), (60, 1), (61, 1), (63, 1), (64, 1), (65, 1)]

    assert line_intersection(local_diff1, local_diff2) == delta

def line_union(local_coverage1, local_coverage2):
    # sorted join
    i1 = 0
    i2 = 0
    local_diff = []
    while i1 < len(local_coverage1) or i2 < len(local_coverage2):
        if i2 >= len(local_coverage2) or (i1 < len(local_coverage1) and local_coverage1[i1][0] < local_coverage2[i2][0]):
            local_diff.append(local_coverage1[i1])
            i1 += 1
        elif i1 >= len(local_coverage1) or (i2 < len(local_coverage2) and local_coverage1[i1][0] > local_coverage2[i2][0]):
            local_diff.append(local_coverage2[i2])
            i2 += 1
        else:
            local_int = max(local_coverage2[i2][1], local_coverage1[i1][1])
            if local_int > 0:
                local_diff.append((local_coverage2[i2][0], local_int))
            i1 += 1
            i2 += 1
    return local_diff

def coverage_union(coverage1, coverage2):
    """
    Returns the union of lines covered by coverage2 and coverage1
    """
    union = {}
    for file in coverage2:
        if file not in coverage1:
            union[file] = coverage2[file]
            continue
        if file not in coverage2:
            union[file] = coverage1[file]
            continue
        if coverage2[file] == coverage1[file]:
            union[file] = coverage1[file]
            continue
        local_diff = line_union(coverage1[file], coverage2[file])
        if local_diff:
            union[file] = local_diff
    return union

def number_lines(coverage):
    return sum(len(v) for v in coverage.values())

def compute_overlap(lines_covered, golden_lines_covered):
    inter_lines = coverage_intersection(lines_covered, golden_lines_covered)
    inter = number_lines(inter_lines)
    number_golden_lines_covered = number_lines(golden_lines_covered)
    number_lines_covered = number_lines(lines_covered)
    number_either_covered = number_lines(coverage_union(lines_covered, golden_lines_covered))


    return inter, number_lines_covered, number_golden_lines_covered

def load_blacklisted():
    with open("blacklisted_cases.txt") as f:
        return set(x.strip() for x in f.readlines())
BLACKLIST = load_blacklisted()

def load_blacklisted():
    with open("blacklisted_cases_full.txt") as f:
        return set(x.strip() for x in f.readlines())
BLACKLIST_FULL = load_blacklisted()

def used_fuzzy(eval_output: str) -> bool:
    return ">>>>> Applied Patch (fuzzy_try)" in eval_output

@cachier()
def main(
    eval_output_dir: str = "evaluation_output/swt_lite_golden_test/mode_vanillafuzzy",
    golden_eval_output_dir: str = "evaluation_output/swt_lite_golden_test/mode_vanillafuzzy",
    dataset: str = "datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k",
    split: str = "test",
    fuzzy: bool = True,
):
    log = []
    eval_output_by_instance = load_eval_outputs(eval_output_dir)
    golden_eval_output_by_instance = load_eval_outputs(golden_eval_output_dir)

    # dataset = load_dataset(dataset)
    dataset = load_from_disk(dataset)
    for example in dataset[split]:
        instance_id = example["instance_id"]
        if instance_id in BLACKLIST:
          continue
        if instance_id not in eval_output_by_instance:
            log.append({
                "instance_id": instance_id,
                "message": "no eval output found",
            })
            continue
        eval_outputs = eval_output_by_instance[instance_id]
        coverage = extract_coverages_from_eval_output(eval_outputs)
        if len(coverage) != 4:
            log.append({
                "instance_id": instance_id,
                "message": "coverage not found for all 3 steps",
            })
            continue
        coverage_original, coverage_after_pred, coverage_after_patch, coverage_original_after_patch = coverage
        if not fuzzy and used_fuzzy(eval_outputs):
            log.append({
                "instance_id": instance_id,
                "message": "used fuzzy",
            })
            continue

        golden_coverage = cached_extract_coverages_from_eval_output(golden_eval_output_by_instance.get(instance_id))
        if golden_coverage is None or len(golden_coverage) != 4:
            log.append({
                "instance_id": instance_id,
                "message": "no golden coverage found",
            })
            continue
        golden_coverage_original, golden_coverage_after_pred, golden_coverage_after_patch, golden_coverage_original_after_patch = golden_coverage

        golden_patch = PatchSet(example["test_patch"])
        removed_lines, added_lines = extract_changed_lines_from_patch(golden_patch)

        additional_lines_covered_pre_patch = coverage_diff(coverage_of_patchset(coverage_original, removed_lines), coverage_of_patchset(coverage_after_pred, removed_lines))
        additional_lines_covered_post_patch = coverage_diff(coverage_of_patchset(coverage_original_after_patch, added_lines), coverage_of_patchset(coverage_after_patch, added_lines))
        golden_lines_covered_pre_patch = coverage_union(coverage_of_patchset(golden_coverage_original, removed_lines), coverage_of_patchset(golden_coverage_after_pred, removed_lines))
        golden_lines_covered_post_patch = coverage_union(coverage_of_patchset(golden_coverage_original_after_patch, added_lines), coverage_of_patchset(golden_coverage_after_patch, added_lines))
        inter_pre_patch, covered_pre_patch, golden_pre_patch = compute_overlap(additional_lines_covered_pre_patch, golden_lines_covered_pre_patch)
        inter_post_patch, covered_post_patch, golden_post_patch = compute_overlap(additional_lines_covered_post_patch, golden_lines_covered_post_patch)
        patch_executable = golden_pre_patch + golden_post_patch > 0
        recall = save_div(inter_pre_patch + inter_post_patch, golden_pre_patch + golden_post_patch, 1)
        precision = save_div(inter_pre_patch + inter_post_patch, covered_pre_patch + covered_post_patch, 1)
        higher_recall_than_golden = recall > (save_div(number_lines(golden_lines_covered_pre_patch) + number_lines(golden_lines_covered_post_patch), golden_pre_patch + golden_post_patch, 1))


        patch = extract_patch_from_eval_output(eval_outputs)
        if patch is None:
            log.append({
                "instance_id": instance_id,
                "message": "no eval output found",
            })
            continue
        unittest_patch = PatchSet(patch)

        num_tests = extract_number_added_tests_from_patch(unittest_patch)


        ftp, etp, fails_initially, error_initially, compilation_error = extract_good_case_from_eval_output(eval_outputs)

        log.append({
            "instance_id": instance_id,
            "recall": recall,
            "precision": precision,
            "patch_executable": int(patch_executable),
            "good_case": int(ftp or etp),
            "ftp": int(ftp),
            "etp": int(etp),
            "fails_initially": int(fails_initially),
            "error_initially": int(error_initially),
            "compilation_error": int(compilation_error),
            "no_added_tests": num_tests,
        })
    return log



        


if __name__ == "__main__":
    fire.Fire(main)
