"""Random forest surrogates."""from__future__importannotationsfromcollections.abcimportCollectionfromtypingimportTYPE_CHECKING,ClassVar,Literal,Protocol,TypedDictimportnumpyasnpimportnumpy.typingasnptfromattrsimportdefine,fieldfromnumpy.randomimportRandomStatefromtyping_extensionsimportoverridefrombaybe.parameters.baseimportParameterfrombaybe.surrogates.baseimportSurrogatefrombaybe.surrogates.utilsimportbatchify_ensemble_predictor,catch_constant_targetsfrombaybe.surrogates.validationimportmake_dict_validatorfrombaybe.utils.conversionimportto_stringifTYPE_CHECKING:frombotorch.models.ensembleimportEnsemblePosteriorfrombotorch.models.transforms.inputimportInputTransformfrombotorch.models.transforms.outcomeimportOutcomeTransformfromtorchimportTensorclass_RandomForestRegressorParams(TypedDict,total=False):"""Optional RandomForestRegressor parameters. See :class:`~sklearn.ensemble.RandomForestRegressor`. """n_estimators:intcriterion:Literal["squared_error","absolute_error","friedman_mse","poisson"]max_depth:intmin_samples_split:int|floatmin_samples_leaf:int|floatmin_weight_fraction_leaf:floatmax_features:Literal["sqrt","log2"]|int|float|Nonemax_leaf_nodes:int|Nonemin_impurity_decrease:floatbootstrap:booloob_score:booln_jobs:int|Nonerandom_state:int|RandomState|Noneverbose:intwarm_start:boolccp_alpha:floatmax_samples:int|float|Nonemonotonic_cst:npt.ArrayLike|int|Noneclass_Predictor(Protocol):"""A basic predictor."""defpredict(self,x:np.ndarray,/)->np.ndarray:...
[docs]@catch_constant_targets@defineclassRandomForestSurrogate(Surrogate):"""A random forest surrogate model."""supports_transfer_learning:ClassVar[bool]=False# See base class.model_params:_RandomForestRegressorParams=field(factory=dict,converter=dict,validator=make_dict_validator(_RandomForestRegressorParams),)"""Optional model parameter that will be passed to the surrogate constructor. For allowed keys and values, see :class:`~sklearn.ensemble.RandomForestRegressor`. """# TODO: type should be `RandomForestRegressor | None` but is currently omitted due# to: https://github.com/python-attrs/cattrs/issues/531_model=field(init=False,default=None,eq=False)"""The actual model."""@override@staticmethoddef_make_parameter_scaler_factory(parameter:Parameter,)->type[InputTransform]|None:# Tree-like models do not require any input scalingreturnNone@override@staticmethoddef_make_target_scaler_factory()->type[OutcomeTransform]|None:# Tree-like models do not require any output scalingreturnNone@overridedef_posterior(self,candidates_comp_scaled:Tensor,/)->EnsemblePosterior:frombotorch.models.ensembleimportEnsemblePosterior@batchify_ensemble_predictordefpredict(candidates_comp_scaled:Tensor)->Tensor:"""Make the end-to-end ensemble prediction."""importtorch# FIXME[typing]: It seems there is currently no better way to inform the# type checker that the attribute is available at the time of the# function callassertself._modelisnotNonereturntorch.from_numpy(self._predict_ensemble(self._model.estimators_,candidates_comp_scaled.numpy()))returnEnsemblePosterior(predict(candidates_comp_scaled).unsqueeze(-1))@staticmethoddef_predict_ensemble(predictors:Collection[_Predictor],candidates:np.ndarray)->np.ndarray:"""Evaluate an ensemble of predictors on a given candidate set."""# Extract shapesn_candidates=len(candidates)n_estimators=len(predictors)# Evaluate all treespredictions=np.zeros((n_estimators,n_candidates))forp,predictorinenumerate(predictors):predictions[p]=predictor.predict(candidates)returnpredictions@overridedef_fit(self,train_x:Tensor,train_y:Tensor)->None:fromsklearn.ensembleimportRandomForestRegressorself._model=RandomForestRegressor(**(self.model_params))self._model.fit(train_x.numpy(),train_y.numpy().ravel())@overridedef__str__(self)->str:fields=[to_string("Model Params",self.model_params,single_line=True)]returnto_string(super().__str__(),*fields)