Example for using custom constraints in discrete searchspaces

This examples shows how a custom constraint can be created for a discrete searchspace. That is, it shows how the user can define a constraint restricting the searchspace.

This example assumes some basic familiarity with using BayBE. We thus refer to campaign for a basic example.

Necessary imports for this example

import os
import numpy as np
import pandas as pd
from baybe import Campaign
from baybe.constraints import DiscreteCustomConstraint
from baybe.objectives import SingleTargetObjective
from baybe.parameters import (
    CategoricalParameter,
    NumericalDiscreteParameter,
    SubstanceParameter,
)
from baybe.searchspace import SearchSpace
from baybe.targets import NumericalTarget
from baybe.utils.dataframe import add_fake_results

Experiment setup

We begin by setting up some parameters for our experiments. TEMPERATURE_RESOLUTION describes the number of different temperatures used.

SMOKE_TEST = "SMOKE_TEST" in os.environ
TEMPERATURE_RESOLUTION = 3 if SMOKE_TEST else 10
dict_solvent = {
    "water": "O",
    "C1": "C",
    "C2": "CC",
    "C3": "CCC",
    "C4": "CCCC",
    "C5": "CCCCC",
    "c6": "c1ccccc1",
    "C6": "CCCCCC",
}
solvent = SubstanceParameter("Solvent", data=dict_solvent, encoding="RDKIT")
speed = CategoricalParameter(
    "Speed", values=["very slow", "slow", "normal", "fast", "very fast"], encoding="INT"
)
temperature = NumericalDiscreteParameter(
    "Temperature",
    values=list(np.linspace(100, 200, TEMPERATURE_RESOLUTION)),
    tolerance=0.5,
)
concentration = NumericalDiscreteParameter(
    "Concentration", values=[1, 2, 5, 10], tolerance=0.4
)
parameters = [solvent, speed, temperature, concentration]

Creating the constraint

The constraints are handled when creating the searchspace object. We thus need to define our constraint first as follows.

def custom_function(df: pd.DataFrame) -> pd.Series:
    """This constraint implements a custom user-defined filter/validation
functionality."""  # noqa: D401
    # Situation 1: We only want entries where the solvent water is used with
    # temperatures <= 120 and concentrations <= 5
    mask_bad1 = (
        (df["Solvent"] == "water")
        & (df["Temperature"] > 120)
        & (df["Concentration"] > 5)
    )

    # Situation 2: We only want entries where the solvent C2 is used with
    # temperatures <= 180 and concentrations <= 3
    mask_bad2 = (
        (df["Solvent"] == "C2") & (df["Temperature"] > 180) & (df["Concentration"] > 3)
    )

    # Situation 3: We only want entries where the solvent C3 is used with
    # temperatures <= 150 and concentrations <= 3
    mask_bad3 = (
        (df["Solvent"] == "C3") & (df["Temperature"] > 150) & (df["Concentration"] > 3)
    )

    # Combine all situations
    mask_good = ~(mask_bad1 | mask_bad2 | mask_bad3)

    return mask_good

We now initialize the CustomConstraint with all parameters this function should have access to.

constraint = DiscreteCustomConstraint(
    parameters=["Concentration", "Solvent", "Temperature"], validator=custom_function
)

Creating the searchspace and the objective

searchspace = SearchSpace.from_product(parameters=parameters, constraints=[constraint])
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
[17:00:47] DEPRECATION WARNING: please use MorganGenerator
objective = SingleTargetObjective(target=NumericalTarget(name="yield", mode="MAX"))

Creating and printing the campaign

campaign = Campaign(searchspace=searchspace, objective=objective)
print(campaign)
Campaign
   Meta Data
      Batches done: 0
      Fits done: 0
   SearchSpace
      Search Space Type: DISCRETE
      SubspaceDiscrete
         Discrete Parameters
                        Name             Type  Num_Values         Encoding
            0  Concentration  NumericalDis...           4             None
            1        Solvent  SubstancePar...           8  SubstanceEnc...
            2          Speed  CategoricalP...           5  CategoricalE...
            3    Temperature  NumericalDis...           3             None
         Experimental Representation
                Solvent      Speed  Temperature  Concentration
            0        C1       fast        100.0            1.0
            1        C1       fast        100.0            2.0
            2        C1       fast        100.0            5.0
            ..      ...        ...          ...            ...
            447   water  very slow        200.0            1.0
            448   water  very slow        200.0            2.0
            449   water  very slow        200.0            5.0
            
            [450 rows x 4 columns]
         Meta Data
            was_recommended: 0/450
            was_measured: 0/450
            dont_recommend: 0/450
         Constraints
                          Type Affected_Paramet
            0  DiscreteCust...  [Concentrati...
         Computational Representation
                 Concentration  Solvent_RDKIT_Ma  ...  Speed  Temperature
            0              1.0              0.0   ...    0.0        100.0
            1              2.0              0.0   ...    0.0        100.0
            2              5.0              0.0   ...    0.0        100.0
            ..             ...              ...   ...    ...          ...
            447            1.0              0.0   ...    4.0        200.0
            448            2.0              0.0   ...    4.0        200.0
            449            5.0              0.0   ...    4.0        200.0
            
            [450 rows x 9 columns]
   Objective
      Type: SingleTargetObjective
      Targets
                       Type   Name  ... Upper_Bound  Transformation
         0  NumericalTarget  yield  ...         inf            None
         
         [1 rows x 6 columns]
   TwoPhaseMetaRecommender
      Initial recommender
         RandomRecommender
            Compatibility: SearchSpaceType.HYBRID
      Recommender
         BotorchRecommender
            Surrogate
               GaussianProcessSurrogate
                  Supports Transfer Learning: True
                  Kernel factory: DefaultKernelFactory()
            Acquisition function: qLogExpectedImprovement()
            Compatibility: SearchSpaceType.HYBRID
            Sequential continuous: False
            Hybrid sampler: None
            Sampling percentage: 1.0
      Switch after: 1

Manual verification of the constraint

The following loop performs some recommendations and manually verifies the given constraints.

N_ITERATIONS = 3
for kIter in range(N_ITERATIONS):
    print(f"\n\n#### ITERATION {kIter+1} ####")

    print("## ASSERTS ##")
    print(
        "Number of entries with water, temp > 120 and concentration > 5:      ",
        (
            campaign.searchspace.discrete.exp_rep["Concentration"].apply(
                lambda x: x > 5
            )
            & campaign.searchspace.discrete.exp_rep["Temperature"].apply(
                lambda x: x > 120
            )
            & campaign.searchspace.discrete.exp_rep["Solvent"].eq("water")
        ).sum(),
    )
    print(
        "Number of entries with C2, temp > 180 and concentration > 3:         ",
        (
            campaign.searchspace.discrete.exp_rep["Concentration"].apply(
                lambda x: x > 3
            )
            & campaign.searchspace.discrete.exp_rep["Temperature"].apply(
                lambda x: x > 180
            )
            & campaign.searchspace.discrete.exp_rep["Solvent"].eq("C2")
        ).sum(),
    )
    print(
        "Number of entries with C3, temp > 150 and concentration > 3:         ",
        (
            campaign.searchspace.discrete.exp_rep["Concentration"].apply(
                lambda x: x > 3
            )
            & campaign.searchspace.discrete.exp_rep["Temperature"].apply(
                lambda x: x > 150
            )
            & campaign.searchspace.discrete.exp_rep["Solvent"].eq("C3")
        ).sum(),
    )

    rec = campaign.recommend(batch_size=5)
    add_fake_results(rec, campaign.targets)
    campaign.add_measurements(rec)
#### ITERATION 1 ####
## ASSERTS ##
Number of entries with water, temp > 120 and concentration > 5:       0
Number of entries with C2, temp > 180 and concentration > 3:          0
Number of entries with C3, temp > 150 and concentration > 3:          0


#### ITERATION 2 ####
## ASSERTS ##
Number of entries with water, temp > 120 and concentration > 5:       0
Number of entries with C2, temp > 180 and concentration > 3:          0
Number of entries with C3, temp > 150 and concentration > 3:          0




#### ITERATION 3 ####
## ASSERTS ##
Number of entries with water, temp > 120 and concentration > 5:       0
Number of entries with C2, temp > 180 and concentration > 3:          0
Number of entries with C3, temp > 150 and concentration > 3:          0