from __future__ import annotations

import math

from abc import ABC, abstractmethod
from copy import copy, deepcopy
from functools import partial
from typing import (
    Any,
    Callable,
    cast,
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    Tuple,
    Type,
)

import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models.model import Model
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from torch import Tensor
from torch.nn.parameter import Parameter

MLL_ITER = 1024
MLL_TOL = 1e-8


class RelevancePursuitMixin(ABC):
    """Mixin class to convert between the sparse and dense representations of the
    relevance pursuit models' sparse parameters, as well as to compute the generalized
    support acquisition and support deletion criteria.
    """

    dim: int  # the total number of features

    # IDEA: could generalize this to sets of parameters Dict[str, List[int]]
    # Beside looping over the parameters for all the sparse / dense conversions,
    # we'd need to introduce a vectorial representation of all the parameters
    # for the selection of the acquisition / deletion indices.
    # We don't really need to enforce a vectorial parameter storage for this, we
    # only need to introduce a helper that computes the (parameter, index) pair
    # that maximize the acquisition criterion.
    # potentially relevant: get_tensors_as_ndarray_1d
    _support: List[int]  # indices of the features in the support, subset of range(dim)

    def __init__(
        self,
        dim: int,
        support: Optional[List[int]],
    ) -> None:
        self.dim = dim
        self._support = support if support is not None else []
        # Assumption: sparse_parameter is initialized in sparse representation
        self._is_sparse = True
        self._expansion_modifier = None
        self._contraction_modifier = None

    @property
    @abstractmethod
    def sparse_parameter(self) -> Parameter: ...

    @abstractmethod
    def set_sparse_parameter(self, value: Parameter) -> None:
        """Sets the sparse parameter.

        NOTE: We can't use because torch Parameter setters intercept prior to the
        canonical setter.
        """
        pass

    @staticmethod
    def _from_model(model: Model) -> RelevancePursuitMixin:
        """Retrieves a RelevancePursuitMixin from a model."""
        raise NotImplementedError

    @property
    def is_sparse(self) -> bool:
        # Do we need to differentiate between a full support sparse representation and
        # a full support dense representation? The order the of the indices could be
        # different, unless we keep them sorted.
        return self._is_sparse

    @property
    def support(self) -> List[int]:
        """The indices of the active parameters."""
        return self._support

    @property
    def is_active(self) -> Tensor:
        """A Boolean Tensor of length dim, indicating which of the d dimensions are in
        the support, i.e. "active".
        """
        is_active = [(i in self.support) for i in range(self.dim)]
        return torch.tensor(
            is_active, dtype=torch.bool, device=self.sparse_parameter.device
        )

    @property
    def active_parameters(self) -> Tensor:
        if self.is_sparse:
            return self.sparse_parameter
        else:
            return self.sparse_parameter[self.support]

    @property
    def inactive_indices(self) -> Tensor:
        device = self.sparse_parameter.device
        return torch.arange(self.dim, device=device)[~self.is_active]

    def to_sparse(self) -> RelevancePursuitMixin:
        # should we prohibit this for the case where the support is the full set?
        if not self.is_sparse:
            self.set_sparse_parameter(
                torch.nn.Parameter(self.sparse_parameter[self.support])
            )
            self._is_sparse = True
        return self

    def to_dense(self) -> RelevancePursuitMixin:
        if not self.is_sparse:
            return self  # already dense
        dtype = self.sparse_parameter.dtype
        device = self.sparse_parameter.device
        zero = torch.tensor(
            0.0,
            dtype=dtype,
            device=device,
        )
        dense_parameter = [
            (
                self.sparse_parameter[self.support.index(i)]
                if i in self.support
                else zero
            )
            for i in range(self.dim)
        ]
        dense_parameter = torch.tensor(dense_parameter, dtype=dtype, device=device)
        self.set_sparse_parameter(torch.nn.Parameter(dense_parameter))
        self._is_sparse = False
        return self

    def append_support(self, index: int) -> RelevancePursuitMixin:
        self.expand_support([index])
        return self

    def expand_support(self, indices: List[int]) -> RelevancePursuitMixin:
        for i in indices:
            if i in self.support:
                raise ValueError(f"Feature {i} already in support.")

        self.support.extend(indices)
        # we need to add the parameter in the sparse representation
        if self.is_sparse:
            self.set_sparse_parameter(
                torch.nn.Parameter(
                    torch.cat(
                        (
                            self.sparse_parameter,
                            torch.zeros(len(indices)).to(self.sparse_parameter),
                        )
                    )
                )
            )
        return self

    def contract_support(self, indices: List[int]) -> RelevancePursuitMixin:
        # indices into the sparse representation of features to *keep*
        sparse_indices = list(range(len(self.support)))
        original_support = copy(self.support)
        for i in indices:
            if i not in self.support:
                raise ValueError(f"Feature {i} is not in support.")
            sparse_indices.remove(original_support.index(i))
            self.support.remove(i)

        # we need to add the parameter in the sparse representation
        if self.is_sparse:
            self.set_sparse_parameter(Parameter(self.sparse_parameter[sparse_indices]))
        else:
            requires_grad = self.sparse_parameter.requires_grad
            self.sparse_parameter.requires_grad_(False)
            self.sparse_parameter[indices] = 0.0
            self.sparse_parameter.requires_grad_(requires_grad)  # restore
        return self

    def drop_zeros_from_support(self, threshold: float = 0.0) -> RelevancePursuitMixin:
        # drops indices from support whose corresponding values are zero
        # TODO: figure out batch_shape if necessary, this seems complicated
        # to make batched, unless we force the support to be the same for
        # all batches.
        is_zero = self.sparse_parameter <= threshold
        if self.is_sparse:
            indices = [self.support[i] for i, b in enumerate(is_zero) if b]
        else:
            indices = [i for i, b in enumerate(is_zero) if b and i in self.support]
        self.contract_support(indices)
        return self

    def drop_threshold_from_support(
        self, lower: float, upper: float
    ) -> RelevancePursuitMixin:
        # drops indices from support whose corresponding values are zero
        # TODO: figure out batch_shape if necessary, this seems complicated
        # to make batched, unless we force the support to be the same for
        # all batches.
        is_small = self.sparse_parameter <= lower
        is_large = self.sparse_parameter >= upper
        to_drop = is_small | is_large
        if self.is_sparse:
            indices = [self.support[i] for i, b in enumerate(to_drop) if b]
        else:
            indices = [i for i, b in enumerate(to_drop) if b and i in self.support]
        self.contract_support(indices)
        return self

    # support initialization helpers
    def full_support(self) -> RelevancePursuitMixin:
        self.expand_support([i for i in range(self.dim) if i not in self.support])
        self.to_dense()  # no reason to be sparse with full support
        return self

    def remove_support(self) -> RelevancePursuitMixin:
        self._support = []
        requires_grad = self.sparse_parameter.requires_grad
        if self.is_sparse:
            self.set_sparse_parameter(
                torch.nn.Parameter(torch.tensor([]).to(self.sparse_parameter))
            )
        else:
            self.sparse_parameter.requires_grad_(False)
            self.sparse_parameter[:] = 0.0
        self.sparse_parameter.requires_grad_(requires_grad)
        return self

    def random_support(self, n: int) -> RelevancePursuitMixin:
        # randperm could also be interesting as an expansion tactic in cases
        # where we want to avoid evaluating other criteria
        self.remove_support()
        if n == self.dim:
            self.full_support()
        elif 0 < n and n < self.dim:
            # random support initialization
            self.expand_support(torch.randperm(self.dim)[:n].tolist())
        else:
            raise ValueError(f"Cannot add more than {self.dim} indices to support.")
        return self

    # the following two methods are the only ones that are specific to the marginal
    # likelihood optimization problem
    def support_expansion(
        self,
        mll: ExactMarginalLogLikelihood,
        n: int = 1,
        modifier: Optional[Callable[[Tensor], Tensor]] = None,
    ) -> bool:
        """Computes the indices of the features that maximize the gradient of the sparse
        parameter and that are not already in the support, and subsequently expands the
        support to include the features if their gradient is positive.

        Args:
            mll: The marginal likelihood, containing the model to optimize.
                NOTE: Virtually all of the rest of the code is not specific to the
                marginal likelihood optimization, so we could generalize this to work
                with any objective.
            n: The number of features to select.
            modifier: A function that modifies the gradient before computing
                the support expansion criterion. This is useful, for example,
                when we want to select the maximum gradient magnitude for real-valued
                (not non-negative) parameters, in which case modifier = torch.abs.

        Returns:
            True if the support was expanded, False otherwise.
        """
        g = self.expansion_objective(mll)

        modifier = modifier if modifier is not None else self._expansion_modifier
        if modifier is not None:
            # IDEA: could compute a Newton step here / use the approximation to the
            # Hessian that is returned by L-BFGS.
            g = modifier(g)

        # support is already removed from consideration
        # gradient of the support parameters is not necessarily zero,
        # even for a converged solution in the presence of constraints.
        # IDEA: could use the vectorized representation of all
        # parameters in the optimizer stack to make this selection
        # over multiple parameter groups.
        # NOTE: these indices are relative to self.inactive_indices.
        indices = g.argsort(descending=True)[:n]
        indices = indices[g[indices] > 0]
        if indices.numel() == 0:  # no indices with positive gradient
            return False
        self.expand_support(self.inactive_indices[indices].tolist())

        return True

    # TODO: generalize contraction_objective
    def expansion_objective(self, mll: ExactMarginalLogLikelihood) -> Tensor:
        """Computes an objective value for all the inactive parameters, i.e.
        self.sparse_parameter[~self.is_active] since we can't add already active
        parameters to the support.
        """
        return self._sparse_parameter_gradient(mll)

    def _sparse_parameter_gradient(self, mll: ExactMarginalLogLikelihood) -> Tensor:
        """Computes the gradient of the marginal likelihood with respect to the
        sparse parameter.

        Args:
            mll: The marginal likelihood, containing the model to optimize.

        Returns:
            The gradient of the marginal likelihood with respect to the inactive
            sparse parameters.
        """
        # evaluate gradient of the sparse parameter
        is_sparse = self.is_sparse  # in order to restore the original representation
        self.to_dense()  # need the parameter in its dense parameterization

        requires_grad = self.sparse_parameter.requires_grad
        self.sparse_parameter.requires_grad_(True)
        if self.sparse_parameter.grad is not None:
            self.sparse_parameter.grad.zero_()
        mll.train()  # NOTE: this changes model.train_inputs
        X, Y = mll.model.train_inputs[0], mll.model.train_targets
        cast(Tensor, mll(mll.model(X), Y)).backward()  # evaluation
        self.sparse_parameter.requires_grad_(requires_grad)

        g = self.sparse_parameter.grad
        if g is None:
            raise ValueError("Gradient is not available.")

        if is_sparse:
            self.to_sparse()

        return g[~self.is_active]  # only need the inactive parameters

    def support_contraction(
        self,
        mll: ExactMarginalLogLikelihood,
        n: int = 1,
        modifier: Optional[Callable[[Tensor], Tensor]] = None,
    ) -> bool:
        """Computes the indices of the features that have the smallest coefficients,
        and subsequently contracts the exlude the features.

        Args:
            mll: The marginal likelihood, containing the model to optimize.
                NOTE: Virtually all of the rest of the code is not specific to the
                marginal likelihood optimization, so we could generalize this to work
                with any objective.
            n: The number of features to select for removal.
            modifier: A function that modifies the parameter values before computing
                the support contraction criterion.

        Returns:
            True if the support was expanded, False otherwise.
        """
        if len(self.support) == 0:
            return False

        is_sparse = self.is_sparse
        self.to_sparse()
        x = self.sparse_parameter

        modifier = modifier if modifier is not None else self._contraction_modifier
        if modifier is not None:
            x = modifier(x)

        # IDEA: for non-negative parameters, could break ties at zero
        # depending with derivative
        sparse_indices = x.argsort(descending=False)[:n]
        indices = [self.support[i] for i in sparse_indices]
        self.contract_support(indices)
        if not is_sparse:
            self.to_dense()
        return True

    def optimize_mll(
        self,
        mll: ExactMarginalLogLikelihood,
        model_trace: Optional[List[Model]] = None,
        reset_parameters: bool = True,
        reset_dense_parameters: bool = False,
        optimizer_kwargs: Optional[Mapping[str, Any]] = None,
    ):
        """Optimizes the marginal likelihood.

        Args:
            mll: The marginal likelihood, containing the model to optimize.
            model_trace: If not None, a list to which a deepcopy of the model state is
                appended. NOTE This operation is *in place*.
            reset_parameters: If true, initializes the sparse parameter to all zeros.
            optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.

        Returns:
            The marginal likelihood after optimization.
        """
        if reset_parameters:
            # this might be beneficial because the parameters can
            # end up at a constraint boundary, which can anecdotally make
            # it more difficult to move the newly added parameters.
            # should we only do this after expansion?
            # IDEA: should we also reset the dense parameters?
            with torch.no_grad():
                self.sparse_parameter.zero_()

        if reset_dense_parameters:
            # re-initialize dense parameters
            initialize_dense_parameters(mll.model)

        # move to sparse representation for optimization
        # NOTE: this function should never force the dense representation, because some
        # models might never need it, and it would be inefficient.
        self.to_sparse()
        fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
        if model_trace is not None:
            # need to record the full model here, rather than just the sparse parameter
            # since other hyper-parameters are co-adapted to the sparse parameter.
            model_trace.append(deepcopy(mll.model))
        return mll


################################# Optimization Algorithms ############################
def relevance_pursuit(
    sparse_module: RelevancePursuitMixin,
    mll: ExactMarginalLogLikelihood,
    num_iter: int,
    num_init: int = 0,
    num_expand: int = 1,
    num_contract: int = 0,
    mll_iter: int = 10_000,  # let's take convergence seriously
    mll_tol: float = 1e-8,
    optimizer_kwargs: Optional[Mapping[str, Any]] = None,
    reset_parameters: bool = True,
    reset_dense_parameters: bool = False,
    record_model_trace: bool = False,
) -> Tuple[RelevancePursuitMixin, Optional[List[Model]]]:
    """Relevance pursuit algorithm for the sparse marginal likelihood optimization
    of Gaussian process parameters. In its most general form, it is a forward-backward
    algorithm, but the forward and backward stages can be called independently, and
    modulated with the num_expand and num_contract arguments.

    Ideas:
        - Could re-optimize after every single expansion, even if num_expand > 1.
        - Could drop exact zeros.

    Args:
        mll: The marginal likelihood, containing the model to optimize.
        num_iter: The number of iterations to run.
        num_init: The number of features to initialize the support with.
        num_expand: The number of features to expand the support with.
        num_contract: The number of features to contract the support with.
        reset_parameters: If true, initializes the sparse parameter to the all zeros
            vector before every marginal likelihood optimization step. If false, the
            optimization is warm-started with the parameters of the previous iteration.
            TODO: check that this doesn't interfere with model trace record
        optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
        record_model_trace: If true, records the model state after every iteration.
            TODO: return this, use it for 1) Bayesian model comparison, 2) GIFs.

    Returns:
        The marginal likelihood after relevance pursuit optimization.
    """
    if optimizer_kwargs is None:
        optimizer_kwargs = {
            "options": {"maxiter": mll_iter, "ftol": mll_tol, "gtol": mll_tol}
        }

    model_trace: Optional[List[Model]] = [] if record_model_trace else None

    def optimize_mll(mll):
        return sparse_module.optimize_mll(
            mll=mll,
            model_trace=model_trace,
            reset_parameters=reset_parameters,
            reset_dense_parameters=reset_dense_parameters,
            optimizer_kwargs=optimizer_kwargs,
        )

    optimize_mll(mll)  # initial optimization

    for _ in range(num_iter):

        expanded = False
        if num_expand > 0:
            expanded = sparse_module.support_expansion(mll=mll, n=num_expand)
            optimize_mll(mll)  # re-optimize support

        contracted = False
        if num_contract > 0:
            contracted = sparse_module.support_contraction(mll=mll, n=num_contract)
            optimize_mll(mll)  # re-optimize support

        # IDEA: could stop here if the marginal likelihood decreases, assuming that
        # the posterior pdf of the support size is uni-modal.
        if not expanded and not contracted:  # stationary support
            break

    return sparse_module, model_trace


def forward_relevance_pursuit(
    sparse_module: RelevancePursuitMixin,
    mll: ExactMarginalLogLikelihood,
    sparsity_levels: Optional[List[int]] = None,  # add levels for forward-backward
    mll_iter: int = MLL_ITER,
    mll_tol: float = MLL_TOL,
    optimizer_kwargs: Optional[Mapping[str, Any]] = None,
    reset_parameters: bool = False,
    reset_dense_parameters: bool = False,
    record_model_trace: bool = True,
    initial_support: Optional[Iterable[int]] = None,
) -> Tuple[RelevancePursuitMixin, List[Model]]:
    """Forward Relevance Pursuit.

    Args:
        sparse_module: The relevance pursuit module.
        mll: The marginal likelihood, containing the model to optimize.
        sparsity_levels: The sparsity levels to expand the support to.
        optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
        reset_parameters: If true, initializes the sparse parameter to the all zeros
            after each iteration.
        record_model_trace: If true, records the model state after every iteration.
        initial_support: The support with which to initialize the sparse module. By
            default, the support is initialized to the empty set.

    Returns:
        The relevance pursuit module after forward relevance pursuit optimization, and
        a list of models with different supports that were optimized.
    """
    if sparsity_levels is None:
        sparsity_levels = list(range(sparse_module.dim + 1))

    # since this is the forward algorithm, potential sparsity levels
    # must be in increasing order and unique.
    sparsity_levels = list(set(sparsity_levels))
    sparsity_levels.sort(reverse=False)

    sparse_module.remove_support()
    if initial_support is not None:
        sparse_module.expand_support(initial_support)

    if optimizer_kwargs is None:
        optimizer_kwargs = {
            "options": {"maxiter": mll_iter, "ftol": mll_tol, "gtol": mll_tol}
        }

    model_trace = [] if record_model_trace else None

    def optimize_mll(mll):
        return sparse_module.optimize_mll(
            mll=mll,
            model_trace=model_trace,
            reset_parameters=reset_parameters,
            reset_dense_parameters=reset_dense_parameters,
            optimizer_kwargs=optimizer_kwargs,
        )

    # if sparsity levels contains the initial support, remove it
    if sparsity_levels[0] == len(sparse_module.support):
        sparsity_levels.pop(0)

    optimize_mll(mll)  # initial optimization

    for sparsity in sparsity_levels:
        support_size = len(sparse_module.support)
        num_expand = sparsity - support_size
        if not (num_expand > 0):
            raise ValueError(
                "sparsity_levels need to be increasing and larger than initial support."
            )

        expanded = sparse_module.support_expansion(mll=mll, n=num_expand)
        # IDEA: could stop here if the marginal likelihood decreases, assuming that
        # the posterior pdf of the support size is uni-modal.
        if not expanded:  # stationary support
            break

        optimize_mll(mll)  # re-optimize support

    return sparse_module, model_trace


# This is probably the most powerful algorithm, but needs the most steps unless
# we contract the support by more than one in each iteration.
def backward_relevance_pursuit(
    sparse_module: RelevancePursuitMixin,
    mll: ExactMarginalLogLikelihood,
    sparsity_levels: Optional[List[int]] = None,
    mll_iter: int = MLL_ITER,
    mll_tol: float = MLL_TOL,
    optimizer_kwargs: Optional[Mapping[str, Any]] = None,
    reset_parameters: bool = False,
    reset_dense_parameters: bool = False,
    record_model_trace: bool = True,
    initial_support: Optional[Iterable[int]] = None,
) -> Tuple[RelevancePursuitMixin, List[Model]]:
    """Backward Relevance Pursuit.

    Args:
        sparse_module: The relevance pursuit module.
        mll: The marginal likelihood, containing the model to optimize.
        sparsity_levels: The sparsity levels to expand the support to.
        optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
        reset_parameters: If true, initializes the sparse parameter to the all zeros
            after each iteration.
        record_model_trace: If true, records the model state after every iteration.
        initial_support: The support with which to initialize the sparse module. By
            default, the support is initialized to the full set.

    Returns:
        The relevance pursuit module after forward relevance pursuit optimization, and
        a list of models with different supports that were optimized.
    """

    if sparsity_levels is None:
        sparsity_levels = list(range(sparse_module.dim + 1))

    # since this is the backward algorithm, potential sparsity levels
    # must be in decreasing order and unique.
    sparsity_levels = list(set(sparsity_levels))
    sparsity_levels.sort(reverse=True)

    if initial_support is not None:
        sparse_module.remove_support()
        sparse_module.expand_support(initial_support)
    else:
        sparse_module.full_support()

    if optimizer_kwargs is None:
        optimizer_kwargs = {
            "options": {"maxiter": mll_iter, "ftol": mll_tol, "gtol": mll_tol}
        }

    model_trace = [] if record_model_trace else None

    def optimize_mll(mll):
        return sparse_module.optimize_mll(
            mll=mll,
            model_trace=model_trace,
            reset_parameters=reset_parameters,
            reset_dense_parameters=reset_dense_parameters,
            optimizer_kwargs=optimizer_kwargs,
        )

    # if sparsity levels contains the initial support, remove it
    if sparsity_levels[0] == len(sparse_module.support):
        sparsity_levels.pop(0)

    optimize_mll(mll)  # initial optimization

    for sparsity in sparsity_levels:
        support_size = len(sparse_module.support)
        num_contract = support_size - sparsity
        if not (num_contract > 0):
            raise ValueError(
                "sparsity_levels need to be decreasing and less than initial support."
            )

        contracted = sparse_module.support_contraction(mll=mll, n=num_contract)
        # IDEA: could stop here if the marginal likelihood decreases, assuming that
        # the posterior pdf of the support size is uni-modal.
        if not contracted:  # stationary support
            break

        optimize_mll(mll)  # re-optimize support

    return sparse_module, model_trace


def subspace_pursuit(
    sparse_module: RelevancePursuitMixin,
    mll: ExactMarginalLogLikelihood,
    sparsity: int,
    num_iter: int = 16,
    optimizer_kwargs: Optional[Mapping[str, Any]] = None,
    random_init: bool = True,
    forward_init: bool = False,
    reset_parameters: bool = True,
    record_model_trace: bool = False,
) -> Tuple[RelevancePursuitMixin, List[Model]]:
    # TODO: solve problem where expansion does not add `sparsity`` features
    # leading the contraction to prune more features than is desired.
    if random_init:
        sparse_module.random_support(sparsity)

    elif forward_init:
        forward_relevance_pursuit(
            sparse_module=sparse_module,
            mll=mll,
            sparsity=sparsity,
            num_iter=1,  # keep automatic initialization cheap
            optimizer_kwargs=optimizer_kwargs,
            reset_parameters=reset_parameters,
            record_model_trace=record_model_trace,
        )

    return relevance_pursuit(
        sparse_module=sparse_module,
        mll=mll,
        num_iter=num_iter,
        num_expand=sparsity,
        num_contract=sparsity,
        optimizer_kwargs=optimizer_kwargs,
        reset_parameters=reset_parameters,
        record_model_trace=record_model_trace,
    )


################################ Bayesian Model Comparison #############################
def get_posterior_over_support(
    rp_class: Type[RelevancePursuitMixin],
    model_trace: List[Model],
    log_support_prior: Optional[Callable[[Tensor], Tensor]] = None,
    prior_mean_of_support: Optional[float] = None,
) -> Tuple[Tensor, Tensor]:
    """Computes the posterior distribution over a list of models.
    Assumes we are storing both likelihood and GP model in the model_trace.

    IDEA: Could use this inside of relevance_pursuit as a stopping criterion.

    Args:
        rp_class: The relevance pursuit class to use for computing the support size.
            This is used to get the RelevancePursuitMixin from the Model via the static
            method `_from_model`. We could generalize this and let the user pass this
            getter instead.
        model_trace: A list of models with different support sizes, usually generated
            with relevance_pursuit.
        log_support_prior: Computes the log prior probability of a support size.
        prior_mean_of_support: A mean value for the default exponential prior
            distribution over the support size.

    Returns:
        A tensor of posterior marginal likelihoods, one for each model in the trace.
    """

    if log_support_prior is None:
        if prior_mean_of_support is None:
            prior_mean_of_support = 1.0
        log_support_prior = partial(_exp_log_pdf, mean=prior_mean_of_support)

    log_support_prior = cast(Callable[[Tensor], Tensor], log_support_prior)

    def log_prior(model: Model) -> Tuple[Tensor, Tensor]:
        sparse_module = rp_class._from_model(model)
        num_support = torch.tensor(len(sparse_module.support))
        return num_support, log_support_prior(num_support)  # pyre-ignore[29]

    log_mll_trace = []
    log_prior_trace = []
    support_size_trace = []
    for model in model_trace:
        mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
        mll.train()
        X, Y = mll.model.train_inputs[0], mll.model.train_targets
        mll_i = mll(mll.model(X), Y)
        log_mll_trace.append(mll_i)
        support_size, log_prior_i = log_prior(model)
        support_size_trace.append(support_size)
        log_prior_trace.append(log_prior_i)
        # IDEA: could also compute posterior probability that a specific data point is an outlier

    log_mll_trace = torch.stack(log_mll_trace)
    log_prior_trace = torch.stack(log_prior_trace)
    support_size_trace = torch.stack(support_size_trace)

    unnormalized_posterior_trace = log_mll_trace + log_prior_trace
    evidence = unnormalized_posterior_trace.logsumexp(dim=-1)
    posterior_probabilities = (unnormalized_posterior_trace - evidence).exp()
    return support_size_trace, posterior_probabilities


def _exp_log_pdf(x: Tensor, mean: Tensor) -> Tensor:
    """Compute the exponential log probability density.

    Args:
        x: A tensor of values.
        mean: A tensor of means.

    Returns:
        A tensor of log probabilities.
    """
    return -x / mean - math.log(mean)


def exponential_sparsity_levels(
    n: int, base: int = 2, decreasing: bool = False
) -> List[int]:
    """Generates a list of exponentially spaced integers, from zero to n,
    inclusive, which can be used as the default sparsity levels for the
    forward and backward algorithms, striking a balance between considering
    a wide range of support sizes and the computational cost.

    Args:
        n: The maximum value of the list.
        base: The base of the exponential.
        decreasing: If true, the list is generated in decreasing order.

    Returns:
        A list of integers.
    """
    max_k = math.ceil(math.log(n, base))
    sparsity_levels = [0] + [base**k for k in range(1, max_k)]
    if sparsity_levels[-1] < n:
        sparsity_levels.append(n)
    if decreasing:
        sparsity_levels.reverse()
    return sparsity_levels


def initialize_dense_parameters(model: Model) -> Tuple[Model, Dict[str, Any]]:
    """Sets the dense parameters of a model to their initial values. Infers initial
    values from the constraints their bounds, if no initial values are provided. If
    a parameter does not have a constraint, it is initialized to zero.

    Args:
        model: The model to initialize.

    Returns:
        The re-initialized model, and a dictionary of initial values.
    """
    constraints = dict(model.named_constraints())
    parameters = dict(model.named_parameters())
    initial_values = {
        n: getattr(constraints.get(n + "_constraint", None), "_initial_value", None)
        for n in parameters
    }
    lower_bounds = {
        n: getattr(
            constraints.get(n + "_constraint", None),
            "lower_bound",
            torch.tensor(-torch.inf),
        )
        for n in parameters
    }
    upper_bounds = {
        n: getattr(
            constraints.get(n + "_constraint", None),
            "upper_bound",
            torch.tensor(torch.inf),
        )
        for n in parameters
    }
    for n, v in initial_values.items():
        # if no initial value is provided, or the initial value is outside the bounds,
        # use a rule-based initialization.
        if v is None or not ((lower_bounds[n] <= v) and (v <= upper_bounds[n])):
            if upper_bounds[n].isinf():
                if lower_bounds[n].isinf():
                    v = 0.0
                else:
                    v = lower_bounds[n] + 1
            elif lower_bounds[n].isinf():  # implies u[n] is finite
                v = upper_bounds[n] - 1
            else:  # both are finite
                v = lower_bounds[n] + torch.minimum(
                    torch.ones_like(lower_bounds[n]),
                    (upper_bounds[n] - lower_bounds[n]) / 2,
                )
        initial_values[n] = v

    # the initial values need to be converted to the transformed space
    initial_values = {n: v for n, v in initial_values.items()}
    for n, v in initial_values.items():
        c = constraints.get(n + "_constraint", None)
        # convert the constraint into the latent space
        if c is not None:
            initial_values[n] = c.inverse_transform(v)
    model.initialize(**initial_values)
    parameters = dict(model.named_parameters())
    return model, initial_values
