baybe.surrogates.utils.batchify

baybe.surrogates.utils.batchify(posterior: Callable[[Surrogate, Tensor], tuple[Tensor, Tensor]])[source]

Wrap Surrogate posterior functions to enable proper batching.

More precisely, this wraps model that are incompatible with t- and q-batching such that they become able to process batched inputs.

Parameters:

posterior (Callable[[Surrogate, Tensor], tuple[Tensor, Tensor]]) – The original posterior function.

Return type:

Callable[[Surrogate, Tensor], tuple[Tensor, Tensor]]

Returns:

The wrapped posterior function.