Source code for baybe.utils.scaling
"""Scaling utilities."""
from __future__ import annotations
import gc
import itertools
from typing import TYPE_CHECKING
from attrs import define, field
from attrs.validators import deep_iterable, deep_mapping, instance_of
if TYPE_CHECKING:
    from botorch.models.transforms.input import InputTransform
    from torch import Tensor
[docs]
@define
class ColumnTransformer:
    """Class for applying separate transforms to different column groups of tensors."""
    mapping: dict[tuple[int, ...], InputTransform] = field()
    """A mapping defining what transform to apply to which columns."""
    _is_trained: bool = field(default=False, init=False)
    """Boolean indicating if the transformer has been trained."""
    @mapping.validator
    def _validate_mapping_types_lazily(self, attr, value):
        """Perform transform ``isinstance`` check using lazy import."""
        from botorch.models.transforms.input import InputTransform
        validator = deep_mapping(
            mapping_validator=instance_of(dict),
            key_validator=deep_iterable(
                member_validator=instance_of(int), iterable_validator=instance_of(tuple)
            ),
            value_validator=instance_of(InputTransform),
        )
        validator(self, attr, value)
    @mapping.validator
    def _validate_mapping_is_disjoint(self, _, value: dict):
        """Validate that the each column is assigned to at most one transformer."""
        for x, y in itertools.combinations(value.keys(), 2):
            if not set(x).isdisjoint(y):
                raise ValueError(
                    f"The provided column specifications {x} and {y} are not disjoint."
                )
[docs]
    def fit(self, x: Tensor, /) -> None:
        """Fit the transformer to the given tensor."""
        # Explicitly set flag to False to guarantee a clean state in case of
        # exceptions occurring in the for-loop below
        self._is_trained = False
        for cols, transformer in self.mapping.items():
            transformer.train()
            transformer(x[..., cols])
            transformer.eval()
        self._is_trained = True
[docs]
    def transform(self, x: Tensor, /) -> Tensor:
        """Transform the given tensor."""
        if not self._is_trained:
            raise RuntimeError(
                f"The {self.__class__.__name__} must be trained before it can be used."
            )
        out = x.clone()
        for cols, transformer in self.mapping.items():
            out[..., cols] = transformer(out[..., cols])
        return out
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()