Source code for baybe.transformations.basic
"""Basic transformations."""
from __future__ import annotations
import gc
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
import numpy as np
from attrs import Factory, define, field
from attrs.validators import gt, is_callable
from typing_extensions import override
from baybe.serialization import converter
from baybe.transformations.base import (
MonotonicTransformation,
Transformation,
_image_equals_codomain,
)
from baybe.utils.dataframe import to_tensor
from baybe.utils.interval import Interval
from baybe.utils.validation import finite_float
if TYPE_CHECKING:
from torch import Tensor
from baybe.targets.botorch import AffinePosteriorTransform
TensorCallable = Callable[[Tensor], Tensor]
"""Type alias for a torch-based function mapping from reals to reals."""
[docs]
@define(frozen=True)
class CustomTransformation(Transformation):
"""A custom transformation applying an arbitrary torch callable."""
function: TensorCallable = field(validator=is_callable())
"""The torch callable representing the transformation."""
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
raise NotImplementedError(
"Custom transformations do not provide details about their codomain."
)
@override
def __call__(self, x: Tensor, /) -> Tensor:
return self.function(x)
[docs]
@define(frozen=True)
class IdentityTransformation(MonotonicTransformation):
"""The identity transformation."""
[docs]
def to_botorch_posterior_transform(self) -> AffinePosteriorTransform:
"""Convert to BoTorch posterior transform.
Returns:
The representation of the transform as BoTorch posterior transform.
"""
from baybe.targets.botorch import AffinePosteriorTransform
return AffinePosteriorTransform(factor=1.0, shift=0.0)
@override
def __call__(self, x: Tensor, /) -> Tensor:
return x
@override
def __or__(self, other: Any) -> Transformation:
if isinstance(other, Transformation):
return other
return NotImplemented
[docs]
@define(frozen=True, init=False)
class ClampingTransformation(MonotonicTransformation):
"""A transformation clamping values between specified cutoffs."""
cutoffs: Interval = field(converter=Interval.create) # type: ignore[misc]
"""The range to which input values are clamped."""
[docs]
def __init__(self, min: float | None = None, max: float | None = None) -> None:
if min is None and max is None:
raise ValueError("At least one cutoff value must be specified.")
self.__attrs_init__(cutoffs=Interval(min, max))
@override
def __call__(self, x: Tensor, /) -> Tensor:
return x.clamp(*self.cutoffs.to_tuple())
[docs]
@define(frozen=True, init=False)
class AffineTransformation(MonotonicTransformation):
"""An affine transformation."""
# https://github.com/python-attrs/attrs/issues/1462
__hash__ = object.__hash__
factor: float = field(default=1.0, converter=float, validator=finite_float)
"""The multiplicative factor of the transformation."""
shift: float = field(default=0.0, converter=float, validator=finite_float)
"""The constant shift of the transformation."""
[docs]
def __init__(
self,
factor: float = 1.0,
shift: float = 0.0,
*,
shift_first: bool = False,
) -> None:
if shift_first:
shift = shift * factor
if not np.isfinite(shift):
raise OverflowError("The transformation produces infinite values.")
self.__attrs_init__(factor=factor, shift=shift)
@override
def __eq__(self, other: Any, /) -> bool:
if isinstance(other, IdentityTransformation):
# An affine transformation without shift and scaling is effectively an
# identity transformation
return self.factor == 1.0 and self.shift == 0.0
if isinstance(other, AffineTransformation):
# TODO: https://github.com/python-attrs/attrs/issues/1452
return self.factor == other.factor and self.shift == other.shift
return NotImplemented
[docs]
def to_botorch_posterior_transform(self) -> AffinePosteriorTransform:
"""Convert to BoTorch posterior transform.
Returns:
The representation of the transform as BoTorch posterior transform.
"""
from baybe.targets.botorch import AffinePosteriorTransform
return AffinePosteriorTransform(self.factor, self.shift)
[docs]
@classmethod
def from_values_mapped_to_unit_interval(
cls, mapped_to_zero: float, mapped_to_one: float
) -> AffineTransformation:
"""Create an affine transform by specifying reference values mapped to 0/1.
Args:
mapped_to_zero: The input value that will be mapped to zero.
mapped_to_one: The input value that will be mapped to one.
Returns:
An affine transformation calibrated to map the specified values to the
unit interval.
Example:
>>> import torch
>>> from baybe.transformations import AffineTransformation as AffineT
>>> t1 = AffineT.from_values_mapped_to_unit_interval(3, 7)
>>> t2 = AffineT.from_values_mapped_to_unit_interval(7, 3)
>>> t1(torch.tensor([3, 7]))
tensor([0., 1.])
>>> t2(torch.tensor([3, 7]))
tensor([1., 0.])
"""
return AffineTransformation(
shift=-mapped_to_zero,
factor=1 / (mapped_to_one - mapped_to_zero),
shift_first=True,
)
@override
def __call__(self, x: Tensor, /) -> Tensor:
# Handle problematic case where input contains infinity (to avoid 0 * inf)
if self.factor == 0.0:
return x.new_full(x.shape, fill_value=self.shift)
return x * self.factor + self.shift
[docs]
@_image_equals_codomain
@define(frozen=True)
class TwoSidedAffineTransformation(Transformation):
"""A transformation with two affine segments on either side of a midpoint."""
slope_left: float = field(converter=float, validator=finite_float)
"""The slope of the affine segment to the left of the midpoint."""
slope_right: float = field(converter=float, validator=finite_float)
"""The slope of the affine segment to the right of the midpoint."""
midpoint: float = field(default=0.0, converter=float, validator=finite_float)
"""The midpoint where the two affine segments meet."""
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
interval = Interval.create(interval)
image_lower = self(to_tensor(interval.lower)).item()
image_upper = self(to_tensor(interval.upper)).item()
min_val, max_val = sorted([image_lower, image_upper])
if interval.contains(self.midpoint):
return Interval(min(0, min_val), max(0, max_val))
else:
return Interval(min_val, max_val)
@override
def __call__(self, x: Tensor, /) -> Tensor:
import torch
# Note: the if conditions handle the problematic cases where input contains
# infinity (to avoid 0 * inf)
mid = self.midpoint
return torch.where(
x < mid,
(x - mid) * sl if (sl := self.slope_left) else x.new_zeros(x.shape),
(x - mid) * sr if (sr := self.slope_right) else x.new_zeros(x.shape),
)
[docs]
@_image_equals_codomain
@define(frozen=True)
class BellTransformation(Transformation):
"""A Gaussian bell curve transformation."""
center: float = field(default=0.0, converter=float, validator=finite_float)
"""The center point of the bell curve."""
sigma: float = field(
default=1.0, converter=float, validator=[finite_float, gt(0.0)]
)
"""The scale parameter of the transformation.
Concerning the width of the bell curve, it has the same interpretation as the
standard deviation in a Gaussian distribution. (The magnitude of the curve is
not affected and always reaches a maximum value of 1 at the center.)
"""
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
interval = Interval.create(interval)
image_lower = self(to_tensor(interval.lower)).item()
image_upper = self(to_tensor(interval.upper)).item()
if interval.contains(self.center):
return Interval(min(image_lower, image_upper), 1)
else:
return Interval(*sorted([image_lower, image_upper]))
@override
def __call__(self, x: Tensor, /) -> Tensor:
return x.sub(self.center).div(self.sigma).pow(2.0).div(2.0).neg().exp()
[docs]
@_image_equals_codomain
@define(frozen=True)
class AbsoluteTransformation(Transformation):
"""A transformation computing absolute values."""
_transformation: Transformation = field(
factory=lambda: TwoSidedAffineTransformation(slope_left=-1, slope_right=1),
init=False,
repr=False,
)
"""Internal transformation object handling the operations."""
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
return self._transformation.get_codomain(interval)
@override
def __call__(self, x: Tensor, /) -> Tensor:
return self._transformation(x)
[docs]
@_image_equals_codomain
@define(frozen=True)
class TriangularTransformation(Transformation):
r"""A transformation with a triangular shape.
The transformation is defined by a peak location between two cutoff values. Outside
the region delimited by the cutoff values, the transformation is zero. Inside the
region, the transformed values increase linearly from both cutoffs to the peak,
where the highest value of 1 is reached:
.. math::
f(x) =
\begin{cases}
0 & \text{if } x < c_1 \\
\frac{x - c_1}{p - c_1} & \text{if } c_1 \leq x < p \\
\frac{c_2 - x}{c_2 - p} & \text{if } p \leq x < c_2 \\
0 & \text{if } c_2 \leq x
\end{cases}
where :math:`c_1` and :math:`c_2` are the left and right cutoffs, respectively, and
:math:`p` is the peak location, with :math:`c_1 < p < c_2`.
"""
# TODO[typing]: https://github.com/python-attrs/attrs/issues/1435
cutoffs: Interval = field(converter=Interval.create) # type: ignore[misc]
"""The cutoff values where the transformation reaches zero."""
peak: float = field(
default=Factory(lambda self: self.cutoffs.center, takes_self=True),
converter=float,
)
"""The peak location of the transformation. By default, centered between cutoffs."""
_transformation: Transformation = field(init=False, repr=False)
"""Internal transformation object handling the operations."""
def __attrs_post_init__(self) -> None:
# We use post-init here to ensure the attribute validators run first,
# since otherwise the validators of the transformation object created here
# would be executed first, raising difficult-to-interpret errors
slope_left = 1 / self.margins[0]
slope_right = -1 / self.margins[1]
if np.isinf([slope_left, slope_right]).any():
raise OverflowError(
"The triangular transformation could not be initialized because "
"the cutoffs are too close to the peak, leading to numerical overflow "
"when computing the slopes."
)
t = (
TwoSidedAffineTransformation(
slope_left=1 / self.margins[0],
slope_right=-1 / self.margins[1],
midpoint=self.peak,
)
+ 1
).clamp(min=0)
object.__setattr__(self, "_transformation", t)
@cutoffs.validator
def _validate_cutoffs(self, _, cutoffs: Interval) -> None:
if not cutoffs.is_bounded:
raise ValueError(
"The cutoffs of the transformation must be bounded. "
f"Given cutoffs: {cutoffs.to_tuple()}."
)
@peak.validator
def _validate_peak(self, _, peak: float) -> None:
if not (self.cutoffs.lower < peak < self.cutoffs.upper):
raise ValueError(
f"The peak of the transformation must be located strictly between the "
f"specified cutoff values. Given peak location: {peak}. "
f"Given cutoffs: {self.cutoffs.to_tuple()}."
)
@property
def margins(self) -> tuple[float, float]:
"""The left and right margin denoting the width of the triangle."""
return self.peak - self.cutoffs.lower, self.cutoffs.upper - self.peak
[docs]
@classmethod
def from_margins(
cls, peak: float, margins: Sequence[float]
) -> TriangularTransformation:
"""Create a triangular transformation from a peak location and margins."""
if len(margins) != 2:
raise ValueError(
"The margins must be provided as a sequence of two values."
)
return cls(peak=peak, cutoffs=Interval(peak - margins[0], peak + margins[1]))
[docs]
@classmethod
def from_width(cls, peak: float, width: float) -> TriangularTransformation:
"""Create a triangular transformation from a peak location and width."""
return cls.from_margins(peak, (width / 2, width / 2))
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
return self._transformation.get_codomain(interval)
@override
def __call__(self, x: Tensor, /) -> Tensor:
return self._transformation(x)
[docs]
@_image_equals_codomain
@define(frozen=True)
class LogarithmicTransformation(MonotonicTransformation):
"""A logarithmic transformation."""
@override
def __call__(self, x: Tensor, /) -> Tensor:
return x.log()
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
# Everything smaller than zero does not extend the codomain but gives NaN
interval = Interval.create(interval)
return super().get_codomain(interval.clamp(min=0.0))
[docs]
@define(frozen=True)
class ExponentialTransformation(MonotonicTransformation):
"""An exponential transformation."""
@override
def __call__(self, x: Tensor, /) -> Tensor:
return x.exp()
[docs]
@_image_equals_codomain
@define(frozen=True)
class PowerTransformation(Transformation):
"""A transformation computing the power."""
exponent: int = field(converter=float, validator=finite_float)
"""The exponent of the power transformation."""
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
interval = Interval.create(interval)
image_lower = self(to_tensor(interval.lower)).item()
image_upper = self(to_tensor(interval.upper)).item()
if self.exponent % 2 == 0 and interval.contains(0.0):
return Interval(0, max(image_lower, image_upper))
else:
return Interval(*sorted([image_lower, image_upper]))
@override
def __call__(self, x: Tensor, /) -> Tensor:
if not (int(self.exponent) == self.exponent) and any(x < 0):
raise RuntimeError(
"For non-integer exponents, the provided input tensor must contain "
"non-negative elements only."
)
return x.pow(self.exponent)
[docs]
@define(frozen=True)
class SigmoidTransformation(MonotonicTransformation):
"""A sigmoid transformation."""
center: float = field(default=0.0, converter=float)
"""The center of the sigmoid function, where it crosses 0.5."""
steepness: float = field(default=1.0, converter=float)
"""The steepness of the sigmoid function."""
[docs]
@classmethod
def from_anchors(cls, anchors: Sequence[Sequence[float]]) -> SigmoidTransformation:
"""Create a sigmoid transformation from two anchor points.
Args:
anchors: The anchor points defining the sigmoid transformation.
Must be convertible to two pairs of floats, where each pair represents
an anchor point through which the sigmoid curve passes.
Raises:
ValueError: If the input given as anchors does not represent two points.
ValueError: If the ordinates of the anchors are not in the unit interval.
Returns:
A sigmoid transformation passing through the specified anchor points.
Example:
>>> import torch
>>> p1 = (-2, 0.1)
>>> p2 = (5, 0.6)
>>> t = SigmoidTransformation.from_anchors([p1, p2])
>>> out = t(torch.tensor([p1[0], p2[0]]))
>>> assert torch.equal(out, torch.tensor([p1[1], p2[1]]))
"""
import cattrs
# Extract point coordinates from the input
try:
anchors = cattrs.structure(
anchors, tuple[tuple[float, float], tuple[float, float]]
) # type: ignore[call-arg]
except cattrs.IterableValidationError as ex:
raise ValueError(
f"The specified anchor point argument must be convertible to two "
f"pairs of floats. Given: {anchors}"
) from ex
(x1, y1), (x2, y2) = anchors
if not ((0.0 < y1 < 1.0) and (0.0 < y2 < 1.0)):
raise ValueError(
f"The ordinates of the anchor points must be in the open "
f"interval (0, 1). Given: {y1=} and {y2=}."
)
k1 = np.log(1 / y1 - 1)
k2 = np.log(1 / y2 - 1)
shift = (k2 * x1 - k1 * x2) / (k2 - k1)
steepness = (k2 - k1) / (x2 - x1)
return SigmoidTransformation(shift, steepness)
@override
def __call__(self, x: Tensor, /) -> Tensor:
import torch
return 1 / (1 + torch.exp(self.steepness * (x - self.center)))
@converter.register_structure_hook
def _(dct, _) -> ClampingTransformation:
cutoffs = Interval(**dct["cutoffs"])
return ClampingTransformation(*cutoffs.to_tuple())
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()