Skip to content

Commit 3d611e2

Browse files
Add a test helper for fixing random seeds
1 parent 11d6bd4 commit 3d611e2

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

xobjects/test_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,23 @@ def requires_context(context_name: str):
8383
return lambda test_function: test_function
8484

8585
return pytest.mark.skip(f"{context_name} is unavailable on this platform.")
86+
87+
88+
def fix_random_seed(seed: int):
89+
"""Decorator to fix the random seed for a test."""
90+
91+
def decorator(test_function):
92+
@wraps(test_function)
93+
def wrapper(*args, **kwargs):
94+
import numpy as np
95+
96+
rng_state = np.random.get_state()
97+
try:
98+
np.random.seed(seed)
99+
test_function(*args, **kwargs)
100+
finally:
101+
np.random.set_state(rng_state)
102+
103+
return wrapper
104+
105+
return decorator

0 commit comments

Comments
 (0)