from typing import Optional

import gin
import numpy as np
import torch
from botorch.acquisition import ExpectedImprovement, qExpectedImprovement
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from botorch.sampling import SobolQMCNormalSampler
from botorch.sampling.pathwise import MatheronPath
from gpytorch.kernels import MaternKernel
from torch.quasirandom import SobolEngine

from bounce.projection import AxUS
from bounce.kernel.categorical_mixture import MixtureKernel
from bounce.trust_region import TrustRegion
from bounce.util.benchmark import ParameterType


@gin.configurable
def create_candidates_discrete(
        x_scaled: torch.Tensor,
        fx_scaled: torch.Tensor,
        acquisition_function: Optional[MatheronPath | ExpectedImprovement],
        model: SingleTaskGP,
        axus: AxUS,
        trust_region: TrustRegion,
        device: str,
        batch_size: int = 1,
        x_bests: Optional[list[torch.Tensor]] = None,
        add_spray_points: bool = True,
        sampler: Optional[SobolQMCNormalSampler] = None,
) -> tuple[torch.Tensor, torch.Tensor, dict]:
    """
    Create candidate points for the next batch.

    Args:
        batch_size: The number of candidate points to create
        x_scaled: The current points in the trust region
        fx_scaled: The function values at the current points
        acquisition_function: The approximate posterior samples
        axus: The current AxUS embedding for the trust region
        trust_region: The current trust region state
        device: The device to use ('cpu' or 'cuda')
        x_bests: The center of the trust region, should be in [0, 1]^d

    Returns: The candidate points, the function values at the candidate points, the new GP hyperparameters, and the new trust region state

    """

    # Get the indices of the continuous parameters
    indices_not_to_optimize = torch.tensor(
        [
            i for b, i in axus.bins_and_indices_of_type(ParameterType.CONTINUOUS)
        ]
    )

    # Find the center of the trust region
    x_centers = torch.clone(x_scaled[fx_scaled.argmin(), :]).detach()
    # x_center should be in [0, 1]^d at this point
    x_centers = torch.repeat_interleave(x_centers.unsqueeze(0), batch_size, dim=0)
    if x_bests is not None:
        # replace
        x_centers[:, indices_not_to_optimize] = (x_bests[:, indices_not_to_optimize] + 1) / 2

    # define the number of candidates as in the TuRBO paper
    n_candidates = min(5000, max(2000, 200 * axus.target_dim))

    x_batch_return = torch.zeros((batch_size, axus.target_dim), dtype=x_scaled.dtype, device=x_scaled.device)
    fx_batch_return = torch.zeros((batch_size, 1), dtype=fx_scaled.dtype, device=fx_scaled.device)

    for batch_index in range(batch_size):
        _acquisition_function = acquisition_function
        if acquisition_function is None:
            assert sampler is not None, "Either acquisition_function or sampler must be provided"
            x_pending = x_batch_return[:batch_index, :] if batch_index > 0 else None
            _acquisition_function = qExpectedImprovement(
                model=model,
                best_f=(-fx_scaled).max().item(),
                sampler=sampler,
                X_pending=x_pending,
            )

        def ts(
                x: torch.Tensor,
                batch_index: int
        ):
            """
            Get the approximate posterior sample of a specific batch index.

            Args:
                x: The points to evaluate the posterior sample at
                batch_index: The index of the batch to evaluate the posterior sample for

            Returns: The approximate posterior sample at the given points

            """

            if type(_acquisition_function).__name__ != "MatheronPath":
                return -_acquisition_function(x.unsqueeze(1))
            else:
                return _acquisition_function(x)[batch_index, :]

        x_candidates = sample_initial_points_discrete(
            x_center=x_centers[batch_index],
            axus=axus,
            tr_length=trust_region.length_discrete,
            n_initial_points=n_candidates,
        )

        if add_spray_points:
            x_spray = hamming_neighbors_within_tr(
                x_center=x_centers[batch_index],
                x=x_centers[batch_index],
                tr_length=trust_region.length_discrete,
                axus=axus,
            )
            x_candidates = torch.vstack((x_candidates, x_spray))

        # Evaluate the acquisition function for all candidates
        with torch.no_grad():
            candidate_acquisition_values = ts(x_candidates, batch_index=batch_index)
        # Find the top k candidates with the highest acquisition function value
        top_k_candidate_indices = \
            torch.topk(candidate_acquisition_values, k=min(20, len(candidate_acquisition_values)), largest=False)[1]
        # Start local search
        best_posterior_value = torch.inf
        x_best = None

        for top_index in top_k_candidate_indices:
            x_candidate = x_candidates[top_index, :].clone().unsqueeze(0)

            posterior_value_k = candidate_acquisition_values[top_index].item()

            if posterior_value_k < best_posterior_value:
                best_posterior_value = posterior_value_k
                x_best = x_candidate
            while True:
                x_start_neighbors = hamming_neighbors_within_tr(
                    x_center=x_centers[batch_index],
                    x=x_candidate,
                    tr_length=trust_region.length_discrete,
                    axus=axus,
                )

                # remove rows from x_start_neighbors that are already in self.x (which is a 2d tensor of shape (n, d))
                for x_eval in x_scaled.to(device=device):
                    x_start_neighbors = x_start_neighbors[~torch.all(x_start_neighbors == x_eval, dim=1)]

                if x_start_neighbors.numel() == 0:
                    # no neighbors left, continue with next top candidate
                    break

                with torch.no_grad():
                    neighbors_acq_val = ts(x_start_neighbors, batch_index=batch_index)

                if len(neighbors_acq_val) > 0 and torch.min(neighbors_acq_val) < posterior_value_k:
                    x_candidate = x_start_neighbors[torch.argmin(neighbors_acq_val)]
                    posterior_value_k = torch.min(neighbors_acq_val).item()
                else:
                    # could not find a better neighbor, continue with next top candidate
                    break
                if posterior_value_k < best_posterior_value:
                    best_posterior_value = posterior_value_k
                    x_best = x_candidate.unsqueeze(0)

        if x_best is None:
            # choose random point
            x_best = x_centers[batch_index].unsqueeze(0)
        # repeat x_cand batch_size many times
        x_batch_return[batch_index, :] = x_best.squeeze()
        fx_batch_return[batch_index, :] = best_posterior_value

    assert len(indices_not_to_optimize) == 0 or torch.any(
        x_centers[:, indices_not_to_optimize].squeeze() == x_batch_return[:, indices_not_to_optimize].squeeze()
    ), "x_ret should not be optimized at indices_not_to_optimize"

    # transform to [-1, 1], was [0, 1]
    x_batch_return = x_batch_return * 2 - 1

    tr_state = {
        "center": x_scaled[fx_scaled.argmin(), :].detach().cpu().numpy().reshape(1, -1),
        "length": np.array([trust_region.length_discrete]),
    }

    return x_batch_return, fx_batch_return.reshape(batch_size), tr_state


def create_candidates_continuous(
        x_scaled: torch.Tensor,
        fx_scaled: torch.Tensor,
        acquisition_function: Optional[MatheronPath | ExpectedImprovement],
        model: SingleTaskGP,
        axus: AxUS,
        trust_region: TrustRegion,
        device: str,
        batch_size: int,
        indices_to_optimize: Optional[torch.Tensor] = None,
        x_bests: Optional[list[torch.Tensor]] = None,
        sampler: Optional[SobolQMCNormalSampler] = None,
) -> tuple[torch.Tensor, torch.Tensor, dict]:
    """
    Create candidate points for the next batch.

    Args:
        x_scaled: The current points in the trust region
        fx_scaled: The function values at the current points
        acquisition_function: The Matheron paths
        model: The current GP model
        axus: The current AxUS embedding for the trust region
        trust_region: The current trust region state
        device: The device to use ('cpu' or 'cuda')
        indices_to_optimize: The indices of the candidate points to optimize (in case of mixed spaces)
        x_bests: The center of the trust region
        batch_size: int

    Returns:
        The candidate points, the function values at the candidate points, the new GP hyperparameters, and the new trust region state

    """

    if indices_to_optimize is None:
        indices_to_optimize = torch.arange(axus.target_dim)
    indices_not_to_optimize = torch.arange(axus.target_dim)[
        ~torch.isin(torch.arange(axus.target_dim), indices_to_optimize)]

    def acq_for_batch(
            x: torch.Tensor,
            batch_index: int
    ):
        """
        Get the approximate posterior sample of a specific batch index.

        Args:
            x: The points to evaluate the posterior sample at
            batch_index: The index of the batch to evaluate the posterior sample for

        Returns: The approximate posterior sample at the given points

        """
        if type(acquisition_function).__name__ != "MatheronPath":
            return acquisition_function(x)
        else:
            return acquisition_function(x)[batch_index, :]

    x_centers = torch.clone(x_scaled[fx_scaled.argmin(), :]).detach()
    # repeat x_centers batch_size many times
    x_centers = torch.repeat_interleave(x_centers.unsqueeze(0), batch_size, dim=0)

    if x_bests is not None:
        x_centers[:, indices_not_to_optimize] = (x_bests[:, indices_not_to_optimize] + 1) / 2

    assert len(x_centers.shape) == 2, "x_center should be a 2d tensor"

    fx_argmins = torch.zeros(batch_size, dtype=torch.long, device=device)
    fx_mins = torch.zeros(batch_size, dtype=torch.double, device=device)
    x_cand_downs = torch.zeros((batch_size, axus.target_dim), dtype=torch.double, device=device)
    for batch_index in range(batch_size):
        x_center = x_centers[batch_index, :]

        if isinstance(model.covar_module.base_kernel, MixtureKernel):
            weights = model.covar_module.base_kernel.continuous_kernel.lengthscale.detach().squeeze(0)
        elif isinstance(model.covar_module.base_kernel, MaternKernel):
            weights = model.covar_module.base_kernel.lengthscale.detach().squeeze(0)
        else:
            raise NotImplementedError("Only MixtureKernel and MaternKernel are supported")
        weights /= weights.mean()
        weights /= torch.prod(torch.pow(weights, 1 / len(weights)))
        _x_center = x_center[indices_to_optimize]
        _tr_lb = torch.clip(_x_center - trust_region.length_continuous * weights / 2, 0, 1)
        _tr_ub = torch.clip(_x_center + trust_region.length_continuous * weights / 2, 0, 1)
        tr_lb = torch.zeros(axus.target_dim, dtype=torch.double, device=device)
        tr_ub = torch.ones(axus.target_dim, dtype=torch.double, device=device)
        tr_lb[indices_to_optimize] = _tr_lb
        tr_ub[indices_to_optimize] = _tr_ub

        if not isinstance(acquisition_function, MatheronPath):
            _acquisition_function = acquisition_function
            if acquisition_function is None:
                assert sampler is not None, "Either acquisition_function or sampler must be provided"
                x_pending = x_cand_downs[:batch_index, :] if batch_index > 0 else None
                _acquisition_function = qExpectedImprovement(
                    model=model,
                    best_f=(-fx_scaled).max().item(),
                    sampler=sampler,
                    X_pending=x_pending,
                )

            # EI-based acquisition function
            x_cand_down = optimize_acqf(
                acq_function=_acquisition_function,
                bounds=torch.stack([tr_lb, tr_ub], dim=0),
                q=1,
                fixed_features={i: x_center[i].item() for i in indices_not_to_optimize.tolist()},
                num_restarts=10,
                raw_samples=512,
            )
            x_cand_down, y_cand_down = x_cand_down
            x_cand_downs[batch_index, :] = x_cand_down
            fx_argmins[batch_index] = -y_cand_down
        else:
            # Matheron path-based acquisition function
            n_candidates = min(5000, max(2000, 200 * len(indices_to_optimize)))
            sobol = SobolEngine(dimension=len(indices_to_optimize), scramble=True)
            pert = sobol.draw(n_candidates).to(device=device)
            pert = pert * (tr_ub - tr_lb) + tr_lb

            # Create candidate points from the perturbations and the mask
            x_cand_down = torch.clone(torch.repeat_interleave(x_center.unsqueeze(0), n_candidates, dim=0))

            # Create a perturbation mask
            prob_perturb = min(20 / len(indices_to_optimize), 1.0)
            mask = torch.zeros(x_cand_down.shape, device=device, dtype=torch.bool)
            mask[:, indices_to_optimize] = torch.rand(pert.shape, device=device) < prob_perturb

            x_cand_down[mask] = pert[mask[:, indices_to_optimize]]

            fx_cand = acq_for_batch(x_cand_down, batch_index=batch_index)
            # add batch-wise minimum to fx_argmins
            fx_argmins[batch_index] = fx_cand.argmin()
            fx_mins[batch_index] = fx_cand[batch_index].min()
            x_cand_downs[batch_index, :] = x_cand_down[fx_argmins[batch_index], :]

    tr_state = {
        "center": x_scaled[fx_scaled.argmin(), :].detach().cpu().numpy().reshape(1, -1),
        "length": np.array([trust_region.length_continuous]),
        "lb"    : tr_lb.detach().cpu().numpy(),
        "ub"    : tr_ub.detach().cpu().numpy(),
    }

    return x_cand_downs * 2 - 1, fx_mins.reshape(batch_size), tr_state


def hamming_distance(
        x: torch.Tensor,
        y: torch.Tensor,
) -> torch.Tensor:
    """
    Compute the Hamming distance between a set of points (x) and a vector (y)

    Args:
        x: The set of points
        y: The second vector

    Returns:
        The Hamming distance between the points and the vector
    """
    if len(x.shape) == 1:
        x = x.unsqueeze(0)
    assert len(x.shape) == 2, "x must be a matrix"
    if len(y.shape) == 2:
        y = y.squeeze()
    assert len(y.shape) == 1, "y must be a vector"

    return torch.sum(x != y, dim=1)


def hamming_neighbors_within_tr(
        x: torch.Tensor,
        x_center: torch.Tensor,
        tr_length: torch.Tensor,
        axus: AxUS,
) -> torch.Tensor:
    """
    Find the neighbors of the points in x that are within Hamming distance 1 and still the trust region

    Args:
        x: The points to compute the neighbors for
        x_center: The center of the trust region
        tr_length: The length of the trust region
        axus: The AxUS embedding
    """
    x = torch.clone(x)
    if len(x.shape) == 2:
        x = x.squeeze()
    assert len(x.shape) == 1, "x must be a vector"

    discrete_parameter_types = [pt for pt in ParameterType if pt != ParameterType.CONTINUOUS]

    neighbors_for_type = dict()

    for parameter_type in discrete_parameter_types:
        if axus.n_bins_of_type(parameter_type) == 0:
            # No parameters of this type
            continue
        if parameter_type == ParameterType.BINARY:
            indices = torch.tensor([i for b, i in axus.bins_and_indices_of_type(parameter_type)])
            diagonal = torch.zeros_like(x)
            diagonal[indices] = 1
            diag_nonzero = diagonal != 0

            type_neighbors = torch.abs(torch.diag(diagonal) - x.unsqueeze(0))[diag_nonzero, :]
        elif parameter_type == ParameterType.CATEGORICAL:
            indicess = [i for b, i in axus.bins_and_indices_of_type(parameter_type)]
            type_neighbors = torch.zeros((0, len(x)), device=x.device)
            for indices in indicess:
                # find inactive indices
                inactive_indices = [i for i in indices if x[i] == 0]
                # create len(inactive_index) copies of x
                x_copies = torch.repeat_interleave(x.unsqueeze(0), len(inactive_indices), dim=0)
                x_copies[:, indices] = 0
                for i, inactive_index in enumerate(inactive_indices):
                    x_copies[i, inactive_index] = 1
                # vstack x_copies to type_neighbors
                type_neighbors = torch.vstack((type_neighbors, x_copies))
        elif parameter_type == ParameterType.ORDINAL:
            raise NotImplementedError("Ordinal parameters are not supported yet")
        else:
            raise ValueError(f"Unknown parameter type {parameter_type}")

        # add type_neighbors to neighbors_for_type
        neighbors_for_type[parameter_type] = type_neighbors

    # stack all neighbors
    neighbors = torch.vstack([type_neighbors for type_neighbors in neighbors_for_type.values()])
    # remove duplicates
    neighbors = torch.unique(neighbors, dim=0)
    # remove the original point
    neighbors = neighbors[torch.any(neighbors != x, dim=1), :]
    # remove the neighbors that are not within the trust region
    neighbors = neighbors[hamming_distance(neighbors, x_center) <= tr_length, :]
    return neighbors


def sample_initial_points_discrete(
        x_center: torch.Tensor,
        tr_length: torch.Tensor,
        axus: AxUS,
        n_initial_points: int,
) -> torch.Tensor:
    """
    Sample initial points for the discrete parameters

    Args:
        x_center: the center of the trust region
        tr_length: the length of the trust region
        axus: the AxUS embedding
        n_initial_points: the number of initial points to sample

    Returns:
        x_cand: the sampled initial points

    """
    discrete_parameter_types = [pt for pt in ParameterType if pt != ParameterType.CONTINUOUS]

    # copy x_center n_initial_points times
    x_cand = torch.repeat_interleave(x_center.unsqueeze(0), n_initial_points, dim=0)

    for parameter_type in discrete_parameter_types:
        if axus.n_bins_of_type(parameter_type) == 0:
            # No parameters of this type
            continue
        if parameter_type == ParameterType.BINARY:
            indices = torch.tensor([i for b, i in axus.bins_and_indices_of_type(parameter_type)])
            # draw min(tr_length, len(indices)) indices for each candidate
            indices_for_cand = torch.tensor(
                np.array(
                    [
                        np.random.choice(indices, min(tr_length - 1, len(indices)), replace=False)
                        for _ in range(n_initial_points)
                    ]
                ), dtype=torch.long, device=x_cand.device
            )
            # draw values for each index
            values_for_cand = torch.randint(
                0,
                2,
                (n_initial_points, len(indices_for_cand[0])),
                dtype=x_cand.dtype,
                device=x_cand.device
            )
            # set values for each candidate
            x_cand = x_cand.scatter_(1, indices_for_cand, values_for_cand)
        elif parameter_type == ParameterType.CATEGORICAL:
            indicess = [i for b, i in axus.bins_and_indices_of_type(parameter_type)]
            if len(indicess) > tr_length:
                index_setss = [
                    np.random.choice(np.arange(len(indicess)), min(tr_length, len(indicess)), replace=False)
                    for _ in range(n_initial_points)
                ]
                for i, index_sets in enumerate(index_setss):
                    index_sets = [indicess[i] for i in index_sets]
                    # set x_cand to 0 for each index
                    x_cand[i, torch.cat(index_sets)] = 0
                    # TODO there's some bug here, fix this at some point
                    #                    n_index_sets = len(index_sets)
                    #                    all_equal_length = len(set([len(indices) for indices in index_sets])) == 1
                    #                    if all_equal_length:
                    #                        length_indices = len(index_sets[0])
                    #                        x_cand[torch.vstack(index_sets)[
                    #                           torch.arange(n_index_sets), np.random.choice(length_indices, n_index_sets)]] = 1
                    #                        pass
                    if True:  # else:
                        # this is the expensive part
                        for indices in index_sets:
                            # set one index to 1
                            x_cand[i, np.random.choice(indices)] = 1
            else:
                for indices in indicess:
                    # set x_cand to 0 for each index
                    x_cand[:, indices] = 0
                    # sample n_initial_points indices
                    indices_for_cand = np.random.choice(indices, n_initial_points)
                    # set one index to 1
                    x_cand[torch.arange(n_initial_points), indices_for_cand] = 1
            pass


        elif parameter_type == ParameterType.ORDINAL:
            raise NotImplementedError("Ordinal parameters are not supported yet")
        else:
            raise ValueError(f"Unknown parameter type {parameter_type}")

    # remove duplicates
    x_cand = torch.unique(x_cand, dim=0)
    # remove points that coincide with x_center
    x_cand = x_cand[torch.any(x_cand != x_center, dim=1), :]
    # remove candidates that are not within the trust region
    x_cand_in_tr = x_cand[hamming_distance(x_cand, x_center) <= tr_length, :]
    if len(x_cand_in_tr) == 0:
        print("Warning: no initial points in trust region")
    return x_cand_in_tr if len(x_cand_in_tr) > 0 else x_cand
