ColumnTransformer

class baybe.utils.scaling.ColumnTransformer[source]

Bases: object

Class for applying separate transforms to different column groups of tensors.

Public methods

__init__(mapping)

Method generated by attrs for class ColumnTransformer.

fit(x, /)

Fit the transformer to the given tensor.

transform(x, /)

Transform the given tensor.

Public attributes and properties

mapping

A mapping defining what transform to apply to which columns.

__init__(mapping: dict[tuple[int, ...], InputTransform])

Method generated by attrs for class ColumnTransformer.

For details on the parameters, see Public attributes and properties.

fit(x: Tensor, /)[source]

Fit the transformer to the given tensor.

Return type:

None

transform(x: Tensor, /)[source]

Transform the given tensor.

Return type:

Tensor

mapping: dict[tuple[int, ...], InputTransform]

A mapping defining what transform to apply to which columns.