Skip to content

Commit 369e2f9

Browse files
committed
Support numpy StringDType in from_dtype
from_dtype now natively supports numpy 2.0+ variable-width strings (StringDType, kind 'T'), generating arbitrary strings via text(). Passing a dtype class such as np.dtypes.StringDType where an instance was expected now raises a clear error, instead of the previous confusing message (or silent coercion to the object dtype in arrays).
1 parent b1cc932 commit 369e2f9

3 files changed

Lines changed: 75 additions & 0 deletions

File tree

hypothesis/RELEASE.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
RELEASE_TYPE: minor
2+
3+
:func:`~hypothesis.extra.numpy.from_dtype` now supports the variable-width
4+
string dtype :attr:`numpy:numpy.dtypes.StringDType`, generating arbitrary
5+
strings via :func:`~hypothesis.strategies.text` (:issue:`4039`).
6+
7+
Additionally, passing a dtype *class* such as ``np.dtypes.StringDType`` where an
8+
instance like ``np.dtypes.StringDType()`` was expected now raises a clear error,
9+
rather than the previous confusing message (or silent coercion to the object
10+
dtype in :func:`~hypothesis.extra.numpy.arrays`).

hypothesis/src/hypothesis/extra/numpy.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,19 @@ def _try_import(mod_name: str, attr_name: str) -> Any:
110110
NP_FIXED_UNICODE = tuple(int(x) for x in np.__version__.split(".")[:2]) >= (1, 19)
111111

112112

113+
def _reject_dtype_class(dtype: object) -> None:
114+
# A common mistake is to pass a dtype *class*, e.g. np.dtypes.StringDType,
115+
# rather than an instance such as np.dtypes.StringDType(). numpy silently
116+
# coerces such classes to the object dtype, so we reject them with a more
117+
# helpful message than the resulting confusion further down the line.
118+
if isinstance(dtype, type) and issubclass(dtype, np.dtype):
119+
name = getattr(dtype, "__name__", repr(dtype))
120+
raise InvalidArgument(
121+
f"Cannot infer a strategy from the dtype class {name}; pass an "
122+
f"instance instead, e.g. {name}() rather than {name}."
123+
)
124+
125+
113126
@defines_strategy(force_reusable_values=True)
114127
def from_dtype(
115128
dtype: np.dtype,
@@ -137,6 +150,7 @@ def from_dtype(
137150
:func:`arrays` which allow a variety of numeric dtypes, as it seamlessly
138151
handles the ``width`` or representable bounds for you.
139152
"""
153+
_reject_dtype_class(dtype)
140154
check_type(np.dtype, dtype, "dtype")
141155
kwargs = {k: v for k, v in locals().items() if k != "dtype" and v is not None}
142156

@@ -214,6 +228,14 @@ def compat_kw(*args, **kw):
214228
result = st.text(**compat_kw("alphabet", "min_size", max_size=max_size)).filter(
215229
lambda b: b[-1:] != "\0"
216230
)
231+
elif dtype.kind == "T":
232+
# NumPy 2.0+ variable-width strings (StringDType). Unlike the fixed-width
233+
# "U"/"S" dtypes, these store arbitrary Python strings with no length
234+
# limit and no null-termination, so we can use st.text() directly - but
235+
# the UTF-8 backing storage means we must exclude lone surrogates.
236+
if "alphabet" not in kwargs:
237+
kwargs["alphabet"] = st.characters(codec="utf-8")
238+
result = st.text(**compat_kw("alphabet", "min_size", "max_size"))
217239
elif dtype.kind in ("m", "M"):
218240
if "[" in dtype.str:
219241
res = st.just(dtype.str.split("[")[-1][:-1])
@@ -555,6 +577,7 @@ def arrays(
555577
lambda s: arrays(dtype, s, elements=elements, fill=fill, unique=unique)
556578
)
557579
# From here on, we're only dealing with values and it's relatively simple.
580+
_reject_dtype_class(dtype)
558581
dtype = np.dtype(dtype) # type: ignore[arg-type]
559582
assert isinstance(dtype, np.dtype) # help mypy out a bit...
560583
if elements is None or isinstance(elements, Mapping):

hypothesis/tests/numpy/test_from_dtype.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323

2424
np_version = tuple(int(x) for x in np.__version__.split(".")[:2])
2525

26+
skipif_no_stringdtype = pytest.mark.skipif(
27+
not hasattr(getattr(np, "dtypes", None), "StringDType"),
28+
reason="StringDType was added in NumPy 2.0",
29+
)
30+
31+
2632
STANDARD_TYPES = [
2733
np.dtype(t)
2834
for t in (
@@ -281,6 +287,42 @@ def test_float_subnormal_generation(allow_subnormal, width):
281287
assert_no_examples(strat, lambda n: -smallest_normal < n < smallest_normal)
282288

283289

290+
@skipif_no_stringdtype
291+
@given(st.data())
292+
def test_stringdtype_generates_strings(data):
293+
dt = np.dtypes.StringDType()
294+
result = data.draw(nps.from_dtype(dt))
295+
assert isinstance(result, str)
296+
297+
298+
@skipif_no_stringdtype
299+
@given(st.data())
300+
def test_stringdtype_respects_kwargs(data):
301+
dt = np.dtypes.StringDType()
302+
result = data.draw(nps.from_dtype(dt, min_size=2, max_size=4, alphabet="abc"))
303+
assert 2 <= len(result) <= 4
304+
assert set(result).issubset("abc")
305+
306+
307+
@skipif_no_stringdtype
308+
@given(st.data())
309+
def test_stringdtype_arrays_roundtrip(data):
310+
# StringDType stores arbitrary strings, so anything we generate (including
311+
# null bytes and arbitrarily-long strings) must read back unchanged.
312+
dt = np.dtypes.StringDType()
313+
ex = data.draw(nps.from_dtype(dt))
314+
arr = np.array([""], dtype=dt)
315+
arr[0] = ex
316+
assert arr[0] == ex
317+
318+
319+
@skipif_no_stringdtype
320+
@pytest.mark.parametrize("func", [nps.from_dtype, lambda dt: nps.arrays(dt, 3)])
321+
def test_helpful_error_on_uninstantiated_dtype_class(func):
322+
with pytest.raises(InvalidArgument, match="dtype class StringDType"):
323+
check_can_generate_examples(func(np.dtypes.StringDType))
324+
325+
284326
@pytest.mark.parametrize("allow_subnormal", [False, True])
285327
@pytest.mark.parametrize("width", [64, 128])
286328
def test_complex_subnormal_generation(allow_subnormal, width):

0 commit comments

Comments
 (0)