"""Categorical parameters."""importgcfromfunctoolsimportcached_propertyimportnumpyasnpimportpandasaspdfromattrsimportConverter,define,fieldfromattrs.validatorsimportdeep_iterable,instance_of,min_lenfromtyping_extensionsimportoverridefrombaybe.parameters.baseimport_DiscreteLabelLikeParameterfrombaybe.parameters.enumimportCategoricalEncodingfrombaybe.parameters.validationimportvalidate_unique_valuesfrombaybe.utils.conversionimportnonstring_to_tuplefrombaybe.utils.numericalimportDTypeFloatNumpydef_convert_values(value,self,field)->tuple[str,...]:"""Sort and convert values for categorical parameters."""value=nonstring_to_tuple(value,self,field)returntuple(sorted(value,key=lambdax:(str(type(x)),x)))def_validate_label_min_len(self,attr,value)->None:"""An attrs-compatible validator to ensure minimum label length."""# noqa: D401ifisinstance(value,str)andlen(value)<1:raiseValueError(f"Strings used as '{attr.alias}' for '{self.__class__.__name__}' must "f"have at least 1 character.")
[docs]@define(frozen=True,slots=False)classCategoricalParameter(_DiscreteLabelLikeParameter):"""Parameter class for categorical parameters."""# object variables_values:tuple[str|bool,...]=field(alias="values",converter=Converter(_convert_values,takes_self=True,takes_field=True),# type: ignorevalidator=(# type: ignorevalidate_unique_values,deep_iterable(member_validator=(instance_of((str,bool)),_validate_label_min_len),iterable_validator=min_len(2),),),)# See base class.encoding:CategoricalEncoding=field(default=CategoricalEncoding.OHE,converter=CategoricalEncoding)# See base class.@override@propertydefvalues(self)->tuple:"""The values of the parameter."""returnself._values@override@cached_propertydefcomp_df(self)->pd.DataFrame:ifself.encodingisCategoricalEncoding.OHE:cols=[f"{self.name}_{'b'ifisinstance(val,bool)else''}{val}"forvalinself.values]comp_df=pd.DataFrame(np.eye(len(self.values),dtype=DTypeFloatNumpy),columns=cols)elifself.encodingisCategoricalEncoding.INT:comp_df=pd.DataFrame(range(len(self.values)),dtype=DTypeFloatNumpy,columns=[self.name])comp_df.index=pd.Index(self.values)returncomp_df
[docs]@define(frozen=True,slots=False)classTaskParameter(CategoricalParameter):"""Parameter class for task parameters."""encoding:CategoricalEncoding=field(default=CategoricalEncoding.INT,init=False)
# See base class.# Collect leftover original slotted classes processed by `attrs.define`gc.collect()