import json
import logging
import warnings
from collections.abc import Sequence
from pathlib import Path

from cohere.core.api_error import ApiError
from google.api_core.exceptions import InternalServerError
from openai import BadRequestError
from pydantic import BaseModel, ConfigDict, ValidationError, model_validator
from typing_extensions import Self, TypedDict

from agentgym.agent_pipeline.base_pipeline_element import (
    BasePipelineElement,
)
from agentgym.attacks.base_attack import BaseAttacker
from agentgym.base_tasks import BaseInjectionTask, BaseUserTask
from agentgym.logging import Logger, TraceLogger
from agentgym.task_suites import (
    banking_task_suite,
    slack_task_suite,
    travel_task_suite,
    workspace_task_suite,
)
from agentgym.task_suites.task_suite import Env, TaskSuite
from agentgym.types import ChatMessage

SUITES = {
    "workspace": workspace_task_suite,
    "travel": travel_task_suite,
    "slack": slack_task_suite,
    "banking": banking_task_suite,
}


class SuiteResults(TypedDict):
    utility_results: dict[tuple[str, str], bool]
    security_results: dict[tuple[str, str], bool]
    injection_tasks_utility_results: dict[str, bool]


class BenchmarkResults(BaseModel):
    suite_results: dict[str, SuiteResults]
    utility: float
    security: float


def aggregate_results(results: list[dict[tuple[str, str], bool]]) -> float:
    all_results = [result for results in results for result in results.values()]
    return sum(all_results) / len(all_results)


def run_task_with_injection_tasks(
    suite: TaskSuite,
    agent_pipeline: BasePipelineElement,
    user_task: BaseUserTask,
    attacker: BaseAttacker,
    logdir: Path | None,
    force_rerun: bool,
    injection_tasks: Sequence[str] | None = None,
) -> tuple[dict[tuple[str, str], bool], dict[tuple[str, str], bool]]:
    utility_results: dict[tuple[str, str], bool] = {}
    security_results: dict[tuple[str, str], bool] = {}

    if attacker.is_dos_attack:
        # No need to run all injection tasks since we don't care about the injection target!
        injection_tasks_to_run = [next(iter(suite.injection_tasks.keys()))]
    else:
        injection_tasks_to_run = suite.injection_tasks

    if injection_tasks is not None and not attacker.is_dos_attack:
        injection_tasks_to_run = injection_tasks

    for injection_task_id in injection_tasks_to_run:
        injection_task = suite.get_injection_task_by_id(injection_task_id)
        task_injections = attacker.attack(user_task, injection_task)

        if logdir is not None and agent_pipeline.name is not None:
            try:
                task_result = load_task_results(
                    agent_pipeline.name,
                    suite.name,
                    user_task.ID,
                    attack_name=attacker.name,
                    injection_task=injection_task.ID,
                    logdir=logdir,
                )
            except (ValidationError, FileNotFoundError):
                task_result = None
        else:
            task_result = None

        if task_result is not None and not force_rerun:
            logging.info(f"Skipping task '{user_task.ID}' with '{injection_task.ID}' because it was already run.")
            utility, security = task_result.utility, task_result.security
        else:
            # a trace logger captures logs and saves them, but also passes them to the delegate logger
            with TraceLogger(
                delegate=Logger.get(),
                suite_name=suite.name,
                user_task_id=user_task.ID,
                injection_task_id=injection_task_id,
                injections=task_injections,
                attack_type=attacker.name,
                pipeline_name=agent_pipeline.name,
            ) as logger:
                try:
                    utility, security = suite.run_task_with_pipeline(
                        agent_pipeline, user_task, injection_task, task_injections
                    )
                except BadRequestError as e:
                    if e.code == "context_length_exceeded" or e.param == "max_tokens":
                        logger.log_error(
                            f"Skipping task '{user_task.ID}' with '{injection_task.ID}' due to context_length_exceeded: {e}"
                        )
                        utility = False
                        security = True
                    else:
                        raise e
                except ApiError as e:
                    if "internal server error" in str(e):
                        logger.log_error(
                            f"Skipping task '{user_task.ID}' with '{injection_task.ID}' because of internal server error: {e}"
                        )
                        utility = False
                        security = True
                    else:
                        raise e
                except InternalServerError as e:
                    logger.log_error(
                        f"Skipping task '{user_task.ID}' with '{injection_task.ID}' because of internal server error: {e}"
                    )
                    utility = False
                    security = True

                if attacker.is_dos_attack:
                    security = not utility

            logger.set_contextarg("utility", utility)
            logger.set_contextarg("security", security)

        utility_results[(user_task.ID, injection_task_id)] = utility
        security_results[(user_task.ID, injection_task_id)] = security
    return utility_results, security_results


def benchmark_suite_with_injections(
    agent_pipeline: BasePipelineElement,
    suite: TaskSuite[Env],
    attacker: BaseAttacker[Env],
    logdir: Path | None,
    force_rerun: bool,
    user_tasks: Sequence[str] | None = None,
    injection_tasks: Sequence[str] | None = None,
    verbose: bool = True,
) -> SuiteResults:
    """
    Run a benchmark on a suite with injections.

    Args:
        agent_pipeline: The agent pipeline to use.
        suite: The suite to benchmark.
        injections: The injections to use. Can be a dictionary of injections or the attack type (directory in attacks/).
                    Only when the attack type is provided, the executed tasks can be traced correctly, as otherwise the attack type is not known.
        user_tasks: A subset of user tasks to run. If None, all user tasks are run.
    """
    suite_utility_results: dict[tuple[str, str], bool] = {}
    suite_security_results: dict[tuple[str, str], bool] = {}

    if user_tasks is not None:
        user_tasks_to_run = [suite.get_user_task_by_id(user_task_id) for user_task_id in user_tasks]
    else:
        user_tasks_to_run = suite.user_tasks.values()

    if injection_tasks is not None:
        injection_tasks_to_run = {
            injection_task_id: suite.get_injection_task_by_id(injection_task_id)
            for injection_task_id in injection_tasks
        }
    else:
        injection_tasks_to_run = suite.injection_tasks

    injection_tasks_utility_results = {}
    if not attacker.is_dos_attack:
        for injection_task_id, injection_task in injection_tasks_to_run.items():
            successful, _ = run_task_without_injection_tasks(suite, agent_pipeline, injection_task, logdir, force_rerun)
            injection_tasks_utility_results[injection_task_id] = successful

        if not all(injection_tasks_utility_results.values()):
            warnings.warn("Not all injection tasks were solved as user tasks.")

    for user_task in user_tasks_to_run:
        utility, security = run_task_with_injection_tasks(
            suite,
            agent_pipeline,
            user_task,
            attacker,
            logdir,
            force_rerun,
            injection_tasks,
        )
        suite_utility_results.update(utility)
        suite_security_results.update(security)

    return SuiteResults(
        utility_results=suite_utility_results,
        security_results=suite_security_results,
        injection_tasks_utility_results=injection_tasks_utility_results,
    )


def run_task_without_injection_tasks(
    suite: TaskSuite,
    agent_pipeline: BasePipelineElement,
    task: BaseUserTask | BaseInjectionTask,
    logdir: Path | None,
    force_rerun: bool,
) -> tuple[bool, bool]:
    if logdir is not None and agent_pipeline.name is not None:
        try:
            task_result = load_task_results(
                agent_pipeline.name,
                suite.name,
                task.ID,
                attack_name="none",
                injection_task="none",
                logdir=logdir,
            )
        except (ValidationError, FileNotFoundError):
            task_result = None
    else:
        task_result = None

    if task_result is not None and not force_rerun:
        logging.info(f"Skipping task '{task.ID}' with no injections because it was already run.")
        return task_result.utility, task_result.security

    with TraceLogger(
        delegate=Logger.get(),
        suite_name=suite.name,
        user_task_id=task.ID,
        injection_task_id=None,
        injections={},
        attack_type="none",
        pipeline_name=agent_pipeline.name,
    ) as logger:
        try:
            utility, security = suite.run_task_with_pipeline(agent_pipeline, task, injection_task=None, injections={})
        except BadRequestError as e:
            if e.code == "context_length_exceeded" or e.param == "max_tokens":
                logger.log_error(f"Skipping task {task.ID} due to context_length_exceeded: {e}")
                utility = False
                security = True
            else:
                raise e
        except ApiError as e:
            if "internal server error" in str(e):
                logger.log_error(f"Skipping task {task.ID} because of internal server error: {e}")
                utility = False
                security = True
            else:
                raise e
        except InternalServerError as e:
            logger.log_error(f"Skipping task {task.ID} because of internal server error: {e}")
            utility = False
            security = True
        logger.set_contextarg("utility", utility)
        logger.set_contextarg("security", security)

    return utility, security


def benchmark_suite_without_injections(
    agent_pipeline: BasePipelineElement,
    suite: TaskSuite,
    logdir: Path | None,
    force_rerun: bool,
    user_tasks: Sequence[str] | None = None,
) -> SuiteResults:
    suite_utility_results: dict[tuple[str, str], bool] = {}
    suite_security_results: dict[tuple[str, str], bool] = {}

    if user_tasks is not None:
        user_tasks_to_run = [suite.get_user_task_by_id(user_task_id) for user_task_id in user_tasks]
    else:
        user_tasks_to_run = suite.user_tasks.values()

    for user_task in user_tasks_to_run:
        utility, security = run_task_without_injection_tasks(suite, agent_pipeline, user_task, logdir, force_rerun)
        suite_utility_results[(user_task.ID, "")] = utility
        suite_security_results[(user_task.ID, "")] = security

    return SuiteResults(
        utility_results=suite_utility_results,
        security_results=suite_security_results,
        injection_tasks_utility_results={},
    )


class TaskResults(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    suite_name: str
    pipeline_name: str
    user_task_id: str
    injection_task_id: str | None
    attack_type: str | None
    injections: dict[str, str]
    messages: list[ChatMessage]
    error: str | None
    utility: bool
    security: bool
    duration: float

    @model_validator(mode="after")
    def check_messages(self) -> Self:
        if len(self.messages) < 2 and "claude-3-opus" not in self.pipeline_name:
            raise ValueError("TaskResults must have at least two messages.")
        return self


def load_task_results(
    pipeline_name: str,
    suite_name: str,
    user_task: str,
    attack_name: str,
    injection_task: str,
    logdir: Path,
) -> TaskResults:
    path = logdir / pipeline_name / suite_name / user_task / attack_name / f"{injection_task}.json"
    with path.open() as f:
        res_dict = json.load(f)

    return TaskResults(**res_dict)
