[docs]@defineclassMeanPredictionSurrogate(IndependentGaussianSurrogate):"""A trivial surrogate model. It provides the average value of the training targets as posterior mean and a (data-independent) constant posterior variance. """supports_transfer_learning:ClassVar[bool]=False# See base class._model:float|None=field(init=False,default=None,eq=False)"""The estimated posterior mean value of the training targets."""@override@batchify_mean_var_predictiondef_estimate_moments(self,candidates_comp_scaled:Tensor,/)->tuple[Tensor,Tensor]:importtorch# TODO: use target value bounds for covariance scaling when explicitly providedmean=self._model*torch.ones([len(candidates_comp_scaled)])# type: ignore[operator]var=torch.ones(len(candidates_comp_scaled))returnmean,var@overridedef_fit(self,train_x:Tensor,train_y:Tensor)->None:self._model=train_y.mean().item()
# Collect leftover original slotted classes processed by `attrs.define`gc.collect()