Source code for baybe.surrogates.composite

"""Composite surrogate models."""

from __future__ import annotations

from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar

import pandas as pd
from attrs import define, field
from typing_extensions import override

from baybe.exceptions import IncompatibleSurrogateError
from baybe.objectives.base import Objective
from baybe.searchspace.core import SearchSpace
from baybe.serialization import converter
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.base import PosteriorStatistic, SurrogateProtocol
from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate
from baybe.utils.basic import is_all_instance

if TYPE_CHECKING:
    from botorch.models.model import ModelList
    from botorch.posteriors import PosteriorList

_T = TypeVar("_T")


class _SurrogateGetter(Protocol):
    """A index-based mapping from strings to surrogates."""

    def __getitem__(self, key: str) -> SurrogateProtocol: ...


@define
class _ReplicationMapping(Generic[_T]):
    """A wrapper class providing copies of a given template object via indexing access.

    Essentially a serializable version of ``defaultdict(lambda: deepcopy(template))``.
    """

    template: _T = field()
    """The template object to be copied upon indexing access."""

    _data: dict[Any, _T] = field(init=False, factory=dict, eq=False)
    """An internal storage keeping track of already requested copies."""

    def __getitem__(self, key: Any, /) -> _T:
        """Create a new object copy upon first access."""
        if key not in self._data:
            self._data[key] = deepcopy(self.template)
        return self._data[key]


[docs] @define class CompositeSurrogate(SerialMixin, SurrogateProtocol): """A class for composing multi-target surrogates from single-target surrogates.""" # IMPROVE: Currently, the class is implemented in the most vanilla way, using only # BayBE's existing interfaces. There are several ways how it can be # further optimized by integrating it more directly with the underlying gpytorch # models. However, this probably requires some additional code adaptations to # achieve a full integration. Some future directions: # * Instead of fitting the models sequentially, a parallel optimization can # be done via `SumMarginalLogLikelihood`. However, a full integration would # also require supporting different fitting routines (e.g. LOO) # * The manual construction of the `PosteriorList` can be avoided when # the posterior computation is triggered directly on the `ModelList`. However, # this requires a clean integration of the necessary pre-processing steps # (transformation to computational representation + scaling) # * There is currently a lot of redundancy because each of the surrogates # internally stores a references to the fitting context (e.g. search space, # objective, ...) surrogates: _SurrogateGetter = field() """An index-based mapping from target names to single-target surrogates.""" _target_names: tuple[str, ...] | None = field(init=False, eq=False) """The names of the targets modeled by the surrogate outputs."""
[docs] @classmethod def from_replication(cls, surrogate: SurrogateProtocol) -> CompositeSurrogate: """Replicate a given single-target surrogate logic for multiple targets.""" return CompositeSurrogate(_ReplicationMapping(surrogate))
@property def _surrogates_flat(self) -> tuple[SurrogateProtocol, ...]: """The surrogates ordered according to the targets of the modeled objective.""" assert self._target_names is not None return tuple(self.surrogates[t] for t in self._target_names)
[docs] @override def fit( self, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame, ) -> None: for target in objective.targets: self.surrogates[target.name].fit( searchspace, target.to_objective(), measurements ) self._target_names = tuple(t.name for t in objective.targets)
[docs] @override def to_botorch(self) -> ModelList: from botorch.models import ModelList from botorch.models.model_list_gp_regression import ModelListGP cls = ( ModelListGP if is_all_instance(self._surrogates_flat, GaussianProcessSurrogate) else ModelList ) return cls(*(s.to_botorch() for s in self._surrogates_flat))
[docs] def posterior(self, candidates: pd.DataFrame) -> PosteriorList: """Compute the posterior for candidates in experimental representation. The (independent joint) posterior is represented as a collection of individual posterior models computed per target of the involved objective. For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. """ if not all(hasattr(s, "posterior") for s in self._surrogates_flat): raise IncompatibleSurrogateError( "A posterior can only be computed if all involved surrogates offer " "posterior computation." ) from botorch.posteriors import PosteriorList # TODO[typing]: a `has_all_attrs` typeguard similar to `is_all_instance` would # be handy here but unclear if this is doable with the current typing system posteriors = [s.posterior(candidates) for s in self._surrogates_flat] # type: ignore[attr-defined] return PosteriorList(*posteriors)
[docs] def posterior_stats( self, candidates: pd.DataFrame, stats: Sequence[PosteriorStatistic] = ("mean", "std"), ) -> pd.DataFrame: """See :meth:`baybe.surrogates.base.Surrogate.posterior_stats`.""" if not all(hasattr(s, "posterior_stats") for s in self._surrogates_flat): raise IncompatibleSurrogateError( "Posterior statistics can only be computed if all involved surrogates " "offer this computation." ) dfs = [s.posterior_stats(candidates, stats) for s in self._surrogates_flat] # type: ignore[attr-defined] return pd.concat(dfs, axis=1)
def _structure_surrogate_getter(obj: dict, _) -> _SurrogateGetter: """Resolve the object type.""" if (type_ := obj.pop("type")) == _ReplicationMapping.__name__: return converter.structure(obj, _ReplicationMapping[SurrogateProtocol]) elif type_ == "dict": return converter.structure(obj, dict[str, SurrogateProtocol]) raise NotImplementedError(f"No structure hook implemented for '{type_}'.") def _unstructure_surrogate_getter(obj: _SurrogateGetter) -> dict: """Add the object type information.""" return {"type": type(obj).__name__, **converter.unstructure(obj)} converter.register_structure_hook_func( lambda t: t is _SurrogateGetter, _structure_surrogate_getter ) converter.register_unstructure_hook_func( lambda t: t is _SurrogateGetter, _unstructure_surrogate_getter )