"""Base classes for all constraints."""from__future__importannotationsimportgcfromabcimportABC,abstractmethodfromtypingimportTYPE_CHECKING,Any,ClassVarimportpandasaspdfromattrimportdefine,fieldfromattr.validatorsimportge,instance_of,min_lenfrombaybe.constraints.deprecationimportstructure_constraintsfrombaybe.serializationimport(SerialMixin,converter,unstructure_base,)ifTYPE_CHECKING:importpolarsaspl
[docs]@defineclassConstraint(ABC,SerialMixin):"""Abstract base class for all constraints."""# class variables# TODO: it might turn out these are not needed at a later development stageeval_during_creation:ClassVar[bool]"""Class variable encoding whether the condition is evaluated during creation."""eval_during_modeling:ClassVar[bool]"""Class variable encoding whether the condition is evaluated during modeling."""eval_during_augmentation:ClassVar[bool]=False"""Class variable encoding whether the constraint could be considered during data augmentation."""numerical_only:ClassVar[bool]=False"""Class variable encoding whether the constraint is valid only for numerical parameters."""# Object variablesparameters:list[str]=field(validator=min_len(1))"""The list of parameters used for the constraint."""@parameters.validatordef_validate_params(# noqa: DOC101, DOC103self,_:Any,params:list[str])->None:"""Validate the parameter list. Raises: ValueError: If ``params`` contains duplicate values. """iflen(params)!=len(set(params)):raiseValueError(f"The given 'parameters' list must have unique values "f"but was: {params}.")
[docs]defsummary(self)->dict:"""Return a custom summarization of the constraint."""constr_dict=dict(Type=self.__class__.__name__,Affected_Parameters=self.parameters)returnconstr_dict
@propertydefis_continuous(self)->bool:"""Boolean indicating if this is a constraint over continuous parameters."""returnisinstance(self,ContinuousConstraint)@propertydefis_discrete(self)->bool:"""Boolean indicating if this is a constraint over discrete parameters."""returnisinstance(self,DiscreteConstraint)
[docs]@defineclassDiscreteConstraint(Constraint,ABC):"""Abstract base class for discrete constraints. Discrete constraints use conditions and chain them together to filter unwanted entries from the search space. """# class variableseval_during_creation:ClassVar[bool]=True# See base class.eval_during_modeling:ClassVar[bool]=False# See base class.
[docs]@abstractmethoddefget_invalid(self,data:pd.DataFrame)->pd.Index:"""Get the indices of dataframe entries that are invalid under the constraint. Args: data: A dataframe where each row represents a particular parameter combination. Returns: The dataframe indices of rows where the constraint is violated. """
[docs]defget_invalid_polars(self)->pl.Expr:"""Translate the constraint to Polars expression identifying undesired rows. Returns: The Polars expressions to pass to :meth:`polars.LazyFrame.filter`. Raises: NotImplementedError: If the constraint class does not have a Polars implementation. """raiseNotImplementedError(f"'{self.__class__.__name__}' does not have a Polars implementation.")
[docs]@defineclassContinuousConstraint(Constraint,ABC):"""Abstract base class for continuous constraints."""# class variableseval_during_creation:ClassVar[bool]=False# See base class.eval_during_modeling:ClassVar[bool]=True# See base class.numerical_only:ClassVar[bool]=True
# See base class.
[docs]@defineclassCardinalityConstraint(Constraint,ABC):"""Abstract base class for cardinality constraints. Places a constraint on the set of nonzero (i.e. "active") values among the specified parameters, bounding it between the two given integers, ``min_cardinality`` <= |{p_i : p_i != 0}| <= ``max_cardinality`` where ``{p_i}`` are the parameters specified for the constraint. Note that this can be equivalently regarded as L0-constraint on the vector containing the specified parameters. """# class variablenumerical_only:ClassVar[bool]=True# See base class.# object variablesmin_cardinality:int=field(default=0,validator=[instance_of(int),ge(0)])"The minimum required cardinality."max_cardinality:int=field(validator=instance_of(int))"The maximum allowed cardinality."@max_cardinality.defaultdef_default_max_cardinality(self):"""Use the number of involved parameters as the upper limit by default."""returnlen(self.parameters)def__attrs_post_init__(self):"""Validate the cardinality bounds. Raises: ValueError: If the provided cardinality bounds are invalid. ValueError: If the provided cardinality bounds impose no constraint. """ifself.min_cardinality>self.max_cardinality:raiseValueError(f"The lower cardinality bound cannot be larger than the upper bound. "f"Provided values: {self.max_cardinality=}, {self.min_cardinality=}.")ifself.max_cardinality>len(self.parameters):raiseValueError(f"The cardinality bound cannot exceed the number of parameters. "f"Provided values: {self.max_cardinality=}, {len(self.parameters)=}.")ifself.min_cardinality==0andself.max_cardinality==len(self.parameters):raiseValueError(f"No constraint of type `{self.__class__.__name__}' is required "f"when the lower cardinality bound is zero and the upper bound equals "f"the number of parameters. Provided values: {self.min_cardinality=}, "f"{self.max_cardinality=}, {len(self.parameters)=}")
[docs]classContinuousNonlinearConstraint(ContinuousConstraint,ABC):"""Abstract base class for continuous nonlinear constraints."""
# Register (un-)structure hooksconverter.register_unstructure_hook(Constraint,unstructure_base)# Currently affected by a deprecation# converter.register_structure_hook(Constraint, get_base_structure_hook(Constraint))converter.register_structure_hook(Constraint,structure_constraints)# Collect leftover original slotted classes processed by `attrs.define`gc.collect()