"""Base classes for all acquisition functions."""from__future__importannotationsimportwarningsfromabcimportABCfrominspectimportsignaturefromtypingimportTYPE_CHECKING,ClassVarimportpandasaspdfromattrsimportdefinefrombaybe.exceptionsimport(IncompatibleAcquisitionFunctionError,UnidentifiedSubclassError,)frombaybe.objectives.baseimportObjectivefrombaybe.objectives.desirabilityimportDesirabilityObjectivefrombaybe.objectives.singleimportSingleTargetObjectivefrombaybe.searchspace.coreimportSearchSpacefrombaybe.serialization.coreimport(converter,get_base_structure_hook,unstructure_base,)frombaybe.serialization.mixinimportSerialMixinfrombaybe.surrogates.baseimportSurrogateProtocolfrombaybe.targets.enumimportTargetModefrombaybe.targets.numericalimportNumericalTargetfrombaybe.utils.basicimportclassproperty,match_attributesfrombaybe.utils.booleanimportis_abstractfrombaybe.utils.dataframeimportto_tensorifTYPE_CHECKING:frombotorch.acquisitionimportAcquisitionFunctionasBotorchAcquisitionFunction
[docs]@define(frozen=True)classAcquisitionFunction(ABC,SerialMixin):"""Abstract base class for all acquisition functions."""abbreviation:ClassVar[str]"""An alternative name for type resolution."""@classpropertydefis_mc(cls)->bool:"""Flag indicating whether this is a Monte-Carlo acquisition function."""returncls.abbreviation.startswith("q")@classpropertydef_non_botorch_attrs(cls)->tuple[str,...]:"""Names of attributes that are not passed to the BoTorch constructor."""return()
[docs]defto_botorch(self,surrogate:SurrogateProtocol,searchspace:SearchSpace,objective:Objective,measurements:pd.DataFrame,pending_experiments:pd.DataFrame|None=None,):"""Create the botorch-ready representation of the function. The required structure of `measurements` is specified in :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`. """importbotorch.acquisitionasbo_acqfimporttorchfrombotorch.acquisition.objectiveimportLinearMCObjectivefrombaybe.acquisition.acqfsimportqThompsonSampling# Retrieve botorch acquisition function class and match attributesacqf_cls=_get_botorch_acqf_class(type(self))params_dict=match_attributes(self,acqf_cls.__init__,ignore=self._non_botorch_attrs)[0]# Create botorch surrogate modelbo_surrogate=surrogate.to_botorch()# Get computational data representationtrain_x=to_tensor(searchspace.transform(measurements,allow_extra=True))# Collect remaining (context-specific) parameterssignature_params=signature(acqf_cls).parametersadditional_params={}additional_params["model"]=bo_surrogateif"X_baseline"insignature_params:additional_params["X_baseline"]=train_xif"mc_points"insignature_params:additional_params["mc_points"]=to_tensor(self.get_integration_points(searchspace)# type: ignore[attr-defined])ifpending_experimentsisnotNone:ifself.is_mc:pending_x=searchspace.transform(pending_experiments,allow_extra=True)additional_params["X_pending"]=to_tensor(pending_x)else:raiseIncompatibleAcquisitionFunctionError(f"Pending experiments were provided but the chosen acquisition "f"function '{self.__class__.__name__}' does not support this.")# Add acquisition objective / best observed valuematchobjective:caseSingleTargetObjective(NumericalTarget(mode=TargetMode.MIN)):if"best_f"insignature_params:additional_params["best_f"]=(bo_surrogate.posterior(train_x).mean.min().item())ifissubclass(acqf_cls,bo_acqf.AnalyticAcquisitionFunction):additional_params["maximize"]=Falseelifissubclass(acqf_cls,bo_acqf.MCAcquisitionFunction):additional_params["objective"]=LinearMCObjective(torch.tensor([-1.0]))else:raiseValueError(f"Unsupported acquisition function type: {acqf_cls}.")caseSingleTargetObjective()|DesirabilityObjective():if"best_f"insignature_params:additional_params["best_f"]=(bo_surrogate.posterior(train_x).mean.max().item())case_:raiseValueError(f"Unsupported objective type: {objective}")params_dict.update(additional_params)acqf=acqf_cls(**params_dict)ifisinstance(self,qThompsonSampling):asserthasattr(acqf,"_default_sample_shape")acqf._default_sample_shape=torch.Size([self.n_mc_samples])returnacqf
def_get_botorch_acqf_class(baybe_acqf_cls:type[AcquisitionFunction],/)->type[BotorchAcquisitionFunction]:"""Extract the BoTorch acquisition class for the given BayBE acquisition class."""importbotorchforclsinbaybe_acqf_cls.mro():ifacqf_cls:=getattr(botorch.acquisition,cls.__name__,False):ifis_abstract(acqf_cls):continuereturnacqf_cls# type: ignoreraiseUnidentifiedSubclassError(f"No BoTorch acquisition function class match found for "f"'{baybe_acqf_cls.__name__}'.")# Register de-/serialization hooksdef_add_deprecation_hook(hook):"""Add deprecation warnings to the default hook. Used for backward compatibility only and will be removed in future versions. """defadded_deprecation_hook(val:dict|str,cls:type):# Backwards-compatibility needs to be ensured only for deserialization from# base class using string-based type specifiers as listed below,# since the concrete classes were available only after the change.ifis_abstract(cls):UCB_DEPRECATIONS={"VarUCB":"UpperConfidenceBound","qVarUCB":"qUpperConfidenceBound",}if(entry:=valifisinstance(val,str)elseval["type"])inUCB_DEPRECATIONS:warnings.warn(f"The use of `{entry}` is deprecated and will be disabled in a "f"future version. To get the same outcome, use the new "f"`{UCB_DEPRECATIONS[entry]}` class instead with a beta of 100.0.",DeprecationWarning,)val={"type":UCB_DEPRECATIONS[entry],"beta":100.0}returnhook(val,cls)returnadded_deprecation_hookconverter.register_structure_hook(AcquisitionFunction,_add_deprecation_hook(get_base_structure_hook(AcquisitionFunction)),)converter.register_unstructure_hook(AcquisitionFunction,unstructure_base)