"""Functionality for single-target objectives."""from__future__importannotationsimportgcfromtypingimportTYPE_CHECKING,ClassVarimportpandasaspdfromattrsimportdefine,fieldfromattrs.validatorsimportinstance_offromtyping_extensionsimportoverridefrombaybe.exceptionsimportNonGaussianityErrorfrombaybe.objectives.baseimportObjectivefrombaybe.targets.baseimportTargetfrombaybe.targets.numericalimportNumericalTargetfrombaybe.transformations.basicimportAffineTransformation,IdentityTransformationfrombaybe.utils.conversionimportto_stringfrombaybe.utils.dataframeimportpretty_print_dfifTYPE_CHECKING:frombotorch.acquisition.objectiveimport(MCAcquisitionObjective,ScalarizedPosteriorTransform,)
[docs]@define(frozen=True,slots=False)classSingleTargetObjective(Objective):"""An objective focusing on a single target."""is_multi_output:ClassVar[bool]=False# See base class._target:Target=field(validator=instance_of(Target),alias="target")"""The single target considered by the objective."""@overridedef__str__(self)->str:targets_list=[target.summary()fortargetinself.targets]targets_df=pd.DataFrame(targets_list)fields=[to_string("Type",self.__class__.__name__,single_line=True),to_string("Targets",pretty_print_df(targets_df)),]returnto_string("Objective",*fields)@override@propertydeftargets(self)->tuple[Target,...]:return(self._target,)@override@propertydefoutput_names(self)->tuple[str,...]:return(self._target.name,)@override@propertydefsupports_partial_measurements(self)->bool:returnFalse
[docs]@overridedefto_botorch_posterior_transform(self)->ScalarizedPosteriorTransform:ifnot(isinstance((t:=self._target),NumericalTarget)andisinstance((tr:=t.transformation),(IdentityTransformation,AffineTransformation))):raiseNonGaussianityError(f"Converting an objective of type '{type(self).__name__}' is only "f"possible when the transformation result is Gaussian, that is, "f"when the target is of type '{NumericalTarget.__name__}' and the "f"assigned transformation is affine.")return(tr.negate()ift.minimizeelsetr).to_botorch_posterior_transform()
# Collect leftover original slotted classes processed by `attrs.define`gc.collect()