Source code for baybe.targets.utils

"""Target utilities."""

from __future__ import annotations

import inspect
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec

from attrs import evolve, fields, fields_dict

from baybe.transformations.basic import IdentityTransformation

if TYPE_CHECKING:
    from baybe.targets.numerical import NumericalTarget

    P = ParamSpec("P")


def _validate_numerical_target_combination(
    t1: NumericalTarget, t2: NumericalTarget, /
) -> None:
    """Validate if two numerical targets can be combined."""
    from baybe.targets.numerical import NumericalTarget

    t1_ = evolve(t1, transformation=IdentityTransformation())  # type: ignore[call-arg]
    t2_ = evolve(t2, transformation=IdentityTransformation())  # type: ignore[call-arg]
    if t1_ != t2_:
        raise ValueError(
            f"Two objects of type '{NumericalTarget.__name__}' can only be "
            f"combined if they are identical in all attributes except for the "
            f"'{fields(NumericalTarget).transformation.name}'. "
            f"Given: {t1_!r} and {t2_!r}."
        )


[docs] def combine_numerical_targets( t1: NumericalTarget, t2: NumericalTarget, /, operator ) -> NumericalTarget: """Combine two numerical targets using a binary operator.""" _validate_numerical_target_combination(t1, t2) return evolve(t1, transformation=operator(t1.transformation, t2.transformation)) # type: ignore[call-arg]
[docs] def capture_constructor_info( constructor: Callable[Concatenate[type[NumericalTarget], P], NumericalTarget], ) -> Callable[Concatenate[type[NumericalTarget], P], NumericalTarget]: """Capture constructor history upon object creation. To be used as decorator with classmethods. """ @wraps(constructor) def wrapper( cls: type[NumericalTarget], *args: P.args, **kwargs: P.kwargs ) -> NumericalTarget: from baybe.targets.numerical import NumericalTarget target = constructor(cls, *args, **kwargs) # Reconstruct arguments sig = inspect.signature(constructor) bound = sig.bind(cls, *args, **kwargs) bound.apply_defaults() # To make it consistent with results for __init__ bound.arguments.pop("cls") # Ignore "cls" # Store argument history constructor_info: dict[str, Any] = { "constructor": constructor.__name__, **{ k: v for k, v in bound.arguments.items() if k not in fields_dict(target.__class__) # Ignore persistent attributes }, } object.__setattr__( target, fields(NumericalTarget)._constructor_info.name, constructor_info ) return target return wrapper