Source code for baybe.utils.scaling
"""Scaling utilities."""
from __future__ import annotations
import gc
import itertools
from typing import TYPE_CHECKING
from attrs import define, field
from attrs.validators import deep_iterable, deep_mapping, instance_of
if TYPE_CHECKING:
from botorch.models.transforms.input import InputTransform
from torch import Tensor
[docs]
@define
class ColumnTransformer:
"""Class for applying separate transforms to different column groups of tensors."""
mapping: dict[tuple[int, ...], InputTransform] = field()
"""A mapping defining what transform to apply to which columns."""
_is_trained: bool = field(default=False, init=False)
"""Boolean indicating if the transformer has been trained."""
@mapping.validator
def _validate_mapping_types_lazily(self, attr, value):
"""Perform transform ``isinstance`` check using lazy import."""
from botorch.models.transforms.input import InputTransform
validator = deep_mapping(
mapping_validator=instance_of(dict),
key_validator=deep_iterable(
member_validator=instance_of(int), iterable_validator=instance_of(tuple)
),
value_validator=instance_of(InputTransform),
)
validator(self, attr, value)
@mapping.validator
def _validate_mapping_is_disjoint(self, _, value: dict):
"""Validate that the each column is assigned to at most one transformer."""
for x, y in itertools.combinations(value.keys(), 2):
if not set(x).isdisjoint(y):
raise ValueError(
f"The provided column specifications {x} and {y} are not disjoint."
)
[docs]
def fit(self, x: Tensor, /) -> None:
"""Fit the transformer to the given tensor."""
# Explicitly set flag to False to guarantee a clean state in case of
# exceptions occurring in the for-loop below
self._is_trained = False
for cols, transformer in self.mapping.items():
transformer.train()
transformer(x[..., cols])
transformer.eval()
self._is_trained = True
[docs]
def transform(self, x: Tensor, /) -> Tensor:
"""Transform the given tensor."""
if not self._is_trained:
raise RuntimeError(
f"The {self.__class__.__name__} must be trained before it can be used."
)
out = x.clone()
for cols, transformer in self.mapping.items():
out[..., cols] = transformer(out[..., cols])
return out
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()