"""Numerical targets."""importwarningsfromcollections.abcimportCallable,SequencefromfunctoolsimportpartialfromtypingimportAny,castimportnumpyasnpimportpandasaspdfromattrsimportdefine,fieldfromnumpy.typingimportArrayLikefrombaybe.serializationimportSerialMixinfrombaybe.targets.baseimportTargetfrombaybe.targets.enumimportTargetMode,TargetTransformationfrombaybe.targets.transformsimport(bell_transform,linear_transform,triangular_transform,)frombaybe.utils.intervalimportInterval,convert_bounds_VALID_TRANSFORMATIONS:dict[TargetMode,Sequence[TargetTransformation]]={TargetMode.MAX:(TargetTransformation.LINEAR,),TargetMode.MIN:(TargetTransformation.LINEAR,),TargetMode.MATCH:(TargetTransformation.TRIANGULAR,TargetTransformation.BELL),}"""A mapping from target modes to allowed target transformations.If multiple transformations are allowed, the first entry is used as default option."""def_get_target_transformation(mode:TargetMode,transformation:TargetTransformation)->Callable[[ArrayLike,float,float],np.ndarray]:"""Provide the transform callable for the given target mode and transform type."""iftransformationisTargetTransformation.TRIANGULAR:returntriangular_transformiftransformationisTargetTransformation.BELL:returnbell_transformiftransformationisTargetTransformation.LINEAR:ifmodeisTargetMode.MAX:returnpartial(linear_transform,descending=False)ifmodeisTargetMode.MIN:returnpartial(linear_transform,descending=True)raiseValueError(f"Unrecognized target mode: '{mode}'.")raiseValueError(f"Unrecognized target transformation: '{transformation}'.")
[docs]@define(frozen=True)classNumericalTarget(Target,SerialMixin):"""Class for numerical targets."""# NOTE: The type annotations of `bounds` are correctly overridden by the attrs# converter. Nonetheless, PyCharm's linter might incorrectly raise a type warning# when calling the constructor. This is a known issue:# https://youtrack.jetbrains.com/issue/PY-34243# Quote from attrs docs:# If a converter’s first argument has a type annotation, that type will# appear in the signature for __init__. A converter will override an explicit# type annotation or type argument.mode:TargetMode=field(converter=TargetMode)"""The target mode."""bounds:Interval=field(default=None,converter=convert_bounds)"""Optional target bounds."""transformation:TargetTransformation|None=field(converter=lambdax:NoneifxisNoneelseTargetTransformation(x))"""An optional target transformation."""@transformation.defaultdef_default_transformation(self)->TargetTransformation|None:"""Provide the default transformation for bounded targets."""ifself.bounds.is_bounded:fun=_VALID_TRANSFORMATIONS[self.mode][0]warnings.warn(f"The transformation for target '{self.name}' "f"in '{self.mode.name}' mode has not been specified. "f"Setting the transformation to '{fun.name}'.",UserWarning,)returnfunreturnNone@bounds.validatordef_validate_bounds(self,_:Any,bounds:Interval)->None:# noqa: DOC101, DOC103"""Validate the bounds. Raises: ValueError: If the target is defined on a half-bounded interval. ValueError: If the target is in ``MATCH`` mode but the provided bounds are infinite. """# IMPROVE: We could also include half-way bounds, which however don't work# for the desirability approachifbounds.is_half_bounded:raiseValueError("Targets on half-bounded intervals are not supported.")ifbounds.is_degenerate:raiseValueError("The interval specified by the target bounds cannot be degenerate.")ifself.modeisTargetMode.MATCHandnotbounds.is_bounded:raiseValueError(f"Target '{self.name}' is in {TargetMode.MATCH.name} mode,"f"which requires finite bounds.")@transformation.validatordef_validate_transformation(# noqa: DOC101, DOC103self,_:Any,value:TargetTransformation|None)->None:"""Validate that the given transformation is compatible with the specified mode. Raises: ValueError: If the target transformation and mode are not compatible. """if(valueisnotNone)and(valuenotin_VALID_TRANSFORMATIONS[self.mode]):raiseValueError(f"You specified bounds for target '{self.name}', but your "f"specified transformation '{value}' is not compatible "f"with the target mode {self.mode}'. It must be one "f"of {_VALID_TRANSFORMATIONS[self.mode]}.")@propertydef_is_transform_normalized(self)->bool:"""Indicate if the computational transformation maps to the unit interval."""return(self.bounds.is_bounded)and(self.transformationisnotNone)
[docs]deftransform(self,data:pd.DataFrame)->pd.DataFrame:# noqa: D102# See base class.# TODO: The method (signature) needs to be refactored, potentially when# enabling multi-target settings. The current input type suggests that passing# dataframes is allowed, but the code was designed for single targets and# desirability objectives, where only one column is present.assertdata.shape[1]==1# When a transformation is specified, apply itifself.transformationisnotNone:func=_get_target_transformation(# TODO[typing]: For bounded targets (see if clause), the attrs default# ensures there is always a transformation specified.# Use function overloads to make this explicit.self.mode,cast(TargetTransformation,self.transformation),)transformed=pd.DataFrame(func(data,*self.bounds.to_tuple()),index=data.index)else:transformed=data.copy()returntransformed
[docs]defsummary(self)->dict:# noqa: D102# See base class.target_dict=dict(Type=self.__class__.__name__,Name=self.name,Mode=self.mode.name,Lower_Bound=self.bounds.lower,Upper_Bound=self.bounds.upper,Transformation=self.transformation.nameifself.transformationelse"None",)returntarget_dict