CustomONNXSurrogate

class baybe.surrogates.custom.CustomONNXSurrogate[source]

Bases: IndependentGaussianSurrogate

A wrapper class for custom pretrained surrogate models.

Note that these surrogates cannot be retrained.

Public methods

__init__(*, onnx_input_name, onnx_str)

Method generated by attrs for class CustomONNXSurrogate.

default_model()

Instantiate the ONNX inference session.

fit(searchspace, objective, measurements)

Train the surrogate model on the provided data.

from_dict(dictionary)

Create an object from its dictionary representation.

from_json(string)

Create an object from its JSON representation.

posterior(candidates, /)

Compute the posterior for candidates in experimental representation.

to_botorch()

Create the botorch-ready representation of the fitted model.

to_dict()

Create an object's dictionary representation.

to_json()

Create an object's JSON representation.

validate_compatibility(searchspace)

Validate if the class is compatible with a given search space.

Public attributes and properties

onnx_input_name

The input name used for constructing the ONNX str.

onnx_str

The ONNX byte str representing the model.

supports_transfer_learning

Class variable encoding whether or not the surrogate supports transfer learning.

__init__(*, onnx_input_name: str, onnx_str: bytes)

Method generated by attrs for class CustomONNXSurrogate.

For details on the parameters, see Public attributes and properties.

default_model()[source]

Instantiate the ONNX inference session.

Return type:

InferenceSession

fit(searchspace: SearchSpace, objective: Objective, measurements: DataFrame)

Train the surrogate model on the provided data.

Parameters:
  • searchspace (SearchSpace) – The search space in which experiments are conducted.

  • objective (Objective) – The objective to be optimized.

  • measurements (DataFrame) – The training data in experimental representation.

Raises:
  • ValueError – If the search space contains task parameters but the selected surrogate model type does not support transfer learning.

  • NotImplementedError – When using a continuous search space and a non-GP model.

Return type:

None

classmethod from_dict(dictionary: dict)

Create an object from its dictionary representation.

Parameters:

dictionary (dict) – The dictionary representation.

Return type:

TypeVar(_T)

Returns:

The reconstructed object.

classmethod from_json(string: str)

Create an object from its JSON representation.

Parameters:

string (str) – The JSON representation of the object.

Return type:

TypeVar(_T)

Returns:

The reconstructed object.

posterior(candidates: DataFrame, /)

Compute the posterior for candidates in experimental representation.

Takes a dataframe of parameter configurations in experimental representation and returns the corresponding posterior object. Therefore, the method serves as the user-facing entry point for accessing model predictions.

Parameters:

candidates (DataFrame) – A dataframe containing parameter configurations in experimental representation.

Raises:

ModelNotTrainedError – When called before the model has been trained.

Return type:

Posterior

Returns:

A botorch.posteriors.Posterior object representing the posterior distribution at the given candidate points, where the posterior is also described in experimental representation. That is, the posterior values lie in the same domain as the modelled targets/objective on which the surrogate was trained via baybe.surrogates.base.Surrogate.fit().

to_botorch()

Create the botorch-ready representation of the fitted model.

The botorch.models.model.Model created by this method needs to be configured such that it can be called with candidate points in computational representation, that is, input of the form as obtained via baybe.searchspace.core.SearchSpace.transform().

Return type:

Model

to_dict()

Create an object’s dictionary representation.

Return type:

dict

to_json()

Create an object’s JSON representation.

Return type:

str

Returns:

The JSON representation as a string.

classmethod validate_compatibility(searchspace: SearchSpace)[source]

Validate if the class is compatible with a given search space.

Parameters:

searchspace (SearchSpace) – The search space to be tested for compatibility.

Raises:

TypeError – If the search space is incompatible with the class.

Return type:

None

onnx_input_name: str

The input name used for constructing the ONNX str.

onnx_str: bytes

The ONNX byte str representing the model.

supports_transfer_learning: ClassVar[bool] = False

Class variable encoding whether or not the surrogate supports transfer learning.