"""Linear surrogates."""from__future__importannotationsimportgcfromtypingimportTYPE_CHECKING,ClassVar,TypedDictfromattrsimportdefine,fieldfromtyping_extensionsimportoverridefrombaybe.surrogates.baseimportIndependentGaussianSurrogatefrombaybe.surrogates.utilsimportbatchify_mean_var_prediction,catch_constant_targetsfrombaybe.surrogates.validationimportmake_dict_validatorfrombaybe.utils.conversionimportto_stringifTYPE_CHECKING:fromtorchimportTensorclass_ARDRegressionParams(TypedDict,total=False):"""Optional ARDRegression parameters. See :class:`~sklearn.linear_model.ARDRegression`. """max_iter:inttol:floatalpha_1:floatalpha_2:floatlambda_1:floatlambda_2:floatcompute_score:boolthreshold_lambda:floatfit_intercept:boolcopy_X:boolverbose:bool
[docs]@catch_constant_targets@defineclassBayesianLinearSurrogate(IndependentGaussianSurrogate):"""A Bayesian linear regression surrogate model."""supports_transfer_learning:ClassVar[bool]=False# See base class.model_params:_ARDRegressionParams=field(factory=dict,converter=dict,validator=make_dict_validator(_ARDRegressionParams),)"""Optional model parameter that will be passed to the surrogate constructor. For allowed keys and values, see :class:`~sklearn.linear_model.ARDRegression`. """# TODO: type should be `ARDRegression | None` but is currently omitted due to:# https://github.com/python-attrs/cattrs/issues/531_model=field(init=False,default=None,eq=False)"""The actual model."""@override@batchify_mean_var_predictiondef_estimate_moments(self,candidates_comp_scaled:Tensor,/)->tuple[Tensor,Tensor]:# FIXME[typing]: It seems there is currently no better way to inform the type# checker that the attribute is available at the time of the function callassertself._modelisnotNoneimporttorch# Get predictionsdists=self._model.predict(candidates_comp_scaled.numpy(),return_std=True)# Split into posterior mean and variancemean=torch.from_numpy(dists[0])var=torch.from_numpy(dists[1]).pow(2)returnmean,var@overridedef_fit(self,train_x:Tensor,train_y:Tensor)->None:fromsklearn.linear_modelimportARDRegressionself._model=ARDRegression(**(self.model_params))self._model.fit(train_x,train_y.ravel())@overridedef__str__(self)->str:fields=[to_string("Model Params",self.model_params,single_line=True)]returnto_string(super().__str__(),*fields)
# Collect leftover original slotted classes processed by `attrs.define`gc.collect()