Registering Custom Hooks

This example demonstrates the basic mechanics of the register_hooks utility, which lets you hook into any callable of your choice:

  • We define a hook that is compatible with the general RecommenderProtocol.recommend interface,

  • attach it to a recommender,

  • and watch it take action.

Imports

from dataclasses import dataclass
from time import perf_counter
from types import MethodType
from baybe.parameters import NumericalDiscreteParameter
from baybe.recommenders import RandomRecommender
from baybe.searchspace import SearchSpace
from baybe.utils import register_hooks

Defining the Hooks

We start by defining a simple hook that lets us inspect the names of the parameters involved in the recommendation process. For this purpose, we match its signature to that of RecommenderProtocol.recommend:

Signature components

Note that you are flexible in designing the signature of your hooks. For instance, function parameters and type annotations that you do not need in the hook body can simply be omitted. The exact rules to follow are described here.

def print_parameter_names_hook(self: RandomRecommender, searchspace: SearchSpace):
    """Print the names of the parameters spanning the search space."""
    print(f"Recommender type: {self.__class__.__name__}")
    print(f"Search space parameters: {[p.name for p in searchspace.parameters]}")

Additionally, we set up a class that provides a combination of hooks for measuring the time needed to compute the recommendations:

@dataclass
class ElapsedTimePrinter:
    """Helper class for measuring the time between two calls."""

    last_call_time: float | None = None

    def start(printer_instance):
        """Start the timer."""
        printer_instance.last_call_time = perf_counter()

    def measure(printer_instance, self: RandomRecommender):
        """Measure the elapsed time."""
        if printer_instance.last_call_time is None:
            raise RuntimeError("Must call `start` first!")
        elapsed = perf_counter() - printer_instance.last_call_time
        print(f"Consumed time of {self.__class__.__name__}: {elapsed}")

Hook instance vs. target instance

Notice the difference between the object belonging to the hook-providing class (named printer_instance) and the object whose method we intend to override (named self). This distinction is necessary because of the particular way we attach the hook below, which binds self to the object carrying the target callable as a method.

Monkeypatching

Next, we create our recommender and monkeypatch its recommend method:

timer = ElapsedTimePrinter()
recommender = RandomRecommender()
recommender.recommend = MethodType(
    register_hooks(
        RandomRecommender.recommend,
        pre_hooks=[print_parameter_names_hook, timer.start],
        post_hooks=[timer.measure],
    ),
    recommender,
)

Bound methods

Note that the explicit binding via MethodType above is required because we decorate the (unbound) RandomRecommender.recommend function with our hooks and attach it as an overridden method to the recommender instance.

Alternatively, we could have …

  • … overridden the class callable itself via RandomRecommender.recommend = register_hooks(RandomRecommender.recommend, ...) which, however, would affect all instances of RandomRecommender or

  • … used the bound method of the instance as reference via recommender.recommend = register_hooks(recommender.recommend, ...) but then the hooks would not have access to the recommender instance as it is not explicitly exposed in the method’s signature.

Triggering the Hooks

When we now apply the recommender in a specific context, we immediately see the effect of the hooks:

temperature = NumericalDiscreteParameter("Temperature", values=[90, 105, 120])
concentration = NumericalDiscreteParameter("Concentration", values=[0.057, 0.1, 0.153])
searchspace = SearchSpace.from_product([temperature, concentration])
recommendation = recommender.recommend(batch_size=3, searchspace=searchspace)
Recommender type: RandomRecommender
Search space parameters: ['Concentration', 'Temperature']
Consumed time of RandomRecommender: 0.0014909209999132145