Skip to content

Commit 4111bc3

Browse files
Add a test helper for fixing random seeds
1 parent 11d6bd4 commit 4111bc3

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

xobjects/test_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,19 @@ def requires_context(context_name: str):
8383
return lambda test_function: test_function
8484

8585
return pytest.mark.skip(f"{context_name} is unavailable on this platform.")
86+
87+
88+
def fix_random_seed(seed: int):
89+
"""Decorator to fix the random seed for a test."""
90+
def decorator(test_function):
91+
@wraps(test_function)
92+
def wrapper(*args, **kwargs):
93+
import numpy as np
94+
rng_state = np.random.get_state()
95+
try:
96+
np.random.seed(seed)
97+
test_function(*args, **kwargs)
98+
finally:
99+
np.random.set_state(rng_state)
100+
return wrapper
101+
return decorator

0 commit comments

Comments
 (0)