Source code for baybe.surrogates.utils
"""Utilities for working with surrogates."""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, TypeVar
from baybe.exceptions import IncompatibleSurrogateError
from baybe.surrogates.base import Surrogate
if TYPE_CHECKING:
from botorch.posteriors import Posterior
from torch import Tensor
from baybe.surrogates.naive import MeanPredictionSurrogate
_TSurrogate = TypeVar("_TSurrogate", bound=Surrogate)
_constant_target_model_store: dict[int, MeanPredictionSurrogate] = {}
"""Dictionary for storing constant target fallback models. Keys are the IDs of the
surrogate models that temporarily have a fallback attached because they were
trained on constant training targets. Values are the corresponding fallback models."""
[docs]
def catch_constant_targets(
cls: type[_TSurrogate], std_threshold: float = 1e-6
) -> type[_TSurrogate]:
"""Make a ``Surrogate`` class robustly handle constant training targets.
More specifically, "constant training targets" can mean either of:
* The standard deviation of the training targets is below the given threshold.
* There is only one target and the standard deviation cannot even be computed.
The modified class handles the above cases separately from "regular operation"
by resorting to a :class:`baybe.surrogates.naive.MeanPredictionSurrogate`,
which is stored outside the model in a dictionary maintained by this decorator.
Args:
cls: The :class:`baybe.surrogates.base.Surrogate` to be augmented.
std_threshold: The standard deviation threshold below which operation is
switched to the alternative model.
Returns:
The modified class.
"""
from baybe.surrogates.naive import MeanPredictionSurrogate
# References to original methods
_fit_original = cls._fit
_posterior_original = cls._posterior
def _posterior_new(self, candidates_comp_scaled: Tensor, /) -> Posterior:
"""Use fallback model if it exists, otherwise call original posterior."""
# Alternative model fallback
if constant_target_model := _constant_target_model_store.get(id(self), None):
return constant_target_model._posterior(candidates_comp_scaled)
# Regular operation
return _posterior_original(self, candidates_comp_scaled)
def _fit_new(self, train_x: Tensor, train_y: Tensor) -> None:
"""Original fit but with fallback model creation for constant targets."""
if not (train_y.ndim == 2 and train_y.shape[-1] == 1):
raise NotImplementedError(
"The current logic is only implemented for single-target surrogates."
)
# Alternative model fallback
if train_y.numel() == 1 or train_y.std() < std_threshold:
model = MeanPredictionSurrogate()
model._fit(train_x, train_y)
_constant_target_model_store[id(self)] = model
# Regular operation
else:
_constant_target_model_store.pop(id(self), None)
_fit_original(self, train_x, train_y)
# Replace the methods
cls._posterior = _posterior_new # type: ignore
cls._fit = _fit_new # type: ignore
return cls
[docs]
def batchify_mean_var_prediction(
posterior: Callable[[_TSurrogate, Tensor], tuple[Tensor, Tensor]],
) -> Callable[[_TSurrogate, Tensor], tuple[Tensor, Tensor]]:
"""Wrap a posterior method to make it evaluate t-batches as an augmented q-batch."""
@wraps(posterior)
def sequential_posterior(
model: _TSurrogate, candidates: Tensor
) -> tuple[Tensor, Tensor]:
# If no batch dimensions are given, call the model directly
if candidates.ndim == 2:
return posterior(model, candidates)
# Parameter batching is not (yet) supported
if candidates.ndim > 3:
raise ValueError("Multiple t-batch dimensions are not supported.")
# Keep track of batch dimensions
t_shape = candidates.shape[-3]
q_shape = candidates.shape[-2]
# Flatten the t-batch dimension into the q-batch dimension
flattened = candidates.flatten(end_dim=-2)
# Call the model on the entire input
mean, var = posterior(model, flattened)
# Restore the batch dimensions
mean = mean.reshape((t_shape, q_shape))
var = var.reshape((t_shape, q_shape))
return mean, var
return sequential_posterior
[docs]
def batchify_ensemble_predictor(
base_predictor: Callable[[Tensor], Tensor],
) -> Callable[[Tensor], Tensor]:
"""Wrap an ensemble predictor to make it evaluate t-batches as an augmented q-batch.
Args:
base_predictor: The ensemble predictor to be wrapped.
Returns:
The wrapped predictor.
"""
@wraps(base_predictor)
def batch_predictor(candidates: Tensor) -> Tensor:
# If no batch dimensions are given, call the model directly
if candidates.ndim == 2:
return base_predictor(candidates)
# Ensemble models do not (yet) support model parameter batching
if candidates.ndim > 3:
raise ValueError("Multiple t-batch dimensions are not supported.")
# Keep track of batch dimensions
t_shape = candidates.shape[-3]
q_shape = candidates.shape[-2]
# Flatten the t-batch dimension into the q-batch dimension
flattened = candidates.flatten(end_dim=-2)
# Call the model on the entire input
predictions = base_predictor(flattened)
# Assert that the model provides the ensemble predictions in the correct shape
# (otherwise the reshaping operation below could silently produce wrong results)
try:
assert predictions.ndim == 2
n_estimators = predictions.shape[0]
assert predictions.shape[1] == t_shape * q_shape
except AssertionError:
raise IncompatibleSurrogateError(
f"For the given input of shape {tuple(candidates.shape)}, "
f"the ensemble model is supposed to create predictions of shape "
f"(n_estimators, t_shape * q_shape) = "
f"(n_estimators, {t_shape * q_shape}) "
f"but returned an array of incompatible shape "
f"{tuple(predictions.shape)}."
)
# Restore the batch dimensions
predictions = predictions.reshape((n_estimators, t_shape, q_shape))
predictions = predictions.permute((1, 0, 2))
return predictions
return batch_predictor