import numpy as np
import copy

from .knowledge_graph import KG, Room, Relation, GraphEdge, node_rel2str, State
from typing import List
from collections import deque

def sigmoid(z):
    return 1/(1 + np.exp(-z))

rooms = ['kitchen', 'bedroom', 'bathroom', 'livingroom']

pre_defined_event_edges = {
    "tv is off": [("tv", "is", "off"), ("tv", "is", "on")],
    "radio is off": [("radio", "is", "off"), ("radio", "is", "on")],
    "stove is off": [("stove", "is", "off"), ("stove", "is", "on")],
    "microwave is off": [("microwave", "is", "off"), ("microwave", "is", "on")],
    "computer is off": [("computer", "is", "off"), ("computer", "is", "on")],
    "apple is not on desk": [("apple", "on", None), ("apple", "on", "desk")],
    "mug is not on coffeetable": [("mug", "on", None), ("mug", "on", "coffeetable")],
    "book is not on sofa": [("book", "on", None), ("book", "on", "sofa")],
    "plate is not on microwave": [("plate", "on", None), ("plate", "on", "microwave")],
    "towel is not on washingmachine": [("towel", "on", None), ("towel", "on", "washingmachine")],
}

def knowledge_based_action(
    knowledge_graph: KG,
    action: str,
    timestep: int
):
    kg = knowledge_graph.clone()
    if action.split()[0] == "walk":
        objectinroom = {edge.from_node.name: edge.to_node.name for edge in kg.edges if edge.relation == Relation['INSIDE'] and edge.to_node.name in rooms}
        agent_room = objectinroom['character']
        if action.split()[1] in rooms:
            if kg.search_adjacent_obj((action.split()[1], "adjacent", agent_room)):
                kg.delete_edge(("character", "inside", None))
                kg.delete_edge(("character", "close", None))
                from_node = kg.search_node("character")[0]
                kg.add_edge(GraphEdge(from_node, Relation["INSIDE"], Room[action.split()[1].upper()], timestep))
        else:
            if action.split()[1] in objectinroom.keys():
                object_room = objectinroom[action.split()[1]]
                if object_room == agent_room:
                    kg.delete_edge(("character", "close", None))
                    from_node = kg.search_node("character")[0]
                    relation = Relation["CLOSE"]
                    to_node = kg.search_node(action.split()[1])
                    if to_node is not None:
                        kg.add_edge(GraphEdge(from_node, relation, to_node[0], timestep))
    return kg

def information_gain(
    knowledge_graph: KG,
    event_edges: List[str],
):
    return_dict = [-100] * len(event_edges)
    yes_or_no = [0] * len(event_edges)
    for idx, edge in enumerate(event_edges):
        search_edges = knowledge_graph.search_edge(edge[0])
        if search_edges is None:
            return_dict[idx] = -100
            yes_or_no[idx] = -1
        else:
            for e in search_edges:
                name = node_rel2str(e.to_node)
                if edge[0][2] is None or name == edge[0][2]:
                    return_dict[idx] = e.timesteps
                    yes_or_no[idx] = 1
                    # break
                if len(edge) == 2:
                    if name == edge[1][2]:
                        return_dict[idx] = e.timesteps
                        yes_or_no[idx] = 0
                        # break
    return return_dict, yes_or_no

def find_distance(
    knowledge_graph: KG,
    objectA: str,
    objectB: str
):
    graph = knowledge_graph.return_string_tuple()
    adjacency_list = {}
    for src, _, dest in graph:
        if src in adjacency_list:
            adjacency_list[src].append(dest)
        else:
            adjacency_list[src] = [dest]
        if dest in adjacency_list:
            adjacency_list[dest].append(src)
        else:
            adjacency_list[dest] = [src]

    queue = deque([(objectA, 0)])
    visited = set()

    while queue:
        current, distance = queue.popleft()
        if current == objectB:
            return distance
        visited.add(current)
        for neighbor in adjacency_list.get(current, []):
            if neighbor not in visited:
                queue.append((neighbor, distance + 1))

    return np.inf

def information_function(
    knowledge_graph: KG,
    query_list: List[str],
    embedding_fns,
    query_edges=None
):
    if query_edges is None:
        query_edges = []
        info_gain = []
        for query in query_list:
            _retrieved_edges = knowledge_graph.retrieve([query], num_edges=1, embedding_fns=embedding_fns, return_type='with_timestep_list', character_info=False, hold_info=False)
            retrieved_edges = []
            info_gain_query = []
            for edge in _retrieved_edges:
                edge_list = edge[1:-1].split(", ")
                info_gain_query.append(int(edge_list[-1]))
                retrieved_edges.append(tuple(edge_list[:3]))

            query_edges.append(retrieved_edges)
            info_gain.append(info_gain_query)
    else:
        info_gain = [0, 0, 0, 0] #dummy

    temp_distance = []
    for querywise_edge in query_edges:
        querywise_distance = []
        for edge in querywise_edge:
            querywise_distance.append(find_distance(knowledge_graph, "character", edge[0]))
        temp_distance.append(querywise_distance)

    return np.array(temp_distance), np.array(info_gain), query_edges

# For debug
def query_execution(
    knowledge_graph: KG,
    event_list: List[str],
):
    event_edges = []
    for event in event_list:
        event_edges.append(pre_defined_event_edges[event])

    _, query_exec = information_gain(knowledge_graph, event_edges=event_edges)
    return query_exec

def calcuate_information(distance, temp_query_eval):
    information_measure = np.clip(np.sum(temp_query_eval * np.log(temp_query_eval + 1e-6), axis=1) + 0.6, a_max=0, a_min=-1e+6) #negative entropy
    distance_coef = np.exp((-np.mean(np.array(distance), axis=1) + 1)/2)
    return np.sum(information_measure * (1 - distance_coef))


def action_information_gain(
    knowledge_graph: KG,
    action_list: List[str],
    queries_list: List[str],
    temp_query_eval: np.ndarray,
    timestep: int,
    embedding_fns,
):
    temp_distance, temp_info_timestep, query_edges = information_function(knowledge_graph, queries_list, embedding_fns)
    temp_info = calcuate_information(temp_distance, temp_query_eval)
    act_info = []
    for action in action_list:
        new_knowledge_graph = knowledge_based_action(knowledge_graph, action, timestep + 1)
        act_distance, _, query_edges = information_function(new_knowledge_graph, queries_list, embedding_fns, query_edges)
        act_info_value = calcuate_information(act_distance, temp_query_eval)
        act_info.append(act_info_value)

    return temp_info, act_info, np.mean(temp_info_timestep, axis=1)