"""Utilities for working with surrogates."""from__future__importannotationsfromcollections.abcimportCallablefromfunctoolsimportwrapsfromtypingimportTYPE_CHECKING,TypeVarfrombaybe.exceptionsimportInvalidSurrogateModelErrorfrombaybe.surrogates.baseimportSurrogateifTYPE_CHECKING:frombotorch.posteriorsimportPosteriorfromtorchimportTensorfrombaybe.surrogates.naiveimportMeanPredictionSurrogate_TSurrogate=TypeVar("_TSurrogate",bound=Surrogate)_constant_target_model_store:dict[int,MeanPredictionSurrogate]={}"""Dictionary for storing constant target fallback models. Keys are the IDs of thesurrogate models that temporarily have a fallback attached because they weretrained on constant training targets. Values are the corresponding fallback models."""
[docs]defcatch_constant_targets(cls:type[Surrogate],std_threshold:float=1e-6):"""Make a ``Surrogate`` class robustly handle constant training targets. More specifically, "constant training targets" can mean either of: * The standard deviation of the training targets is below the given threshold. * There is only one target and the standard deviation cannot even be computed. The modified class handles the above cases separately from "regular operation" by resorting to a :class:`baybe.surrogates.naive.MeanPredictionSurrogate`, which is stored outside the model in a dictionary maintained by this decorator. Args: cls: The :class:`baybe.surrogates.base.Surrogate` to be augmented. std_threshold: The standard deviation threshold below which operation is switched to the alternative model. Returns: The modified class. """frombaybe.surrogates.naiveimportMeanPredictionSurrogate# References to original methods_fit_original=cls._fit_posterior_original=cls._posteriordef_posterior_new(self,candidates_comp_scaled:Tensor,/)->Posterior:"""Use fallback model if it exists, otherwise call original posterior."""# Alternative model fallbackifconstant_target_model:=_constant_target_model_store.get(id(self),None):returnconstant_target_model._posterior(candidates_comp_scaled)# Regular operationreturn_posterior_original(self,candidates_comp_scaled)def_fit_new(self,train_x:Tensor,train_y:Tensor)->None:"""Original fit but with fallback model creation for constant targets."""ifnot(train_y.ndim==2andtrain_y.shape[-1]==1):raiseNotImplementedError("The current logic is only implemented for single-target surrogates.")# Alternative model fallbackiftrain_y.numel()==1ortrain_y.std()<std_threshold:model=MeanPredictionSurrogate()model._fit(train_x,train_y)_constant_target_model_store[id(self)]=model# Regular operationelse:_constant_target_model_store.pop(id(self),None)_fit_original(self,train_x,train_y)# Replace the methodscls._posterior=_posterior_new# type: ignorecls._fit=_fit_new# type: ignorereturncls
[docs]defbatchify_mean_var_prediction(posterior:Callable[[_TSurrogate,Tensor],tuple[Tensor,Tensor]],)->Callable[[_TSurrogate,Tensor],tuple[Tensor,Tensor]]:"""Wrap a posterior method to make it evaluate t-batches as an augmented q-batch."""@wraps(posterior)defsequential_posterior(model:_TSurrogate,candidates:Tensor)->tuple[Tensor,Tensor]:# If no batch dimensions are given, call the model directlyifcandidates.ndim==2:returnposterior(model,candidates)# Parameter batching is not (yet) supportedifcandidates.ndim>3:raiseValueError("Multiple t-batch dimensions are not supported.")# Keep track of batch dimensionst_shape=candidates.shape[-3]q_shape=candidates.shape[-2]# Flatten the t-batch dimension into the q-batch dimensionflattened=candidates.flatten(end_dim=-2)# Call the model on the entire inputmean,var=posterior(model,flattened)# Restore the batch dimensionsmean=mean.reshape((t_shape,q_shape))var=var.reshape((t_shape,q_shape))returnmean,varreturnsequential_posterior
[docs]defbatchify_ensemble_predictor(base_predictor:Callable[[Tensor],Tensor],)->Callable[[Tensor],Tensor]:"""Wrap an ensemble predictor to make it evaluate t-batches as an augmented q-batch. Args: base_predictor: The ensemble predictor to be wrapped. Returns: The wrapped predictor. """@wraps(base_predictor)defbatch_predictor(candidates:Tensor)->Tensor:# If no batch dimensions are given, call the model directlyifcandidates.ndim==2:returnbase_predictor(candidates)# Ensemble models do not (yet) support model parameter batchingifcandidates.ndim>3:raiseValueError("Multiple t-batch dimensions are not supported.")# Keep track of batch dimensionst_shape=candidates.shape[-3]q_shape=candidates.shape[-2]# Flatten the t-batch dimension into the q-batch dimensionflattened=candidates.flatten(end_dim=-2)# Call the model on the entire inputpredictions=base_predictor(flattened)# Assert that the model provides the ensemble predictions in the correct shape# (otherwise the reshaping operation below could silently produce wrong results)try:assertpredictions.ndim==2n_estimators=predictions.shape[0]assertpredictions.shape[1]==t_shape*q_shapeexceptAssertionError:raiseInvalidSurrogateModelError(f"For the given input of shape {tuple(candidates.shape)}, "f"the ensemble model is supposed to create predictions of shape "f"(n_estimators, t_shape * q_shape) = "f"(n_estimators, {t_shape*q_shape}) "f"but returned an array of incompatible shape "f"{tuple(predictions.shape)}.")# Restore the batch dimensionspredictions=predictions.reshape((n_estimators,t_shape,q_shape))predictions=predictions.permute((1,0,2))returnpredictionsreturnbatch_predictor