from avatarchat.agent import *
from avatarchat.communication import *
import logging

# load global config
file_path = os.path.dirname(__file__)
project_path = os.path.dirname(file_path)

class Mode():
    """class for holding different experiment settings
    # TODO this part should be formatted for easier configure and better looking

        ModeName                 MindPin     MultiCommunication      Memory
    1.  Full                     Double        yes                   Mixed
    2.  DoubleMindPin            Double        yes                   --
    3.  FullSC                   Double        --                    Mixed
    4.  DoubleMindPinSC          Double        --                    --
    5.  SingleMindPin            Single        yes                   --
    6.  SingleMindPinSC          Single        --                    --
    7.  Vanilla                  --            yes                   --
    8.  FuzzyMemory              Double        yes                   Fuzzy                    
    9.  DistinctMemory           Double        yes                   Distinct                    
    10. VanillaMemory            --            yes                   Mixed
    11. VanillaSC                --            --                    --


    For Schedule/Needle in the Persona dataset, conduct experiments 2/4/5/7
    2-->4 ablate MultiCommunication
    2-->5 ablate DoubleMindPin 
    2-->7 ablate MindPin

    For Friends dataset, conduct experiments 1/3/8/9/10
    1-->3  ablate MultiCommunication
    1-->8  ablate DistinctMemory
    1-->9  ablate FuzzyMemory
    1-->10 ablate MindPin

    For those with MultiCommunication enabled,
    The mode used in the raised new third-party communication is the same as invoker mode,
    but no multicommunication enabled (so no nested communication),
    For example, the raised new communication in Full mode is the same as Full mode but no further multicommunication,
    (which is actually a FullSC mode).
    It is automatically handled by MultiCommunication Class

    For 5/7/8/9/10, we do not invoke multicommunication again but utilize multicommunication results from 1/2,
    which is realized by OfflineLoadMultiCommunication.set_communication_history() method

    """
    def __init__(self, sender, receiver, task, global_config, rewrite_prompt=True) -> None:
        self.sender = sender
        self.receiver = receiver
        self.task = task
        self.raw_task = task
        self.global_config = global_config
        self.mode_name = self.global_config.get("mode").get("mode")
        self.backend = global_config.get('backend').get('provider')
        if self.backend == "gemini":
            self.query_func = query_gemini
        elif self.backend == "gpt":
            self.query_func = query_gpt
        elif self.backend == "gpt4":
            self.query_func = query_gpt4
        elif self.backend == "claude":
            self.query_func = query_claude
        else:
            raise ValueError("{} backend not implemented".format(self.backend))
        
        
        with open(os.path.join(project_path, "prompts", "tool_prompt.json"), "r") as f:
            self.tool_prompt = json.load(f)

        
        # rewrite the task
        self.rewrite_prompt = rewrite_prompt
        if self.rewrite_prompt:
            query_prompt = "\n".join(self.tool_prompt['rewrite_task']).format(sender=sender, receiver=receiver, task=self.task)
            self.task = self.query_func(query_prompt)
            AvatarLogger.log(query_prompt, self.task, "【rewrite task】")

        self.realized_modes = {"Full",
                               "DoubleMindPin",
                               "FullSC",
                               "DoubleMindPinSC",
                               "SingleMindPin",
                               "SingleMindPinSC",
                               "Vanilla",
                               "FuzzyMemory",
                               "DistinctMemory",
                               "VanillaMemory",
                               "VanillaSC"}
        assert self.mode_name in self.realized_modes, "{} not realized, AvatarChat now supports: {}".format(self.mode_name, str(self.realized_modes))
        
        # log some global config
        global_config_str = ""
        global_config_str += "Global LLM Config:\n{}".format(str(self.global_config.get("backend").get("provider"))) + "\n"
        global_config_str += "Global Agent Config:\n{}".format(str(self.global_config.get("agent"))) + "\n"
        global_config_str += "Global Mode Config:\n{}".format(str(self.global_config.get("mode"))) + "\n"
        global_config_str += "Global Database Config:\n{}".format(str(self.global_config.get("mysql").get("database"))) + "\n"
        AvatarLogger.log(instruction=global_config_str)

    
    def get_instructor_agent(self):
        if self.mode_name in {'Full', 'FullSC'}:
            return OnlineMemoryAgent(master=self.sender, backend=self.backend, task=self.task, enable_fuzzy_memory=True)
        elif self.mode_name in {'VanillaMemory'}:
            return OnlineMemoryVanillaAgent(master=self.sender, backend=self.backend, task=self.task, enable_fuzzy_memory=True)
        elif self.mode_name in {'Vanilla','VanillaSC'}:
            return OnlineAgent(master=self.sender, backend=self.backend, task=self.task)
        elif self.mode_name in {'DoubleMindPin', 'DoubleMindPinSC', 'SingleMindPin', 'SingleMindPinSC'}:
            return OnlineThinkAgent(master=self.sender, backend=self.backend, task=self.task)
        elif self.mode_name in {'DistinctMemory'}:
            return OnlineMemoryAgent(master=self.sender, backend=self.backend, task=self.task)
        elif self.mode_name in {'FuzzyMemory'}:
            return OnlineMemoryAgent(master=self.sender, backend=self.backend, task=self.task, enable_distinct_memory=False, enable_fuzzy_memory=True)            

    def get_assistant_agent(self):
        if self.mode_name in {'Full', 'FullSC'}:
            return OnlineMemoryAgent(master=self.receiver, backend=self.backend, task=self.task, enable_fuzzy_memory=True, is_assistant=True)
        elif self.mode_name in {'VanillaMemory'}:
            return OnlineMemoryVanillaAgent(master=self.sender, backend=self.backend, task=self.task, enable_fuzzy_memory=True, is_assistant=True)
        elif self.mode_name in {'SingleMindPin', 'SingleMindPinSC', 'Vanilla', 'VanillaSC'}:
            return OnlineAgent(master=self.receiver, backend=self.backend, task=self.task, is_assistant=True)
        elif self.mode_name in {'DoubleMindPin', 'DoubleMindPinSC'}:
            return OnlineThinkAgent(master=self.receiver, backend=self.backend, task=self.task, is_assistant=True)
        elif self.mode_name in {'DistinctMemory'}:
            return OnlineMemoryAgent(master=self.receiver, backend=self.backend, task=self.task, is_assistant=True)
        elif self.mode_name in {'FuzzyMemory'}:
            return OnlineMemoryAgent(master=self.receiver, backend=self.backend, task=self.task, enable_distinct_memory=False, enable_fuzzy_memory=True, is_assistant=True)
       
    def get_communication(self, is_offline=False):
        """choose communication

        some combinations:
        1. whether offline
        2. whether MultiCommunication
        3. whether consensus conclusion
        4. whether OfflineLoadMultiCommunication

        Args:
            is_offline (bool, optional): whether offline. Defaults to False.

        Returns:
            Communication: constructed communication based on the mode
        """
        instructor_agent = self.get_instructor_agent()
        assistant_agent = self.get_assistant_agent()

        if self.mode_name in {'Full', 'DoubleMindPin'}:
            if is_offline:
                comm = OfflineMultiCommunication(instructor=instructor_agent, 
                                                 assistant=assistant_agent, 
                                                 max_round=global_config.get("agent").get("max_communication_turns"),
                                                 is_consensus_conclusion=True)   
            else:
                comm = OnlineMultiCommunication(instructor=instructor_agent, 
                                                assistant=assistant_agent, 
                                                max_round=global_config.get("agent").get("max_communication_turns"),
                                                is_consensus_conclusion=True)              
        elif self.mode_name in {'FullSC', 'DoubleMindPinSC'}:
            if is_offline:
                comm = OfflineCommunication(instructor=instructor_agent, 
                                            assistant=assistant_agent, 
                                            max_round=global_config.get("agent").get("max_communication_turns"),
                                            is_consensus_conclusion=True)   
            else:
                comm = OnlineCommunication(instructor=instructor_agent, 
                                           assistant=assistant_agent, 
                                           max_round=global_config.get("agent").get("max_communication_turns"),
                                           is_consensus_conclusion=True)             
        elif self.mode_name in {'SingleMindPinSC', 'VanillaSC'}:
            if is_offline:
                comm = OfflineCommunication(instructor=instructor_agent, 
                                            assistant=assistant_agent, 
                                            max_round=global_config.get("agent").get("max_communication_turns"))
            else:
                comm = OnlineCommunication(instructor=instructor_agent, 
                                           assistant=assistant_agent, 
                                           max_round=global_config.get("agent").get("max_communication_turns"))
        elif self.mode_name in {'FuzzyMemory', 'DistinctMemory'}:
                comm = OfflineLoadMultiCommunication(instructor=instructor_agent, 
                                                        assistant=assistant_agent, 
                                                        max_round=global_config.get("agent").get("max_communication_turns"),
                                                        is_consensus_conclusion=True)       
        elif self.mode_name in {'SingleMindPin', 'Vanilla', 'VanillaMemory'}:
                comm = OfflineLoadMultiCommunication(instructor=instructor_agent, 
                                                        assistant=assistant_agent, 
                                                        max_round=global_config.get("agent").get("max_communication_turns"))                        
        
        # print the rewritten task
        if self.rewrite_prompt:
            comm.send_message_agent(comm.instructor, 
                                    comm.assistant,
                                    "【rewrite task prompt】: from '{}' --> '{}'".format(self.raw_task, self.task))
        
        return comm