import torch
import numpy as np
from .knowledge_graph import KG
from .retriever import EmbeddingFnsClass
from .value_fn import query_execution
from openai import OpenAI
client = OpenAI(api_key="key")


softmax = lambda x: np.exp(x)/np.sum(np.exp(x), axis=0)

class LLMAgent(object):
    def __init__(self, tokenizer, LLMmodel, knowledge_graph: KG, embedding_fns: EmbeddingFnsClass=None):
        self.knowledge_graph = knowledge_graph
        self.prev_knowledge_graph = knowledge_graph.clone()
        self.embedding_fns = embedding_fns
        self.memory_augmented_planner = "gpt-3.5-turbo-0125"
        self.timestep = 0
        if isinstance(LLMmodel, str):
            self.model = "gpt-3.5-turbo-0125"
            self.tokenizer = None
            self.return_type = None
        else:
            self.tokenizer = tokenizer
            self.model = LLMmodel.eval()

    def kg_add(self, new_graph, t, use_refinement=True):
        if self.timestep < t:
            self.timestep = t
            self.prev_knowledge_graph = self.knowledge_graph.clone()
        self.knowledge_graph.add(new_graph, t, use_refinement)


    def query(self, query, prev_prob=None):
        if prev_prob is None:
            prev_ent = 1.0
        else:
            prev_ent = prev_prob[0] - prev_prob[1] # Compare entropy (difference increase => entropy decrease)

        answer_list = ["true", "false", "unknown", "True", "False", "Unknown"]
        tokenized_answer = self.tokenizer(answer_list, return_tensors="pt", padding=True).input_ids
        system_prompt = "Evaluate the query based on environment knowledge. Knowledge consists of (object, relation, subject, observation time). Answer yes or no, taking into account changes over time. If the information is missing or too old to be accurate, answer unkown.\n"
        temperature = 0.33
        non_update = True
        result_list = []
        for i in range(10):
            env_info = self.retrieve_from_prev([query], num_edges=8, return_type='with_timestep', character_info=False, replace=True)
            history = self.history_retrieve_from_prev([query], num_edges=8, return_type='with_timestep', character_info=False)
            user_prompt = "History: "+ history + "\n"
            user_prompt += "Environment knowledge: " + env_info+ "\n"
            user_prompt += "Temp timestep: " + str(self.timestep) + "\nQuery: " + query + "\nAnswer:"
            print(system_prompt + "\n" + user_prompt)
            _, token_logprob = self.get_logprob_tokens(system_prompt + "\n" + user_prompt, tokenized_answer[0])

            preupdate_result = torch.exp(token_logprob[tokenized_answer[:3, 1]] / temperature) \
                               + torch.exp(token_logprob[tokenized_answer[3:6, 1]] / temperature)
            preupdate_result = preupdate_result / (torch.sum(preupdate_result) + 1e-20)
            preupdate_result = np.array(preupdate_result)
            neg_ent = np.abs(preupdate_result[0] - preupdate_result[1])
            # if np.abs(prev_ent) < neg_ent:
            #     continue

            env_info = self.retrieve([query], num_edges=8, return_type='with_timestep_list', character_info=False, replace=True)
            # print(env_info)
            env_timesteps = [int(d[1: -1].split()[-1]) for d in env_info]
            history = self.history_retrieve([query], num_edges=8, return_type='with_timestep', character_info=False)
            user_prompt = "History: "+ history + "\n"
            user_prompt += "Environment knowledge: " + ", ".join(env_info)+ "\nBefore update result: True({}%), False({}%), Unknown({}%)\n".format(int(preupdate_result[0]*100), int(preupdate_result[1]*100), int(preupdate_result[2]*100))
            user_prompt += "Temp timestep: " + str(self.timestep) + "\nQuery: " + query + "\nAnswer:"
            print(system_prompt + "\n" + user_prompt)
            _, token_logprob = self.get_logprob_tokens(system_prompt + "\n" + user_prompt, tokenized_answer[0])
            result = torch.exp(token_logprob[tokenized_answer[:3, 1]] / temperature) + torch.exp(token_logprob[tokenized_answer[3:6, 1]] / temperature)
            result = result / (torch.sum(result) + 1e-20)
            result = result.numpy()
            print("\nAfter update result: True({}%), False({}%), Unknown({}%)\n".format(int(result[0]*100), int(result[1]*100), int(result[2]*100)))
            result_list.append(result)
            if self.timestep in env_timesteps:
                non_update = False
        if not result_list:
            result_list.append(prev_prob)
        if not non_update:
            result = np.mean(np.array(result_list), axis=0)
            if np.abs(result[0] - result[1]) > np.abs(prev_ent) or (result[0] - result[1]) * prev_ent < 0:
                non_update = False
            else:
                non_update = True
        return np.mean(np.array(result_list), axis=0), non_update

    def get_logprob_tokens(self, inputs, token_candidates):
        tokenized_inputs = self.tokenizer(inputs, return_tensors="pt")
        with torch.no_grad():
            generate_ids = self.model(tokenized_inputs.input_ids.cuda())
        token_logprob = torch.log(torch.softmax(generate_ids.logits[0, -1], dim=0).cpu())
        return token_logprob[token_candidates], token_logprob

    def get_logprob_actions(self, inputs, action_list, return_dict=False):
        tokenized_action = self.tokenizer(action_list, return_tensors="pt", padding=True).input_ids
        cache = {}
        action_list_logprob = dict() if return_dict else []
        for i in range(tokenized_action.shape[0]):
            temp_inputs = inputs
            action_logprob = 0
            is_unique = False
            for token_idx, token in enumerate(tokenized_action[i]):
                if token in self.tokenizer.all_special_ids or is_unique:
                    continue
                if temp_inputs in cache.keys():
                    token_logprob = cache[temp_inputs][token]
                else:
                    token_logprob, entire_logprob = self.get_logprob_tokens(temp_inputs, token)
                    cache[temp_inputs] = entire_logprob

                action_logprob += token_logprob
                next_tokenized_action = self.tokenizer.convert_ids_to_tokens(int(token))
                temp_inputs = temp_inputs + next_tokenized_action.replace('▁', ' ')

                is_unique = True
                for action in tokenized_action:
                    if not (tokenized_action[i] == action).all() and (tokenized_action[i][:token_idx + 1] == action[:token_idx + 1]).all():
                        is_unique = False
                        break

            if return_dict:
                action_list_logprob[action_list[i]] = action_logprob
            else:
                action_list_logprob.append(action_logprob)
        return action_list_logprob

    def retrieve_from_prev(self, instructions, num_edges=50, return_type="str", character_info=True, replace=False):
        return self.prev_knowledge_graph.retrieve(instructions, self.embedding_fns, num_edges, return_type, character_info, replace)

    def retrieve(self, instructions, num_edges=50, return_type="str", character_info=True, replace=False):
        return self.knowledge_graph.retrieve(instructions, self.embedding_fns, num_edges, return_type, character_info, replace)

    def history_retrieve_from_prev(self, instructions, num_edges=50, return_type="str", character_info=True):
        return self.prev_knowledge_graph.history_retrieve(instructions, self.embedding_fns, num_edges, return_type, character_info)

    def history_retrieve(self, instructions, num_edges=50, return_type="str", character_info=True):
        return self.knowledge_graph.history_retrieve(instructions, self.embedding_fns, num_edges, return_type, character_info)

    def predict(self, prompts, action_list, return_list=False):
        log_prob = self.get_logprob_actions(prompts, action_list, return_dict=True)
        sorted_items = sorted(log_prob.items(), key=lambda item: item[1], reverse=True)
        if return_list:
            return sorted_items[-1][0], sorted_items
        else:
            return sorted_items[-1][0], None

    def evaluate_query(self, queries, executions, prev_prob=None, temporal_consistency=True):
        query_temp_eval = []
        inferred_query_execs = []
        for idx, q in enumerate(queries):
            if prev_prob is not None:
                probs, non_update = self.query(q, prev_prob[idx])
                if temporal_consistency and non_update and np.abs(probs[0] - probs[1]) < np.abs(prev_prob[idx][0] - prev_prob[idx][1]):
                    probs = probs * 0.1 + prev_prob[idx] * 0.9
                elif temporal_consistency and non_update:
                    probs = prev_prob[idx]
            else:
                probs, non_update = self.query(q)
            query_temp_eval.append(probs)
            # print(probs)
            if probs[0] > 0.5:
                inferred_query_execs.append(1)
            else:
                inferred_query_execs.append(0)
        selected_executions = [executions[idx] for idx, query_exec in enumerate(inferred_query_execs) if query_exec == 1]
        if not selected_executions:
            selected_executions.append("Find the {}".format(", ".join(queries)))
        return query_temp_eval, inferred_query_execs, selected_executions
