Skip to content

Commit b32a602

Browse files
authored
Refactor strategies (#445)
1 parent a83a0c7 commit b32a602

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

tests/strategies.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ def cftime_arrays(
6565
return cftime.num2date(values, units=unit, calendar=cal)
6666

6767

68+
def insert_nans(draw: st.DrawFn, array: np.ndarray) -> np.ndarray:
69+
if array.dtype.kind in "cf":
70+
nan_idx = draw(
71+
st.lists(
72+
st.integers(min_value=0, max_value=array.shape[-1] - 1),
73+
max_size=array.shape[-1] - 1,
74+
unique=True,
75+
)
76+
)
77+
if nan_idx:
78+
array[..., nan_idx] = np.nan
79+
return array
80+
81+
6882
numeric_dtypes = (
6983
npst.integer_dtypes(endianness="=")
7084
| npst.unsigned_integer_dtypes(endianness="=")
@@ -96,20 +110,18 @@ def cftime_arrays(
96110
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
97111

98112
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
99-
numeric_arrays = npst.arrays(
100-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_dtypes
101-
)
102-
numeric_like_arrays = npst.arrays(
103-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_like_dtypes
104-
)
105-
all_arrays = (
106-
npst.arrays(
107-
elements={"allow_subnormal": False},
108-
shape=npst.array_shapes(),
109-
dtype=numeric_like_dtypes,
110-
)
111-
| cftime_arrays()
112-
)
113+
114+
115+
@st.composite
116+
def numpy_arrays(draw: st.DrawFn, *, dtype) -> np.ndarray:
117+
array = draw(npst.arrays(elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=dtype))
118+
array = insert_nans(draw, array)
119+
return array
120+
121+
122+
numeric_arrays = numpy_arrays(dtype=numeric_dtypes)
123+
numeric_like_arrays = numpy_arrays(dtype=numeric_like_dtypes)
124+
all_arrays = numeric_like_arrays | cftime_arrays()
113125

114126

115127
def by_arrays(
@@ -153,16 +165,4 @@ def chunked_arrays(
153165
) -> dask.array.Array:
154166
array = draw(arrays)
155167
chunks = draw(chunks(shape=array.shape))
156-
157-
if array.dtype.kind in "cf":
158-
nan_idx = draw(
159-
st.lists(
160-
st.integers(min_value=0, max_value=array.shape[-1] - 1),
161-
max_size=array.shape[-1] - 1,
162-
unique=True,
163-
)
164-
)
165-
if nan_idx:
166-
array[..., nan_idx] = np.nan
167-
168168
return from_array(array, chunks=chunks)

0 commit comments

Comments
 (0)