Skip to content

Commit b7b16fe

Browse files
committed
core: fix support for annotations defined with python3.12 'type' keyword
1 parent 5b8ed03 commit b7b16fe

File tree

4 files changed

+70
-3
lines changed

4 files changed

+70
-3
lines changed

src/cachew/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
TYPE_CHECKING,
1515
Any,
1616
Literal,
17+
TypeAliasType,
1718
cast,
1819
get_args,
1920
get_origin,
@@ -160,6 +161,16 @@ def infer_return_type(func) -> Failure | Inferred:
160161
>>> infer_return_type(union_provider)
161162
('multiple', str | int)
162163
164+
>>> from typing import Iterator
165+
>>> type Str = str
166+
>>> type Int = int
167+
>>> type IteratorStrInt = Iterator[Str | Int]
168+
>>> def iterator_str_int() -> IteratorStrInt:
169+
... yield 1
170+
... yield 'aaa'
171+
>>> infer_return_type(iterator_str_int)
172+
('multiple', Str | Int)
173+
163174
# a bit of an edge case
164175
>>> from typing import Tuple
165176
>>> def empty_tuple() -> Iterator[Tuple[()]]:
@@ -197,6 +208,10 @@ def infer_return_type(func) -> Failure | Inferred:
197208
if rtype is None:
198209
return f"no return type annotation on {func}"
199210

211+
if isinstance(rtype, TypeAliasType):
212+
# handle 'type ... = ...' aliases
213+
rtype = rtype.__value__
214+
200215
def bail(reason: str) -> str:
201216
return f"can't infer type from {rtype}: " + reason
202217

src/cachew/marshall/cachew.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
NamedTuple,
1515
Optional,
1616
Tuple,
17+
TypeAliasType,
1718
Union,
1819
get_args,
1920
get_origin,
@@ -289,6 +290,10 @@ def load(self, dct: str):
289290

290291

291292
def build_schema(Type) -> Schema:
293+
if isinstance(Type, TypeAliasType):
294+
# handle 'type ... = ...' aliases
295+
Type = Type.__value__
296+
292297
# just to avoid confusion in case of weirdness with stringish type annotations
293298
assert not isinstance(Type, str), Type
294299

@@ -414,6 +419,13 @@ def normalise(x):
414419
return (j, obj2)
415420

416421

422+
## this is used for test below...
423+
# however if we define this inside the test function, it fails if from __future__ import annotations is present on the file..
424+
type _IntType = int
425+
type _StrIntType = str | int
426+
##
427+
428+
417429
# TODO customise with cattrs
418430
def test_serialize_and_deserialize() -> None:
419431
import pytest
@@ -504,6 +516,22 @@ class WithJson:
504516
id: int
505517
raw_data: dict[str, Any]
506518

519+
## type aliases including new 3.12 type aliases
520+
# this works..
521+
StrInt = str | int
522+
helper('aaa', StrInt)
523+
524+
helper('aaa', _StrIntType)
525+
helper([1, 2, 3], list[_IntType])
526+
527+
@dataclass
528+
class TestTypeAlias:
529+
x: _IntType
530+
value: _StrIntType
531+
532+
helper(TestTypeAlias(x=1, value='aaa'), TestTypeAlias)
533+
##
534+
507535
# json-ish stuff
508536
helper({}, dict[str, Any])
509537
helper(WithJson(id=123, raw_data={'payload': 'whatever', 'tags': ['a', 'b', 'c']}), WithJson)

src/cachew/tests/test_cachew.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,3 +1439,25 @@ def fun_multiple() -> Iterable[int]:
14391439

14401440
assert (tmp_path / callable_name(fun_single)).exists()
14411441
assert (tmp_path / callable_name(fun_multiple)).exists()
1442+
1443+
1444+
def test_type_alias_type_1(tmp_path: Path) -> None:
1445+
type Int = int
1446+
1447+
@cachew(tmp_path)
1448+
def fun() -> Iterator[Int]:
1449+
yield 123
1450+
1451+
assert list(fun()) == [123]
1452+
assert list(fun()) == [123]
1453+
1454+
1455+
def test_type_alias_type_2(tmp_path: Path) -> None:
1456+
type IteratorInt = Iterator[int]
1457+
1458+
@cachew(tmp_path)
1459+
def fun() -> IteratorInt:
1460+
yield 123
1461+
1462+
assert list(fun()) == [123]
1463+
assert list(fun()) == [123]

src/cachew/tests/test_future_annotations.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
from .. import cachew
1616

17+
type _Str = str # deliberate, to test 3.12 'type ... = ...' type definitions
18+
1719

1820
# fmt: off
1921
@dataclass
2022
class NewStyleTypes1:
2123
a_str : str
2224
a_dict : dict[str, Any]
2325
a_list : list[Any]
24-
a_tuple : tuple[float, str]
26+
a_tuple : tuple[float, _Str]
2527
# fmt: on
2628

2729

@@ -45,7 +47,7 @@ def get() -> Iterator[NewStyleTypes1]:
4547
@dataclass
4648
class NewStyleTypes2:
4749
an_opt : str | None
48-
a_union : str | int
50+
a_union : _Str | int
4951
# fmt: on
5052

5153

@@ -102,7 +104,7 @@ def test_future_annotations(
102104
'''
103105

104106
_TEST = '''
105-
T = int
107+
type T = int
106108
107109
@cachew(td)
108110
def fun() -> list[T]:

0 commit comments

Comments
 (0)