Source code for baybe.acquisition.base

"""Base classes for all acquisition functions."""

from __future__ import annotations

import gc
from abc import ABC
from typing import TYPE_CHECKING, ClassVar, Literal, overload

import pandas as pd
from attrs import define

from baybe.exceptions import (
    IncompatibleAcquisitionFunctionError,
    UnidentifiedSubclassError,
)
from baybe.objectives.base import Objective
from baybe.searchspace.core import SearchSpace
from baybe.serialization.core import (
    converter,
    get_base_structure_hook,
    unstructure_base,
)
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.base import SurrogateProtocol
from baybe.utils.basic import classproperty
from baybe.utils.boolean import is_abstract
from baybe.utils.dataframe import to_tensor

if TYPE_CHECKING:
    from botorch.acquisition import AcquisitionFunction as BotorchAcquisitionFunction


[docs] @define(frozen=True) class AcquisitionFunction(ABC, SerialMixin): """Abstract base class for all acquisition functions.""" abbreviation: ClassVar[str] """An alternative name for type resolution.""" supports_multi_output: ClassVar[bool] = False """Whether this acquisition function can handle models with multiple outputs.""" @classproperty def supports_batching(cls) -> bool: """Flag indicating whether batch recommendation is supported.""" return cls.abbreviation.startswith("q") @classproperty def supports_pending_experiments(cls) -> bool: """Flag indicating whether pending experiments are supported. This is based on the same mechanism underlying batched recommendations. """ return cls.supports_batching @classproperty def _non_botorch_attrs(cls) -> tuple[str, ...]: """Names of attributes that are not passed to the BoTorch constructor.""" return ()
[docs] def to_botorch( self, surrogate: SurrogateProtocol, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame, pending_experiments: pd.DataFrame | None = None, ) -> BotorchAcquisitionFunction: """Create the botorch-ready representation of the function. The required structure of `measurements` is specified in :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`. """ from baybe.acquisition._builder import BotorchAcquisitionFunctionBuilder if pending_experiments is not None and not self.supports_pending_experiments: raise IncompatibleAcquisitionFunctionError( f"The chosen acquisition function of type '{self.__class__.__name__}' " f"does not support pending experiments." ) return BotorchAcquisitionFunctionBuilder( self, surrogate, searchspace, objective, measurements, pending_experiments ).build()
@overload def evaluate( self, candidates: pd.DataFrame, surrogate: SurrogateProtocol, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame, pending_experiments: pd.DataFrame | None = None, *, jointly: Literal[True], ) -> float: ... @overload def evaluate( self, candidates: pd.DataFrame, surrogate: SurrogateProtocol, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame, pending_experiments: pd.DataFrame | None = None, *, jointly: Literal[False] = False, ) -> pd.Series: ...
[docs] def evaluate( self, candidates: pd.DataFrame, surrogate: SurrogateProtocol, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame, pending_experiments: pd.DataFrame | None = None, *, jointly: bool = False, ) -> pd.Series | float: """Get the acquisition values for the given candidates. Args: candidates: The candidate points in experimental representation. For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. surrogate: The surrogate model to use for the acquisition function. searchspace: The search space. See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`. objective: The objective. See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`. measurements: Available experimentation data. See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`. pending_experiments: Optional pending experiments. See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`. jointly: If ``False``, the acquisition values are computed independently for each candidate. If ``True``, a single joint acquisition value is computed for the entire candidate set. Returns: Depending on the joint mode, either a single batch acquisition value or a series of individual acquisition values. """ import torch # Assemble the Botorch acquisition function and its input botorch_acqf = self.to_botorch( surrogate, searchspace, objective, measurements, pending_experiments ) comp = to_tensor(searchspace.transform(candidates)) # Depending on joint mode, evaluate using t- or q-batching in_ = comp if jointly else comp.unsqueeze(-2) with torch.no_grad(): out = botorch_acqf(in_) if jointly: return out.item() return pd.Series(out.numpy(), index=candidates.index)
def _get_botorch_acqf_class( baybe_acqf_cls: type[AcquisitionFunction], / ) -> type[BotorchAcquisitionFunction]: """Extract the BoTorch acquisition class for the given BayBE acquisition class.""" import botorch for cls in baybe_acqf_cls.mro(): if ( acqf_cls := getattr(botorch.acquisition, cls.__name__, False) or getattr(botorch.acquisition.multi_objective, cls.__name__, False) or getattr(botorch.acquisition.multi_objective.parego, cls.__name__, False) ): if is_abstract(acqf_cls): continue return acqf_cls # type: ignore raise UnidentifiedSubclassError( f"No BoTorch acquisition function class match found for " f"'{baybe_acqf_cls.__name__}'." ) # Register (un-)structure hooks converter.register_structure_hook( AcquisitionFunction, get_base_structure_hook(AcquisitionFunction) ) converter.register_unstructure_hook(AcquisitionFunction, unstructure_base) # Collect leftover original slotted classes processed by `attrs.define` gc.collect()