[docs]@defineclassBetaBernoulliMultiArmedBanditSurrogate(Surrogate):"""A multi-armed bandit model with Bernoulli likelihood and beta prior."""supports_transfer_learning:ClassVar[bool]=False# See base class.prior:BetaPrior=field(factory=lambda:BetaPrior(1,1))"""The beta prior for the win rates of the bandit arms. Uniform by default."""# TODO: type should be `torch.Tensor | None` but is currently# omitted due to: https://github.com/python-attrs/cattrs/issues/531_win_lose_counts=field(init=False,default=None,eq=False)"""Sufficient statistics for the Bernoulli likelihood model: (# wins, # losses)."""
[docs]defposterior_modes(self)->Tensor:"""Compute the posterior mode win rates for all arms. Returns: A tensor of length ``N`` containing the posterior mode estimates of the win rates, where ``N`` is the number of bandit arms. Contains ``float('nan')`` for arms with undefined mode. """fromtorch.distributionsimportBetareturnBeta(*self._posterior_beta_parameters()).mode
[docs]defposterior_means(self)->Tensor:"""Compute the posterior mean win rates for all arms. Returns: A tensor of length ``N`` containing the posterior mean estimates of the win rates, where ``N`` is the number of bandit arms. """fromtorch.distributionsimportBetareturnBeta(*self._posterior_beta_parameters()).mean
def_posterior_beta_parameters(self)->Tensor:"""Compute the posterior parameters of the beta distribution. Raises: ModelNotTrainedError: If accessed before the model was trained. Returns: A tensors of shape ``(2, N)`` containing the posterior beta parameters, where ``N`` is the number of bandit arms. """ifself._win_lose_countsisNone:raiseModelNotTrainedError(f"'{self.__class__.__name__}' must be trained before posterior "f"information can be accessed.")importtorchreturnself._win_lose_counts+torch.tensor([self.prior.alpha,self.prior.beta]).unsqueeze(-1)
[docs]defto_botorch(self)->Model:# noqa: D102# See base class.# We register the sampler on the fly to avoid eager loading of torchfrombotorch.sampling.baseimportMCSamplerfrombotorch.sampling.get_samplerimportGetSamplerfromtorch.distributionsimportBetaclassCustomMCSampler(MCSampler):"""Customer sampler for beta posterior."""defforward(self,posterior:TorchPosterior)->Tensor:"""Sample the posterior."""withtemporary_seed(self.seed):samples=posterior.rsample(self.sample_shape)returnsamples@GetSampler.register(Beta)defget_custom_sampler(_,sample_shape,seed:int|None=None):"""Get the sampler for the beta posterior."""returnCustomMCSampler(sample_shape=sample_shape,seed=seed)returnsuper().to_botorch()
@staticmethoddef_make_input_scaler_factory():# See base class.## Due to enforced one-hot encoding, no input scaling is needed.returnNone@staticmethoddef_make_target_scaler_factory():# See base class.## We directly use the binary computational representation from the target.returnNonedef_posterior(self,candidates:Tensor,/)->TorchPosterior:# See base class.frombotorch.posteriorsimportTorchPosteriorfromtorch.distributionsimportBetabeta_params_for_candidates=self._posterior_beta_parameters().T[candidates.argmax(-1)]returnTorchPosterior(Beta(*beta_params_for_candidates.split(1,-1)))def_fit(self,train_x:Tensor,train_y:Tensor,_:Any=None)->None:# See base class.# TODO: Fix requirement of OHE encoding. This is likely a long-term goal since# probably requires decoupling parameter from encodings and associating the# latter with the surrogate.# TODO: Generalize to arbitrary number of categorical parametersmatchself._searchspace:caseSearchSpace(parameters=[CategoricalParameter(encoding=CategoricalEncoding.OHE)]):passcase_:raiseIncompatibleSearchSpaceError(f"'{self.__class__.__name__}' currently only supports search "f"spaces spanned by exactly one categorical parameter using "f"one-hot encoding.")importtorch# IMPROVE: The training inputs/targets can actually be represented as# integers/boolean values but the transformation pipeline currently# converts them float. Potentially, this can be improved by making# the type conversion configurable.wins=(train_x*(train_y==float(_SUCCESS_VALUE_COMP))).sum(dim=0)losses=(train_x*(train_y==float(_FAILURE_VALUE_COMP))).sum(dim=0)self._win_lose_counts=torch.vstack([wins,losses]).to(torch.int)def__str__(self)->str:fields=[to_string("Prior",self.prior,single_line=True)]returnto_string(super().__str__(),*fields)
# Collect leftover original slotted classes processed by `attrs.define`gc.collect()