From d3e283c4037fc2ce2b4a618d2988e4b74ff688bb Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 23 May 2024 11:28:09 +0200 Subject: [PATCH] add return types to functions --- seedbank/__init__.py | 7 ++++--- seedbank/cupy.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/seedbank/__init__.py b/seedbank/__init__.py index 6e590b7..9310889 100644 --- a/seedbank/__init__.py +++ b/seedbank/__init__.py @@ -11,6 +11,7 @@ from types import ModuleType import numpy as np +from numpy.random import SeedSequence from typing_extensions import Optional from seedbank._keys import RNGKey, SeedLike, make_seed @@ -47,7 +48,7 @@ ] -def initialize(seed: SeedLike, *keys: RNGKey): +def initialize(seed: SeedLike, *keys: RNGKey) -> SeedSequence: """ Initialize the random infrastructure with a seed. This function should generally be called very early in the setup. This initializes all known and available RNGs with @@ -78,7 +79,7 @@ def initialize(seed: SeedLike, *keys: RNGKey): return _root_state.seed -def derive_seed(*keys: RNGKey, base: Optional[np.random.SeedSequence] = None): +def derive_seed(*keys: RNGKey, base: Optional[np.random.SeedSequence] = None) -> SeedSequence: """ Derive a seed from the root seed, optionally with additional seed keys. @@ -96,7 +97,7 @@ def derive_seed(*keys: RNGKey, base: Optional[np.random.SeedSequence] = None): return _root_state.derive(base, keys).seed -def root_seed(): +def root_seed() -> SeedSequence: """ Get the current root seed. diff --git a/seedbank/cupy.py b/seedbank/cupy.py index e4c387c..08bcbeb 100644 --- a/seedbank/cupy.py +++ b/seedbank/cupy.py @@ -29,7 +29,7 @@ def seed(state): cupy.random.seed(state.int_seed) -def cupy_rng(spec: Optional[SeedLike | cupy.random.Generator] = None): +def cupy_rng(spec: Optional[SeedLike | cupy.random.Generator] = None) -> cupy.random.Generator: """ Get a CuPy random number generator. This works like :func:`numpy_rng`, but it returns a :class:`cupy.random.Generator` instead.