Source code for baybe.surrogates.linear

"""Linear surrogates."""

from __future__ import annotations

import gc
from typing import TYPE_CHECKING, ClassVar, TypedDict

from attrs import define, field
from typing_extensions import override

from baybe.surrogates.base import IndependentGaussianSurrogate
from baybe.surrogates.utils import batchify_mean_var_prediction, catch_constant_targets
from baybe.surrogates.validation import make_dict_validator
from baybe.utils.conversion import to_string

if TYPE_CHECKING:
    from torch import Tensor


class _ARDRegressionParams(TypedDict, total=False):
    """Optional ARDRegression parameters.

    See :class:`~sklearn.linear_model.ARDRegression`.
    """

    max_iter: int
    tol: float
    alpha_1: float
    alpha_2: float
    lambda_1: float
    lambda_2: float
    compute_score: bool
    threshold_lambda: float
    fit_intercept: bool
    copy_X: bool
    verbose: bool


[docs] @catch_constant_targets @define class BayesianLinearSurrogate(IndependentGaussianSurrogate): """A Bayesian linear regression surrogate model.""" supports_transfer_learning: ClassVar[bool] = False # See base class. model_params: _ARDRegressionParams = field( factory=dict, converter=dict, validator=make_dict_validator(_ARDRegressionParams), ) """Optional model parameter that will be passed to the surrogate constructor. For allowed keys and values, see :class:`~sklearn.linear_model.ARDRegression`. """ # TODO: type should be `ARDRegression | None` but is currently omitted due to: # https://github.com/python-attrs/cattrs/issues/531 _model = field(init=False, default=None, eq=False) """The actual model.""" @override @batchify_mean_var_prediction def _estimate_moments( self, candidates_comp_scaled: Tensor, / ) -> tuple[Tensor, Tensor]: # FIXME[typing]: It seems there is currently no better way to inform the type # checker that the attribute is available at the time of the function call assert self._model is not None import torch # Get predictions dists = self._model.predict(candidates_comp_scaled.numpy(), return_std=True) # Split into posterior mean and variance mean = torch.from_numpy(dists[0]) var = torch.from_numpy(dists[1]).pow(2) return mean, var @override def _fit(self, train_x: Tensor, train_y: Tensor) -> None: from sklearn.linear_model import ARDRegression self._model = ARDRegression(**(self.model_params)) self._model.fit(train_x, train_y.ravel()) @override def __str__(self) -> str: fields = [to_string("Model Params", self.model_params, single_line=True)] return to_string(super().__str__(), *fields)
# Collect leftover original slotted classes processed by `attrs.define` gc.collect()