Skip to content

Commit c0e1b78

Browse files
authored
Add support for dataclasses.InitVar as a type qualifier (#31)
1 parent 2ee3a8e commit c0e1b78

File tree

2 files changed

+60
-34
lines changed

2 files changed

+60
-34
lines changed

src/typing_inspection/introspection.py

+41-14
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import sys
66
import types
77
from collections.abc import Generator, Sequence
8+
from dataclasses import InitVar
89
from enum import Enum, IntEnum, auto
9-
from typing import Any, Literal, NamedTuple
10+
from typing import Any, Literal, NamedTuple, cast
1011

11-
from typing_extensions import TypeAlias, assert_never, get_origin
12+
from typing_extensions import TypeAlias, assert_never, get_args, get_origin
1213

1314
from . import typing_objects
1415

@@ -172,9 +173,11 @@ def get_literal_values(
172173
yield from (p for p, _ in dct)
173174

174175

175-
Qualifier: TypeAlias = Literal['required', 'not_required', 'read_only', 'class_var', 'final']
176+
Qualifier: TypeAlias = Literal['required', 'not_required', 'read_only', 'class_var', 'init_var', 'final']
176177
"""A [type qualifier][]."""
177178

179+
_all_qualifiers: set[Qualifier] = set(get_args(Qualifier))
180+
178181

179182
# TODO at some point, we could switch to an enum flag, so that multiple sources
180183
# can be combined. However, is there a need for this?
@@ -187,7 +190,7 @@ class AnnotationSource(IntEnum):
187190
Depending on the source, different [type qualifiers][type qualifier] may be (dis)allowed.
188191
"""
189192

190-
ASSIGNMENT_OR_VARIABLE = 1
193+
ASSIGNMENT_OR_VARIABLE = auto()
191194
"""An annotation used in an assignment or variable annotation:
192195
193196
```python
@@ -198,7 +201,7 @@ class AnnotationSource(IntEnum):
198201
**Allowed type qualifiers:** [`Final`][typing.Final].
199202
"""
200203

201-
CLASS = 2
204+
CLASS = auto()
202205
"""An annotation used in the body of a class:
203206
204207
```python
@@ -210,7 +213,20 @@ class Test:
210213
**Allowed type qualifiers:** [`ClassVar`][typing.ClassVar], [`Final`][typing.Final].
211214
"""
212215

213-
TYPED_DICT = 3
216+
DATACLASS = auto()
217+
"""An annotation used in the body of a dataclass:
218+
219+
```python
220+
@dataclass
221+
class Test:
222+
x: Final[int] = 1
223+
y: InitVar[str] = 'test'
224+
```
225+
226+
**Allowed type qualifiers:** [`ClassVar`][typing.ClassVar], [`Final`][typing.Final], [`InitVar`][dataclasses-init-only-variables].
227+
""" # noqa: E501
228+
229+
TYPED_DICT = auto()
214230
"""An annotation used in the body of a [`TypedDict`][typing.TypedDict]:
215231
216232
```python
@@ -223,7 +239,7 @@ class TD(TypedDict):
223239
[`NotRequired`][typing.NotRequired].
224240
"""
225241

226-
NAMED_TUPLE = 4
242+
NAMED_TUPLE = auto()
227243
"""An annotation used in the body of a [`NamedTuple`][typing.NamedTuple].
228244
229245
```python
@@ -235,7 +251,7 @@ class NT(NamedTuple):
235251
**Allowed type qualifiers:** none.
236252
"""
237253

238-
FUNCTION = 5
254+
FUNCTION = auto()
239255
"""An annotation used in a function, either for a parameter or the return value.
240256
241257
```python
@@ -246,13 +262,13 @@ def func(a: int) -> str:
246262
**Allowed type qualifiers:** none.
247263
"""
248264

249-
ANY = 6
265+
ANY = auto()
250266
"""An annotation that might come from any source.
251267
252268
**Allowed type qualifiers:** all.
253269
"""
254270

255-
BARE = 7
271+
BARE = auto()
256272
"""An annotation that is inspected as is.
257273
258274
**Allowed type qualifiers:** none.
@@ -266,12 +282,14 @@ def allowed_qualifiers(self) -> set[Qualifier]:
266282
return {'final'}
267283
elif self is AnnotationSource.CLASS:
268284
return {'final', 'class_var'}
285+
elif self is AnnotationSource.DATACLASS:
286+
return {'final', 'class_var', 'init_var'}
269287
elif self is AnnotationSource.TYPED_DICT:
270288
return {'required', 'not_required', 'read_only'}
271289
elif self in (AnnotationSource.NAMED_TUPLE, AnnotationSource.FUNCTION, AnnotationSource.BARE):
272290
return set()
273291
elif self is AnnotationSource.ANY:
274-
return {'required', 'not_required', 'read_only', 'class_var', 'final'}
292+
return _all_qualifiers
275293
else: # pragma: no cover
276294
assert_never(self)
277295

@@ -327,7 +345,7 @@ class C:
327345
"""The annotated metadata."""
328346

329347

330-
def inspect_annotation(
348+
def inspect_annotation( # noqa: PLR0915
331349
annotation: Any,
332350
/,
333351
*,
@@ -423,11 +441,15 @@ def inspect_annotation(
423441
else:
424442
# origin is not None but not a type qualifier nor `Annotated` (e.g. `list[int]`):
425443
break
444+
elif isinstance(annotation, InitVar):
445+
if 'init_var' not in allowed_qualifiers:
446+
raise ForbiddenQualifier('init_var')
447+
qualifiers.add('init_var')
448+
annotation = cast(Any, annotation.type)
426449
else:
427450
break
428451

429-
# `Final` and `ClassVar` are type qualifiers allowed to be used as a bare annotation
430-
# (`ClassVar` is not explicitly specified, but will be: https://discuss.python.org/t/81705).
452+
# `Final`, `ClassVar` and `InitVar` are type qualifiers allowed to be used as a bare annotation:
431453
if typing_objects.is_final(annotation):
432454
if 'final' not in allowed_qualifiers:
433455
raise ForbiddenQualifier('final')
@@ -438,6 +460,11 @@ def inspect_annotation(
438460
raise ForbiddenQualifier('class_var')
439461
qualifiers.add('class_var')
440462
annotation = UNKNOWN
463+
elif annotation is InitVar:
464+
if 'init_var' not in allowed_qualifiers:
465+
raise ForbiddenQualifier('init_var')
466+
qualifiers.add('init_var')
467+
annotation = UNKNOWN
441468

442469
return InspectedAnnotation(annotation, qualifiers, metadata)
443470

tests/introspection/test_inspect_annotation.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import typing as t
3+
from dataclasses import InitVar
34
from textwrap import dedent
45
from typing import Any, Literal
56

@@ -20,6 +21,7 @@ def test_unknown_repr() -> None:
2021
t_e.ReadOnly[int],
2122
t_e.Required[int],
2223
t_e.NotRequired[int],
24+
InitVar[int],
2325
]
2426

2527

@@ -28,6 +30,7 @@ def test_unknown_repr() -> None:
2830
[
2931
(AnnotationSource.ASSIGNMENT_OR_VARIABLE, [t_e.Final[int]]),
3032
(AnnotationSource.CLASS, [t_e.ClassVar[int], t.Final[int]]),
33+
(AnnotationSource.DATACLASS, [t_e.ClassVar[int], t.Final[int], InitVar[int]]),
3134
(AnnotationSource.TYPED_DICT, [t_e.ReadOnly[int], t_e.Required[int], t_e.NotRequired[int]]),
3235
(AnnotationSource.ANY, _all_qualifiers),
3336
],
@@ -42,10 +45,11 @@ def test_annotation_source_valid_qualifiers(source: AnnotationSource, annotation
4245
[
4346
(
4447
AnnotationSource.ASSIGNMENT_OR_VARIABLE,
45-
[t_e.ClassVar[int], t_e.ReadOnly[int], t_e.Required[int], t_e.NotRequired[int]],
48+
[t_e.ClassVar[int], t_e.ReadOnly[int], t_e.Required[int], t_e.NotRequired[int], InitVar[int]],
4649
),
47-
(AnnotationSource.CLASS, [t_e.ReadOnly[int], t_e.Required[int], t_e.NotRequired[int]]),
48-
(AnnotationSource.TYPED_DICT, [t_e.ClassVar[int], t_e.Final[int]]),
50+
(AnnotationSource.CLASS, [t_e.ReadOnly[int], t_e.Required[int], t_e.NotRequired[int], InitVar[int]]),
51+
(AnnotationSource.DATACLASS, [t_e.ReadOnly[int], t_e.Required[int], t_e.NotRequired[int]]),
52+
(AnnotationSource.TYPED_DICT, [t_e.ClassVar[int], t_e.Final[int], InitVar[int]]),
4953
(AnnotationSource.NAMED_TUPLE, _all_qualifiers),
5054
(AnnotationSource.FUNCTION, _all_qualifiers),
5155
(AnnotationSource.BARE, _all_qualifiers),
@@ -57,30 +61,25 @@ def test_annotation_source_invalid_qualifiers(source: AnnotationSource, annotati
5761
inspect_annotation(annotation, annotation_source=source)
5862

5963

60-
def test_final_bare_final_qualifier() -> None:
61-
result = inspect_annotation(
62-
t.Final,
63-
annotation_source=AnnotationSource.ANY,
64-
)
65-
66-
assert result.qualifiers == {'final'}
67-
assert result.type is UNKNOWN
68-
69-
with pytest.raises(ForbiddenQualifier):
70-
inspect_annotation(t.Final, annotation_source=AnnotationSource.BARE)
71-
72-
73-
def test_class_var_bare_final_qualifier() -> None:
64+
@pytest.mark.parametrize(
65+
['qualifier_obj', 'qualifier_str'],
66+
[
67+
(t.Final, 'final'),
68+
(t.ClassVar, 'class_var'),
69+
(InitVar, 'init_var'),
70+
],
71+
)
72+
def test_bare_qualifier(qualifier_obj: Any, qualifier_str: str) -> None:
7473
result = inspect_annotation(
75-
t.ClassVar,
74+
qualifier_obj,
7675
annotation_source=AnnotationSource.ANY,
7776
)
7877

79-
assert result.qualifiers == {'class_var'}
78+
assert result.qualifiers == {qualifier_str}
8079
assert result.type is UNKNOWN
8180

8281
with pytest.raises(ForbiddenQualifier):
83-
inspect_annotation(t.ClassVar, annotation_source=AnnotationSource.BARE)
82+
inspect_annotation(qualifier_obj, annotation_source=AnnotationSource.BARE)
8483

8584

8685
def test_nested_metadata_and_qualifiers() -> None:

0 commit comments

Comments
 (0)