"""Composite kernels (that is, kernels composed of other kernels)."""fromfunctoolsimportreducefromoperatorimportadd,mulfromattrsimportdefine,fieldfromattrs.convertersimportoptionalasoptional_cfromattrs.validatorsimportdeep_iterable,gt,instance_of,min_lenfromattrs.validatorsimportoptionalasoptional_vfrombaybe.kernels.baseimportCompositeKernel,Kernelfrombaybe.priors.baseimportPriorfrombaybe.utils.validationimportfinite_float
[docs]@define(frozen=True)classScaleKernel(CompositeKernel):"""A kernel for decorating existing kernels with an outputscale."""base_kernel:Kernel=field(validator=instance_of(Kernel))"""The base kernel that is being decorated."""outputscale_prior:Prior|None=field(default=None,validator=optional_v(instance_of(Prior)))"""An optional prior on the output scale."""outputscale_initial_value:float|None=field(default=None,converter=optional_c(float),validator=optional_v([finite_float,gt(0.0)]),)"""An optional initial value for the output scale."""
[docs]defto_gpytorch(self,*args,**kwargs):# noqa: D102# See base class.importtorchfrombaybe.utils.torchimportDTypeFloatTorchgpytorch_kernel=super().to_gpytorch(*args,**kwargs)if(initial_value:=self.outputscale_initial_value)isnotNone:gpytorch_kernel.outputscale=torch.tensor(initial_value,dtype=DTypeFloatTorch)returngpytorch_kernel
[docs]@define(frozen=True)classAdditiveKernel(CompositeKernel):"""A kernel representing the sum of a collection of base kernels."""base_kernels:tuple[Kernel,...]=field(converter=tuple,validator=deep_iterable(member_validator=instance_of(Kernel),iterable_validator=min_len(2)),)"""The individual kernels to be summed."""
[docs]defto_gpytorch(self,*args,**kwargs):# noqa: D102# See base class.returnreduce(add,(k.to_gpytorch(*args,**kwargs)forkinself.base_kernels))
[docs]@define(frozen=True)classProductKernel(CompositeKernel):"""A kernel representing the product of a collection of base kernels."""base_kernels:tuple[Kernel,...]=field(converter=tuple,validator=deep_iterable(member_validator=instance_of(Kernel),iterable_validator=min_len(2)),)"""The individual kernels to be multiplied."""
[docs]defto_gpytorch(self,*args,**kwargs):# noqa: D102# See base class.returnreduce(mul,(k.to_gpytorch(*args,**kwargs)forkinself.base_kernels))