"""
For inter-rator agreement, we need to compare the results of two or more
1. Inter-rator agreement
2. Alignment between transcribed answer and crowd sourced answer
"""

from typing import List
import numpy as np
from data import CausalDataset, MoralDataset
from nltk import word_tokenize
from tqdm import trange
from collections import defaultdict


def is_proper_subset(set_a, set_b):
    return set_a.issubset(set_b) and set_a != set_b


def is_euqal(set_a, set_b):
    return set_a == set_b


def lcs(X, Y):
    # find the length of the strings
    m = len(X)
    n = len(Y)

    # declaring the array for storing the dp values
    L = [[None] * (n + 1) for i in range(m + 1)]

    """Following steps build L[m + 1][n + 1] in bottom up fashion
    Note: L[i][j] contains length of LCS of X[0..i-1]
    and Y[0..j-1]"""
    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

    # L[m][n] contains the length of LCS of X[0..n-1] & Y[0..m-1]
    return L[m][n]


def compute_overlap_words(sentence1: List[str], sentence2: List[str]):
    lcs_len = lcs(sentence1, sentence2)
    return lcs_len / ((len(sentence1) + len(sentence2)) / 2)


def flatten_list(list_list):
    return [item for sublist in list_list for item in sublist]


def factor_overlap(data1, data2):
    assert len(data1) == len(data2)
    overlaps = []
    for i in range(len(data1)):
        if i in data1.ignore_ids:
            continue
        annotations1 = set(
            [
                sentence.annotation.factor
                for sentence in data1.examples[i].annotated_sentences
            ]
        )
        annotations2 = set(
            [
                sentence.annotation.factor
                for sentence in data2.examples[i].annotated_sentences
            ]
        )
        overlaps.append(
            {
                "idx": i,
                "a subset b": is_proper_subset(annotations1, annotations2),
                "b subset a": is_proper_subset(annotations2, annotations1),
                "a = b": is_euqal(annotations1, annotations2),
                "iou": len(annotations1.intersection(annotations2))
                / len(annotations1.union(annotations2))
                if len(annotations1.union(annotations2)) > 0
                else 1.0,
            }
        )

    non_equal_idxs = [item["idx"] for item in overlaps if item["a = b"] != 1]
    print(f"{non_equal_idxs=}")

    iou = np.mean([item["iou"] for item in overlaps])
    print(f"iou: {iou:.4f}")


def value_overlap(data1, data2):
    assert len(data1) == len(data2)
    overlaps = []
    for i in range(len(data1)):
        if i in data1.ignore_ids:
            continue
        annotations1 = set(
            [
                (sentence.annotation.factor, sentence.annotation.value)
                for sentence in data1.examples[i].annotated_sentences
            ]
        )
        annotations2 = set(
            [
                (sentence.annotation.factor, sentence.annotation.value)
                for sentence in data2.examples[i].annotated_sentences
            ]
        )
        overlaps.append(
            {
                "idx": i,
                "a subset b": is_proper_subset(annotations1, annotations2),
                "b subset a": is_proper_subset(annotations2, annotations1),
                "a = b": is_euqal(annotations1, annotations2),
                "iou": len(annotations1.intersection(annotations2))
                / len(annotations1.union(annotations2))
                if len(annotations1.union(annotations2)) > 0
                else 1.0,
            }
        )

    non_equal_idxs = [item["idx"] for item in overlaps if item["a = b"] != 1]
    print(f"{non_equal_idxs=}")

    iou = np.mean([item["iou"] for item in overlaps])
    print(f"iou: {iou:.4f}")


def sentence_overlap(data1, data2):
    assert len(data1) == len(data2)
    overlaps = []
    for i in range(len(data1)):
        if i in data1.ignore_ids:
            continue

        example_annotation1 = defaultdict(list)
        example_annotation2 = defaultdict(list)

        for sentence1 in data1.examples[i].annotated_sentences:
            sentence1_annotation = (
                sentence1.annotation.factor,
                sentence1.annotation.value,
            )
            sentence1_text = word_tokenize(sentence1.text)
            example_annotation1[sentence1_annotation].append(sentence1_text)

        for sentence2 in data2.examples[i].annotated_sentences:
            sentence2_annotation = (
                sentence2.annotation.factor,
                sentence2.annotation.value,
            )
            sentence2_text = word_tokenize(sentence2.text)
            example_annotation2[sentence2_annotation].append(sentence2_text)

        for key in example_annotation1:
            example_annotation1[key] = flatten_list(sorted(example_annotation1[key]))

        for key in example_annotation2:
            example_annotation2[key] = flatten_list(sorted(example_annotation2[key]))

        for key in example_annotation1:
            if key in example_annotation2:
                overlaps.append(
                    {
                        "idx": i,
                        "sentence1": example_annotation1[key],
                        "sentence2": example_annotation2[key],
                        "iou": compute_overlap_words(
                            example_annotation1[key], example_annotation2[key]
                        ),
                    }
                )

    non_equal_idxs = list(set([item["idx"] for item in overlaps if item["iou"] != 1]))
    print(f"{non_equal_idxs=}")

    iou = np.mean([item["iou"] for item in overlaps])
    print(f"iou: {iou:.4f}")


if __name__ == "__main__":
    assert (
        compute_overlap_words("hello world".split(), "hello hello world".split()) == 0.8
    )

    cd1 = CausalDataset(json_file="../../data/causal_dataset_v1.json")
    cd2 = CausalDataset(json_file="../../data/causal_dataset_hd.json")
    assert len(cd1) == len(cd2)
    md1 = MoralDataset(json_file="../../data/moral_dataset_v1.json")
    md2 = MoralDataset(json_file="../../data/moral_dataset_mk.json")
    assert len(md1) == len(md2)

    print("1. factor overlap")
    factor_overlap(cd1, cd2)
    factor_overlap(md1, md2)
    print("\n\n")
    print("2. value overlap")
    value_overlap(cd1, cd2)
    value_overlap(md1, md2)
    print("\n\n")
    print("3. sentence overlap")
    sentence_overlap(cd1, cd2)
    sentence_overlap(md1, md2)
    print("\n\n")
