[docs]@catch_constant_targets@defineclassNGBoostSurrogate(IndependentGaussianSurrogate):"""A natural-gradient-boosting surrogate model."""supports_transfer_learning:ClassVar[bool]=False# See base class._default_model_params:ClassVar[dict]={"n_estimators":25,"verbose":False}"""Class variable encoding the default model parameters."""model_params:dict[str,Any]=field(factory=dict,converter=dict)"""Optional model parameter that will be passed to the surrogate constructor."""# TODO: Type should be `NGBRegressor | None` but is currently# omitted due to: https://github.com/python-attrs/cattrs/issues/531_model=field(init=False,default=None,eq=False)"""The actual model."""@model_params.validatordef_validate_model_params(self,attr,value)->None:frombaybe._optional.ngboostimportNGBRegressorvalidator=get_model_params_validator(NGBRegressor.__init__)validator(self,attr,value)def__attrs_post_init__(self):self.model_params={**self._default_model_params,**self.model_params}@override@staticmethoddef_make_parameter_scaler_factory(parameter:Parameter,)->type[InputTransform]|None:# Tree-like models do not require any input scalingreturnNone@override@staticmethoddef_make_target_scaler_factory()->type[OutcomeTransform]|None:# Tree-like models do not require any output scalingreturnNone@override@batchify_mean_var_predictiondef_estimate_moments(self,candidates_comp_scaled:Tensor,/)->tuple[Tensor,Tensor]:# FIXME[typing]: It seems there is currently no better way to inform the type# checker that the attribute is available at the time of the function callassertself._modelisnotNoneimporttorch# Get predictionsdists=self._model.pred_dist(candidates_comp_scaled)# Split into posterior mean and variancemean=torch.from_numpy(dists.mean())var=torch.from_numpy(dists.var)returnmean,var@overridedef_fit(self,train_x:Tensor,train_y:Tensor)->None:frombaybe._optional.ngboostimportNGBRegressorself._model=NGBRegressor(**(self.model_params)).fit(train_x,train_y.ravel())@overridedef__str__(self)->str:fields=[to_string("Model Params",self.model_params,single_line=True)]returnto_string(super().__str__(),*fields)