Source code for baybe.surrogates.gaussian_process.kernel_factory
"""Kernel factories for the Gaussian process surrogate."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Protocolfromattrsimportdefine,fieldfromattrs.validatorsimportinstance_offrombaybe.kernels.baseimportKernelfrombaybe.searchspaceimportSearchSpacefrombaybe.serialization.coreimport(converter,get_base_structure_hook,unstructure_base,)frombaybe.serialization.mixinimportSerialMixinifTYPE_CHECKING:fromtorchimportTensor
[docs]classKernelFactory(Protocol):"""A protocol defining the interface expected for kernel factories."""def__call__(self,searchspace:SearchSpace,train_x:Tensor,train_y:Tensor)->Kernel:"""Create a :class:`baybe.kernels.base.Kernel` for the given DOE context."""...
[docs]@define(frozen=True)classPlainKernelFactory(KernelFactory,SerialMixin):"""A trivial factory that returns a fixed pre-defined kernel upon request."""kernel:Kernel=field(validator=instance_of(Kernel))"""The fixed kernel to be returned by the factory."""def__call__(# noqa: D102self,searchspace:SearchSpace,train_x:Tensor,train_y:Tensor)->Kernel:# See base class.returnself.kernel
[docs]defto_kernel_factory(x:Kernel|KernelFactory,/)->KernelFactory:"""Wrap a kernel into a plain kernel factory (with factory passthrough)."""returnx.to_factory()ifisinstance(x,Kernel)elsex