Example for using custom constraints in discrete searchspaces¶
This examples shows how a custom constraint can be created for a discrete searchspace. That is, it shows how the user can define a constraint restricting the searchspace.
This example assumes some basic familiarity with using BayBE.
We thus refer to campaign
for a basic example.
Necessary imports for this example¶
import os
import numpy as np
import pandas as pd
from baybe import Campaign
from baybe.constraints import DiscreteCustomConstraint
from baybe.objectives import SingleTargetObjective
from baybe.parameters import (
CategoricalParameter,
NumericalDiscreteParameter,
SubstanceParameter,
)
from baybe.searchspace import SearchSpace
from baybe.targets import NumericalTarget
from baybe.utils.dataframe import add_fake_measurements
Experiment setup¶
We begin by setting up some parameters for our experiments.
TEMPERATURE_RESOLUTION
describes the number of different temperatures used.
SMOKE_TEST = "SMOKE_TEST" in os.environ
TEMPERATURE_RESOLUTION = 3 if SMOKE_TEST else 10
dict_solvent = {
"water": "O",
"C1": "C",
"C2": "CC",
"C3": "CCC",
"C4": "CCCC",
"C5": "CCCCC",
"c6": "c1ccccc1",
"C6": "CCCCCC",
}
solvent = SubstanceParameter("Solvent", data=dict_solvent, encoding="RDKIT")
speed = CategoricalParameter(
"Speed", values=["very slow", "slow", "normal", "fast", "very fast"], encoding="INT"
)
temperature = NumericalDiscreteParameter(
"Temperature",
values=list(np.linspace(100, 200, TEMPERATURE_RESOLUTION)),
tolerance=0.5,
)
concentration = NumericalDiscreteParameter(
"Concentration", values=[1, 2, 5, 10], tolerance=0.4
)
parameters = [solvent, speed, temperature, concentration]
Creating the constraint¶
The constraints are handled when creating the searchspace object. We thus need to define our constraint first as follows.
def custom_function(df: pd.DataFrame) -> pd.Series:
"""This constraint implements a custom user-defined filter/validation
functionality.""" # noqa: D401
# Situation 1: We only want entries where the solvent water is used with
# temperatures <= 120 and concentrations <= 5
mask_bad1 = (
(df["Solvent"] == "water")
& (df["Temperature"] > 120)
& (df["Concentration"] > 5)
)
# Situation 2: We only want entries where the solvent C2 is used with
# temperatures <= 180 and concentrations <= 3
mask_bad2 = (
(df["Solvent"] == "C2") & (df["Temperature"] > 180) & (df["Concentration"] > 3)
)
# Situation 3: We only want entries where the solvent C3 is used with
# temperatures <= 150 and concentrations <= 3
mask_bad3 = (
(df["Solvent"] == "C3") & (df["Temperature"] > 150) & (df["Concentration"] > 3)
)
# Combine all situations
mask_good = ~(mask_bad1 | mask_bad2 | mask_bad3)
return mask_good
We now initialize the CustomConstraint
with all parameters this function should have
access to.
constraint = DiscreteCustomConstraint(
parameters=["Concentration", "Solvent", "Temperature"], validator=custom_function
)
Creating the searchspace and the objective¶
searchspace = SearchSpace.from_product(parameters=parameters, constraints=[constraint])
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
[14:25:35] DEPRECATION WARNING: please use MorganGenerator
objective = SingleTargetObjective(target=NumericalTarget(name="yield", mode="MAX"))
Creating and printing the campaign¶
campaign = Campaign(searchspace=searchspace, objective=objective)
print(campaign)
Campaign
Meta Data
Batches done: 0
Fits done: 0
SearchSpace
Search Space Type: DISCRETE
SubspaceDiscrete
Discrete Parameters
Name Type Num_Values Encoding
0 Concentration NumericalDis... 4 None
1 Solvent SubstancePar... 8 SubstanceEnc...
2 Speed CategoricalP... 5 CategoricalE...
3 Temperature NumericalDis... 3 None
Experimental Representation
Solvent Speed Temperature Concentration
0 C1 fast 100.0 1.0
1 C1 fast 100.0 2.0
2 C1 fast 100.0 5.0
.. ... ... ... ...
447 water very slow 200.0 1.0
448 water very slow 200.0 2.0
449 water very slow 200.0 5.0
[450 rows x 4 columns]
Meta Data
was_recommended: 0/450
was_measured: 0/450
dont_recommend: 0/450
Constraints
Type Affected_Paramet
0 DiscreteCust... [Concentrati...
Computational Representation
Concentration Solvent_RDKIT_Ma ... Speed Temperature
0 1.0 0.0 ... 0.0 100.0
1 2.0 0.0 ... 0.0 100.0
2 5.0 0.0 ... 0.0 100.0
.. ... ... ... ... ...
447 1.0 0.0 ... 4.0 200.0
448 2.0 0.0 ... 4.0 200.0
449 5.0 0.0 ... 4.0 200.0
[450 rows x 9 columns]
Objective
Type: SingleTargetObjective
Targets
Type Name ... Upper_Bound Transformation
0 NumericalTarget yield ... inf None
[1 rows x 6 columns]
TwoPhaseMetaRecommender
Initial recommender
RandomRecommender
Compatibility: SearchSpaceType.HYBRID
Recommender
BotorchRecommender
Surrogate
GaussianProcessSurrogate
Supports Transfer Learning: True
Kernel factory: DefaultKernelFactory()
Acquisition function: qLogExpectedImprovement()
Compatibility: SearchSpaceType.HYBRID
Sequential continuous: False
Hybrid sampler: None
Sampling percentage: 1.0
Switch after: 1
Manual verification of the constraint¶
The following loop performs some recommendations and manually verifies the given constraints.
N_ITERATIONS = 3
for kIter in range(N_ITERATIONS):
print(f"\n\n#### ITERATION {kIter+1} ####")
print("## ASSERTS ##")
print(
"Number of entries with water, temp > 120 and concentration > 5: ",
(
campaign.searchspace.discrete.exp_rep["Concentration"].apply(
lambda x: x > 5
)
& campaign.searchspace.discrete.exp_rep["Temperature"].apply(
lambda x: x > 120
)
& campaign.searchspace.discrete.exp_rep["Solvent"].eq("water")
).sum(),
)
print(
"Number of entries with C2, temp > 180 and concentration > 3: ",
(
campaign.searchspace.discrete.exp_rep["Concentration"].apply(
lambda x: x > 3
)
& campaign.searchspace.discrete.exp_rep["Temperature"].apply(
lambda x: x > 180
)
& campaign.searchspace.discrete.exp_rep["Solvent"].eq("C2")
).sum(),
)
print(
"Number of entries with C3, temp > 150 and concentration > 3: ",
(
campaign.searchspace.discrete.exp_rep["Concentration"].apply(
lambda x: x > 3
)
& campaign.searchspace.discrete.exp_rep["Temperature"].apply(
lambda x: x > 150
)
& campaign.searchspace.discrete.exp_rep["Solvent"].eq("C3")
).sum(),
)
rec = campaign.recommend(batch_size=5)
add_fake_measurements(rec, campaign.targets)
campaign.add_measurements(rec)
#### ITERATION 1 ####
## ASSERTS ##
Number of entries with water, temp > 120 and concentration > 5: 0
Number of entries with C2, temp > 180 and concentration > 3: 0
Number of entries with C3, temp > 150 and concentration > 3: 0
#### ITERATION 2 ####
## ASSERTS ##
Number of entries with water, temp > 120 and concentration > 5: 0
Number of entries with C2, temp > 180 and concentration > 3: 0
Number of entries with C3, temp > 150 and concentration > 3: 0
#### ITERATION 3 ####
## ASSERTS ##
Number of entries with water, temp > 120 and concentration > 5: 0
Number of entries with C2, temp > 180 and concentration > 3: 0
Number of entries with C3, temp > 150 and concentration > 3: 0