[docs]@catch_constant_targets@defineclassBayesianLinearSurrogate(IndependentGaussianSurrogate):"""A Bayesian linear regression surrogate model."""supports_transfer_learning:ClassVar[bool]=False# See base class.model_params:dict[str,Any]=field(factory=dict,converter=dict,validator=get_model_params_validator(ARDRegression.__init__),)"""Optional model parameter that will be passed to the surrogate constructor."""_model:ARDRegression|None=field(init=False,default=None,eq=False)"""The actual model."""@override@batchify_mean_var_predictiondef_estimate_moments(self,candidates_comp_scaled:Tensor,/)->tuple[Tensor,Tensor]:# FIXME[typing]: It seems there is currently no better way to inform the type# checker that the attribute is available at the time of the function callassertself._modelisnotNoneimporttorch# Get predictionsdists=self._model.predict(candidates_comp_scaled.numpy(),return_std=True)# Split into posterior mean and variancemean=torch.from_numpy(dists[0])var=torch.from_numpy(dists[1]).pow(2)returnmean,var@overridedef_fit(self,train_x:Tensor,train_y:Tensor)->None:self._model=ARDRegression(**(self.model_params))self._model.fit(train_x,train_y.ravel())@overridedef__str__(self)->str:fields=[to_string("Model Params",self.model_params,single_line=True)]returnto_string(super().__str__(),*fields)
# Collect leftover original slotted classes processed by `attrs.define`gc.collect()