from collections.abc import Sequence

from agentgym.agent_pipeline.base_pipeline_element import (
    BasePipelineElement,
    DummyEnv,
)
from agentgym.functions_engine.functions_engine import (
    Env,
    Function,
    FunctionCall,
)
from agentgym.types import ChatMessage


class ToolSelector(BasePipelineElement):
    def __init__(self, llm: BasePipelineElement) -> None:
        self.llm = llm

    def select_tools(self, query: str, tools: Sequence[Function]) -> Sequence[Function]:
        raise NotImplementedError()

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        tools_to_use = self.select_tools(query, tools)
        return query, tools_to_use, env, messages, extra_args


class ToolUsagePlanner(BasePipelineElement):
    def __init__(self, llm: BasePipelineElement) -> None:
        self.llm = llm

    def plan(self, query: str, tools: Sequence[Function]) -> Sequence[FunctionCall]:
        raise NotImplementedError()

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        plan = self.plan(query, tools)
        tools_to_use_names = [tool_call["function"] for tool_call in plan]
        tools_to_use = [tool for tool in tools if tool.name in tools_to_use_names]
        extra_args["plan"] = plan
        return query, tools_to_use, env, messages, extra_args
