"""Binary targets."""importgcfromtypingimportTypeAliasimportnumpyasnpimportpandasaspdfromattrsimportdefine,fieldfromattrs.validatorsimportinstance_offrombaybe.exceptionsimportInvalidTargetValueErrorfrombaybe.serializationimportSerialMixinfrombaybe.targets.baseimportTargetfrombaybe.utils.validationimportvalidate_not_nanChoiceValue:TypeAlias=bool|int|float|str"""Types of values that a :class:`BinaryTarget` can take."""_SUCCESS_VALUE_COMP=1.0"""Computational representation of the success value."""_FAILURE_VALUE_COMP=0.0"""Computational representation of the failure value."""
[docs]@define(frozen=True)classBinaryTarget(Target,SerialMixin):"""Class for binary targets."""# FIXME[typing]: https://github.com/python-attrs/attrs/issues/1336success_value:ChoiceValue=field(default=True,validator=[instance_of(ChoiceValue),validate_not_nan],# type: ignore[call-overload]kw_only=True,)"""Experimental representation of the success value."""failure_value:ChoiceValue=field(default=False,validator=[instance_of(ChoiceValue),validate_not_nan],# type: ignore[call-overload]kw_only=True,)"""Experimental representation of the failure value."""@failure_value.validatordef_validate_values(self,_,value):"""Validate that the two choice values of the target are different."""ifvalue==self.success_value:raiseValueError(f"The two choice values of a '{BinaryTarget.__name__}' must be "f"different but the following value was provided for both choices of "f"target '{self.name}': {value}")
[docs]deftransform(self,data:pd.DataFrame)->pd.DataFrame:# noqa: D102# TODO: The method (signature) needs to be refactored, potentially when# enabling multi-target settings. The current input type suggests that passing# dataframes is allowed, but the code was designed for single targets and# desirability objectives, where only one column is present.assertdata.shape[1]==1# Validate target valuescol=data.iloc[:,[0]]invalid=col[~col.isin([self.success_value,self.failure_value]).values]iflen(invalid)>0:raiseInvalidTargetValueError(f"The following values entered for target '{self.name}' are not in the "f"set of accepted choice values "f"{{self.success_value, self.failure_value}}: \n{invalid}")# Transformsuccess_idx=data.iloc[:,0]==self.success_valuereturnpd.DataFrame(np.where(success_idx,_SUCCESS_VALUE_COMP,_FAILURE_VALUE_COMP),index=data.index,columns=data.columns,)
[docs]defsummary(self)->dict:# noqa: D102# See base class.returndict(Type=self.__class__.__name__,Name=self.name,Success_value=self.success_value,Failure_value=self.failure_value,)
# Collect leftover original slotted classes processed by `attrs.define`gc.collect()