Source code for baybe.transformations.composite
"""Composite transformations."""
from __future__ import annotations
from functools import reduce
from typing import TYPE_CHECKING, Any
from attrs import define, field
from attrs.validators import and_, deep_iterable, instance_of, max_len, min_len
from typing_extensions import override
from baybe.transformations.base import Transformation
from baybe.transformations.utils import compress_transformations
from baybe.utils.basic import compose, to_tuple
from baybe.utils.interval import Interval
if TYPE_CHECKING:
from torch import Tensor
[docs]
@define(frozen=True)
class ChainedTransformation(Transformation):
"""A chained transformation composing several individual transformations."""
# https://github.com/python-attrs/attrs/issues/1462
__hash__ = object.__hash__
transformations: tuple[Transformation, ...] = field(
converter=compress_transformations,
validator=[
min_len(1),
deep_iterable(member_validator=instance_of(Transformation)),
],
)
"""The transformations to be composed (the first element gets applied first)."""
@override
def __eq__(self, other: Any, /) -> bool:
if len(self.transformations) == 1:
# A chained transformation with only one element is equivalent to that
# element
return self.transformations[0] == other
if isinstance(other, ChainedTransformation):
return self.transformations == other.transformations
return NotImplemented
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
interval = Interval.create(interval)
return reduce(
lambda acc, t: t.get_codomain(acc), self.transformations, interval
)
[docs]
@override
def get_image(self, interval: Interval | None = None, /) -> Interval:
interval = Interval.create(interval)
return reduce(lambda acc, t: t.get_image(acc), self.transformations, interval)
@override
def __call__(self, x: Tensor, /) -> Tensor:
return compose(*(t.__call__ for t in self.transformations))(x)
[docs]
@define(frozen=True)
class AdditiveTransformation(Transformation):
"""A transformation implementing the sum of two transformations."""
transformations: tuple[Transformation, Transformation] = field(
converter=to_tuple,
validator=deep_iterable(
iterable_validator=and_(min_len(2), max_len(2)),
member_validator=instance_of(Transformation),
),
)
"""The transformations to be added."""
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
interval = Interval.create(interval)
im1 = self.transformations[0].get_codomain(interval)
im2 = self.transformations[1].get_codomain(interval)
return Interval(im1.lower + im2.lower, im1.upper + im2.upper)
@override
def __call__(self, x: Tensor, /) -> Tensor:
return self.transformations[0](x) + self.transformations[1](x)
[docs]
@define(frozen=True)
class MultiplicativeTransformation(Transformation):
"""A transformation implementing the product of two transformations."""
transformations: tuple[Transformation, Transformation] = field(
converter=to_tuple,
validator=deep_iterable(
iterable_validator=and_(min_len(2), max_len(2)),
member_validator=instance_of(Transformation),
),
)
"""The transformations to be multiplied."""
[docs]
@override
def get_codomain(self, interval: Interval | None = None, /) -> Interval:
interval = Interval.create(interval)
im1 = self.transformations[0].get_codomain(interval)
im2 = self.transformations[1].get_codomain(interval)
boundary_products = [
im1.lower * im2.lower,
im1.lower * im2.upper,
im1.upper * im2.lower,
im1.upper * im2.upper,
]
return Interval(min(boundary_products), max(boundary_products))
@override
def __call__(self, x: Tensor, /) -> Tensor:
return self.transformations[0](x) * self.transformations[1](x)