"""Binary targets."""importgcimportwarningsfromtypingimportTypeAliasimportnumpyasnpimportpandasaspdfromattrsimportdefine,fieldfromattrs.validatorsimportinstance_offromtyping_extensionsimportoverridefrombaybe.exceptionsimportInvalidTargetValueErrorfrombaybe.serializationimportSerialMixinfrombaybe.targets.baseimportTargetfrombaybe.utils.validationimportvalidate_not_nanChoiceValue:TypeAlias=bool|int|float|str"""Types of values that a :class:`BinaryTarget` can take."""_SUCCESS_VALUE_COMP=1.0"""Computational representation of the success value."""_FAILURE_VALUE_COMP=0.0"""Computational representation of the failure value."""
[docs]@define(frozen=True)classBinaryTarget(Target,SerialMixin):"""Class for binary targets."""# FIXME[typing]: https://github.com/python-attrs/attrs/issues/1336success_value:ChoiceValue=field(default=True,validator=[instance_of(ChoiceValue),validate_not_nan],# type: ignore[call-overload]kw_only=True,)"""Experimental representation of the success value."""failure_value:ChoiceValue=field(default=False,validator=[instance_of(ChoiceValue),validate_not_nan],# type: ignore[call-overload]kw_only=True,)"""Experimental representation of the failure value."""@failure_value.validatordef_validate_values(self,_,value):"""Validate that the two choice values of the target are different."""ifvalue==self.success_value:raiseValueError(f"The two choice values of a '{BinaryTarget.__name__}' must be "f"different but the following value was provided for both choices of "f"target '{self.name}': {value}")
[docs]@overridedeftransform(self,series:pd.Series|None=None,/,*,data:pd.DataFrame|None=None)->pd.Series:# >>>>>>>>>> Deprecationifnot((seriesisNone)^(dataisNone)):raiseValueError("Provide the data to be transformed as first positional argument.")ifdataisnotNone:assertdata.shape[1]==1series=data.iloc[:,0]warnings.warn("Providing a dataframe via the `data` argument is deprecated and ""will be removed in a future version. Please pass your data ""in form of a series as positional argument instead.",DeprecationWarning,)# Mypy does not infer from the above that `series` must be a series hereassertisinstance(series,pd.Series)# <<<<<<<<<< Deprecation# Validate target valuesinvalid=series[~series.isin([self.success_value,self.failure_value]).to_numpy()]iflen(invalid)>0:raiseInvalidTargetValueError(f"The following values entered for target '{self.name}' are not in the "f"set of accepted choice values "f"{set((self.success_value,self.failure_value))}: {set(invalid)}")# Transformsuccess_idx=series==self.success_valuereturnpd.Series(np.where(success_idx,_SUCCESS_VALUE_COMP,_FAILURE_VALUE_COMP),index=series.index,name=series.name,)