import time
from typing import Any, Dict, Optional

import torch
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.acquisition.objective import LinearMCObjective
from botorch.optim.optimize import optimize_acqf
from botorch.test_functions.synthetic import Hartmann
from torch.quasirandom import SobolEngine

from ..robust_gp.experiment_utils import (
    constant_outlier_generator,
    CorruptedTestProblem,
    normal_outlier_corruption,
    uniform_corruption,
    uniform_input_corruption,
)
from ..robust_gp.models import (
    get_power_transformed_model,
    get_robust_model,
    get_student_t_model,
    get_trimmed_mll_model,
    get_vanilla_model,
    get_winsorized_model,
)
from .problems.bo import PyTorchCNNProblem, RoverProblem, SVMProblem

# `UCI_DATASET_PATH` here should be pointing to the "slice_localization_data.csv" file
# from the "Relative location of CT slices on axial axis" UCI data set from
# https://archive.ics.uci.edu/dataset/206/relative+location+of+ct+slices+on+axial+axis
UCI_DATASET_PATH = "/path/to/slice_localization_data.csv"


METHOD_NAMES = [
    "vanilla",
    "relevance_pursuit",
    "student_t",
    "trimmed_mll",
    "power_transform",
    "winsorize",
    "sobol",
    "oracle",  # Vanilla GP with no outliers
]


def run_one_bo_replication(
    results_fpath: str,
    seed: int,
    method_name: str,
    function_name: str,
    outlier_fraction: float,
    outlier_generator_name: str,
    batch_size: int,
    n_evals: int,
    outlier_generator_kwargs: Optional[Dict[str, Any]] = None,
    n_init: Optional[int] = None,
    dtype: torch.dtype = torch.double,
    device: Optional[torch.device] = None,
    max_sobol_fallbacks: int = 3,  # To handle rare model fitting errors for Student-t
) -> None:

    tkwargs = {"dtype": dtype, "device": device}
    outlier_generator_kwargs = outlier_generator_kwargs or {}
    if method_name == "sobol":
        n_init = n_evals
    else:
        assert n_init is not None, "n_init must be specified for non-Sobol methods"

    if method_name == "oracle":
        outlier_fraction = 0.0
        outlier_generator_name = "none"
        method_name = "vanilla"  # Going forward this is just a vanilla GP
    elif method_name == "trimmed_mll":
        n_evals = min(n_evals, 100)

    # Outlier generator
    if outlier_generator_name == "none":
        outlier_generator = None
    elif outlier_generator_name == "constant":
        if outlier_generator_kwargs.keys() != {"constant"}:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return constant_outlier_generator(
                f=f, X=X, bounds=bounds, constant=outlier_generator_kwargs["constant"]
            )

    elif outlier_generator_name == "uniform_input":
        if len(outlier_generator_kwargs) > 0:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return uniform_input_corruption(f=f, X=X, bounds=bounds)

    elif outlier_generator_name == "uniform":
        if outlier_generator_kwargs.keys() != {"lower", "upper"}:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return uniform_corruption(
                f=f,
                X=X,
                bounds=bounds,
                lower=outlier_generator_kwargs["lower"],
                upper=outlier_generator_kwargs["upper"],
            )

    elif outlier_generator_name == "normal":
        if outlier_generator_kwargs.keys() != {"noise_std"}:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return normal_outlier_corruption(
                f=f, X=X, bounds=bounds, noise_std=outlier_generator_kwargs["noise_std"]
            )

    else:
        raise ValueError(f"Unknown outlier generator name: {outlier_generator_name}")

    # Base problem function
    # NOTE: The domain is assumed to be [0, 1]^d for all problems.
    if function_name == "hartmann6":
        dim = 6
        bounds = torch.cat((torch.zeros(1, dim), torch.ones(1, dim))).to(**tkwargs)  # pyre-ignore
        base_test_problem = Hartmann(dim=dim)
        minimize = True
    elif function_name == "pytorch_cnn":
        dim = 5
        bounds = torch.cat((torch.zeros(1, dim), torch.ones(1, dim))).to(**tkwargs)  # pyre-ignore
        base_test_problem = PyTorchCNNProblem(outlier_fraction=outlier_fraction)
        minimize = False
    elif "rover_" in function_name:
        dim = int(function_name.split("_")[-1])
        bounds = torch.cat((torch.zeros(1, dim), torch.ones(1, dim))).to(**tkwargs)  # pyre-ignore
        base_test_problem = RoverProblem(dim=dim, outlier_fraction=outlier_fraction)
        minimize = False
        assert (
            outlier_generator_name == "none"
        ), "Outlier generator must be None for Rover."
    elif function_name == "svm":
        dim = 3
        bounds = torch.cat((torch.zeros(1, dim), torch.ones(1, dim))).to(**tkwargs)  # pyre-ignore
        base_test_problem = SVMProblem(
            dataset_path=UCI_DATASET_PATH, outlier_fraction=outlier_fraction
        )
        minimize = True
        assert (
            outlier_generator_name == "none"
        ), "Outlier generator must be None for SVM."
    else:
        raise ValueError(f"Unknown function name: {function_name}")

    # Test problem
    if outlier_generator is None:
        objective_function = base_test_problem
    elif function_name in ["pytorch_cnn", "rover", "svm"]:
        objective_function = base_test_problem  # Base problem will add the outliers
    else:
        objective_function = CorruptedTestProblem(
            base_test_problem=base_test_problem,
            outlier_generator=outlier_generator,
            outlier_fraction=outlier_fraction,
        )

    # Sobol batch
    X = SobolEngine(dimension=dim, scramble=True, seed=seed).draw(n_init).to(**tkwargs)  # pyre-ignore
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        Y = objective_function(X).unsqueeze(-1)

    true_Y_inference = []
    num_sobol_fallbacks = 0
    fit_time, gen_time = 0, 0
    while len(Y) < n_evals:
        try:
            # Fit model
            start_time = time.monotonic()
            train_Y = (Y.clone() - Y.mean()) / Y.std()
            if method_name == "vanilla":
                model = get_vanilla_model(X=X, Y=train_Y)
            elif method_name == "relevance_pursuit":
                model = get_robust_model(
                    X=X,
                    Y=train_Y,
                    use_forward_algorithm=False,
                    convex_parameterization=False,
                )
            elif method_name == "student_t":
                model = get_student_t_model(X=X, Y=train_Y)
            elif method_name == "trimmed_mll":
                model = get_trimmed_mll_model(X=X, Y=train_Y)
            elif method_name == "power_transform":
                model = get_power_transformed_model(X=X, Y=train_Y)
            elif method_name == "winsorize":
                model = get_winsorized_model(
                    X=X, Y=train_Y, winsorize_lower=not minimize
                )
            else:
                raise ValueError(f"Unknown method name: {method_name}")
            fit_time += time.monotonic() - start_time
        except Exception as e:
            num_sobol_fallbacks += 1

            # Use observed values for inference regret since there is no model
            best_idx = Y.argmin() if minimize else Y.argmax()
            true_Y_inference.append(
                objective_function.evaluate_true(X[best_idx]).item()
            )

            # Sobol fallback
            candidates = (
                SobolEngine(dimension=dim, scramble=True).draw(batch_size).to(**tkwargs)  # pyre-ignore
            )
            Y_next = objective_function(candidates).unsqueeze(-1)

            # Append and continue
            X = torch.cat((X, candidates))
            Y = torch.cat((Y, Y_next))
            continue

        # Compute in-sample inference regret
        if minimize:
            best_idx = model.posterior(X).mean.argmin()
        else:
            best_idx = model.posterior(X).mean.argmax()
        true_Y_inference.append(objective_function.evaluate_true(X[best_idx]).item())

        # Optimize acquisition function
        start_time = time.monotonic()
        weight = -1 if minimize else 1
        objective = LinearMCObjective(torch.tensor([weight], **tkwargs))  # pyre-ignore
        qLogNEI = qLogNoisyExpectedImprovement(
            model=model, X_baseline=X, objective=objective
        )
        candidates, acq_values = optimize_acqf(
            qLogNEI,
            bounds=bounds,
            q=batch_size,
            num_restarts=8,
            raw_samples=1024,
        )
        gen_time += time.monotonic() - start_time

        if num_sobol_fallbacks >= max_sobol_fallbacks:
            raise RuntimeError(f"Too many Sobol fallbacks ({num_sobol_fallbacks})")

        Y_next = objective_function(candidates).unsqueeze(-1)

        # Append
        X = torch.cat((X, candidates))
        Y = torch.cat((Y, Y_next))

    # Post-process
    true_Y = torch.tensor([objective_function.evaluate_true(xy) for xy in X], **tkwargs)  # pyre-ignore

    # Save the final output
    output_dict = {
        "method_name": method_name,
        "function_name": function_name,
        "outlier_generator_name": outlier_generator_name,
        "outlier_fraction": outlier_fraction,
        "outlier_generator_kwargs": outlier_generator_kwargs,
        "batch_size": batch_size,
        "n_evals": n_evals,
        "n_init": n_init,
        "X": X.cpu(),
        "Y": Y.cpu(),
        "true_Y": true_Y.cpu(),
        "true_Y_inference": torch.tensor(true_Y_inference, dtype=dtype),
        "fit_time": fit_time,
        "gen_time": gen_time,
    }
    torch.save(output_dict, results_fpath)
