"""Functionality for constraint conditions."""
from __future__ import annotations
import operator as ops
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Any
import numpy as np
import pandas as pd
from attr import define, field
from attr.validators import in_
from attrs.validators import min_len
from cattrs.gen import override
from funcy import rpartial
from numpy.typing import ArrayLike
from baybe.parameters.validation import validate_unique_values
from baybe.serialization import (
SerialMixin,
converter,
get_base_structure_hook,
unstructure_base,
)
from baybe.utils.numerical import DTypeFloatNumpy
if TYPE_CHECKING:
import polars as pl
def _is_not_close(x: ArrayLike, y: ArrayLike, rtol: float, atol: float) -> np.ndarray:
"""Return a boolean array indicating where ``x`` and ``y`` are not close.
The counterpart to ``numpy.isclose``.
Args:
x: First input array to compare.
y: Second input array to compare.
rtol: The relative tolerance parameter.
atol: The absolute tolerance parameter.
Returns:
A boolean array of where ``x`` and ``y`` are not equal within the
given tolerances.
"""
return np.logical_not(_is_close(x, y, rtol=rtol, atol=atol))
def _is_close(x: ArrayLike, y: ArrayLike, rtol: float, atol: float) -> np.ndarray:
"""Return a boolean array indicating where ``x`` and ``y`` are close.
The equivalent to :func:``numpy.isclose``.
Using ``numpy.isclose`` with Polars dataframes results in this error:
``TypeError: ufunc 'isfinite' not supported for the input types``.
Args:
x: First input array to compare.
y: Second input array to compare.
rtol: The relative tolerance parameter.
atol: The absolute tolerance parameter.
Returns:
A boolean array of where ``x`` and ``y`` are equal within the
given tolerances.
"""
return np.abs(np.subtract(x, y)) <= atol + rtol * np.abs(y)
# provide threshold operators
_threshold_operators: dict[str, Callable] = {
"<": ops.lt,
"<=": ops.le,
"=": rpartial(_is_close, rtol=0.0),
"==": rpartial(_is_close, rtol=0.0),
"!=": rpartial(_is_not_close, rtol=0.0),
">": ops.gt,
">=": ops.ge,
}
# define operators that are eligible for tolerance
_valid_tolerance_operators = ["=", "==", "!="]
_valid_logic_combiners = {
"AND": ops.and_,
"OR": ops.or_,
"XOR": ops.xor,
}
[docs]
class Condition(ABC, SerialMixin):
"""Abstract base class for all conditions.
Conditions always evaluate an expression regarding a single parameter.
Conditions are part of constraints, a constraint can have multiple conditions.
"""
[docs]
@abstractmethod
def evaluate(self, data: pd.Series) -> pd.Series:
"""Evaluate the condition on a given data series.
Args:
data: A series containing parameter values.
Returns:
A boolean series indicating which elements satisfy the condition.
"""
[docs]
@abstractmethod
def to_polars(self, expr: pl.Expr, /) -> pl.Expr:
"""Apply the condition to a Polars expression.
Args:
expr: Input expression, for instance column selection etc.
Returns:
An expression that can be used for filtering.
"""
[docs]
@define
class ThresholdCondition(Condition):
"""Class for modelling threshold-based conditions."""
# object variables
threshold: float = field()
"""The threshold value used in the condition."""
operator: str = field(validator=[in_(_threshold_operators)])
"""The operator used in the condition."""
tolerance: float | None = field()
"""A numerical tolerance. Set to a reasonable default tolerance."""
@tolerance.default
def _tolerance_default(self) -> float | None:
"""Create the default value for the tolerance."""
# Default value for the tolerance.
return 1e-8 if self.operator in _valid_tolerance_operators else None
@tolerance.validator
def _validate_tolerance(self, _: Any, value: float) -> None: # noqa: DOC101, DOC103
"""Validate the threshold condition tolerance.
Raises:
ValueError: If the operator does not allow for setting a tolerance.
ValueError: If the operator allows for setting a tolerance, but the provided
tolerance is either less than 0 or ``None``.
"""
if (self.operator not in _valid_tolerance_operators) and (value is not None):
raise ValueError(
f"Setting the tolerance for a threshold condition is only valid "
f"with the following operators: {_valid_tolerance_operators}."
)
if self.operator in _valid_tolerance_operators:
if (value is None) or (value <= 0.0):
raise ValueError(
f"When using a tolerance-enabled operator"
f" ({_valid_tolerance_operators}) the tolerance cannot be None "
f"or <= 0.0, but was {value}."
)
def _make_operator_function(self):
"""Generate a function using operators to filter out undesired rows."""
func = rpartial(_threshold_operators[self.operator], self.threshold)
if self.operator in _valid_tolerance_operators:
func = rpartial(func, atol=self.tolerance)
return func
[docs]
def evaluate(self, data: pd.Series) -> pd.Series: # noqa: D102
# See base class.
if data.dtype.kind not in "iufb":
raise ValueError(
"You tried to apply a threshold condition to non-numeric data. "
"This operation is error-prone and not supported. Only use threshold "
"conditions with numerical parameters."
)
func = self._make_operator_function()
return data.apply(func)
[docs]
def to_polars(self, expr: pl.Expr, /) -> pl.Expr: # noqa: D102
# See base class.
op = self._make_operator_function()
return op(expr)
[docs]
@define
class SubSelectionCondition(Condition):
"""Class for defining valid parameter entries."""
# object variables
_selection: tuple = field(
converter=tuple,
# FIXME[typing]: https://github.com/python-attrs/attrs/issues/1197
validator=[
min_len(1),
validate_unique_values, # type: ignore
],
)
"""The internal list of items which are considered valid."""
@property
def selection(self) -> tuple: # noqa: D102
"""The list of items which are considered valid."""
return tuple(
DTypeFloatNumpy(itm) if isinstance(itm, (float, int, bool)) else itm
for itm in self._selection
)
[docs]
def evaluate(self, data: pd.Series) -> pd.Series: # noqa: D102
# See base class.
return data.isin(self.selection)
[docs]
def to_polars(self, expr: pl.Expr, /) -> pl.Expr: # noqa: D102
# See base class.
return expr.is_in(self.selection)
# Register (un-)structure hooks
_overrides = {
"_selection": override(rename="selection"),
}
# FIXME[typing]: https://github.com/python/mypy/issues/4717
converter.register_structure_hook(
Condition,
get_base_structure_hook(Condition, overrides=_overrides), # type: ignore
)
converter.register_unstructure_hook(
Condition, partial(unstructure_base, overrides=_overrides)
)