"""Validation functionality for surrogates."""from__future__importannotationsfromcollections.abcimportCallablefromtypingimportAnyfrombaybe.surrogates.baseimportSurrogate
[docs]defvalidate_custom_architecture_cls(model_cls:type)->None:"""Validate a custom architecture to have the correct attributes. Args: model_cls: The user defined model class. Raises: ValueError: When model_cls does not have _fit or _posterior. ValueError: When _fit or _posterior is not a callable method. ValueError: When _fit does not have the required signature. ValueError: When _posterior does not have the required signature. """# Methods must existifnot(hasattr(model_cls,"_fit")andhasattr(model_cls,"_posterior")):raiseValueError("`_fit` and a `_posterior` must exist for custom architectures")fit=model_cls._fitposterior=model_cls._posterior# They must be methodsifnot(callable(fit)andcallable(posterior)):raiseValueError("`_fit` and a `_posterior` must be methods for custom architectures")# Methods must have the correct argumentsparams=fit.__code__.co_varnames[:fit.__code__.co_argcount]ifparams!=Surrogate._fit.__code__.co_varnames:raiseValueError("Invalid args in `_fit` method definition for custom architecture. ""Please refer to Surrogate._fit for the required function signature.")params=posterior.__code__.co_varnames[:posterior.__code__.co_argcount]ifparams!=Surrogate._posterior.__code__.co_varnames:raiseValueError("Invalid args in `_posterior` method definition for custom architecture. ""Please refer to Surrogate._posterior for the required function signature.")
[docs]defget_model_params_validator(model_init:Callable|None=None)->Callable:"""Construct a validator based on the model class. Args: model_init: The init method for the model. Returns: A validator function to validate parameters. """defvalidate_model_params(# noqa: DOC101, DOC103obj:Any,_:Any,model_params:dict)->None:"""Validate the model params attribute of an object. Raises: ValueError: When model params are given for non-supported objects. ValueError: When surrogate is not recognized (no valid model_init). ValueError: When invalid params are given for a model. """# Get model class namemodel=obj.__class__.__name__ifnotmodel_params:return# GP does not support additional model params# Neither does custom modelsif"GaussianProcess"inmodelor"Custom"inmodel:raiseValueError(f"{model} does not support model params.")ifnotmodel_init:raiseValueError(f"Cannot validate model params for unrecognized Surrogate: {model}")# Invalid paramsinvalid_params=", ".join([keyforkeyinmodel_params.keys()ifkeynotinmodel_init.__code__.co_varnames])ifinvalid_params:raiseValueError(f"Invalid model params for {model}: {invalid_params}.")returnvalidate_model_params