"""Utilities targeting random number generation."""importcontextlibimportrandomimportnumpyasnp
[docs]defset_random_seed(seed:int):"""Set the global random seed. Args: seed: The chosen global random seed. """importtorch# Ensure seed limitsseed%=2**32torch.manual_seed(seed)random.seed(seed)np.random.seed(seed)
[docs]@contextlib.contextmanagerdeftemporary_seed(seed:int):# noqa: DOC402, DOC404"""Context manager for setting a temporary random seed. Args: seed: The chosen random seed. """importtorch# Ensure seed limitsseed%=2**32# Collect the current RNG statesstate_builtin=random.getstate()state_np=np.random.get_state()state_torch=torch.get_rng_state()# Set the requested seedset_random_seed(seed)# Run the context-specific codetry:yield# Restore the original RNG statesfinally:random.setstate(state_builtin)np.random.set_state(state_np)torch.set_rng_state(state_torch)