Skip to content

Commit 5d2efdb

Browse files
authored
Merge pull request #119 from mkurnikov/subclass-queryset-proper-typing
Allow to subclass queryset without loss of typing
2 parents ac40b80 + 27793ec commit 5d2efdb

File tree

4 files changed

+43
-14
lines changed

4 files changed

+43
-14
lines changed

Diff for: mypy_django_plugin/transformers/meta.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypy.types import TypeOfAny
66

77
from mypy_django_plugin.django.context import DjangoContext
8-
from mypy_django_plugin.lib import fullnames, helpers
8+
from mypy_django_plugin.lib import helpers
99

1010

1111
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
@@ -20,21 +20,25 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context:
2020
# Options instance
2121
assert isinstance(ctx.type, Instance)
2222

23+
# bail if list of generic params is empty
24+
if len(ctx.type.args) == 0:
25+
return ctx.default_return_type
26+
2327
model_type = ctx.type.args[0]
2428
if not isinstance(model_type, Instance):
25-
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
29+
return ctx.default_return_type
2630

2731
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
2832
if model_cls is None:
29-
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
33+
return ctx.default_return_type
3034

3135
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
3236
if field_name_expr is None:
33-
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
37+
return ctx.default_return_type
3438

3539
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
3640
if field_name is None:
37-
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
41+
return ctx.default_return_type
3842

3943
try:
4044
field = model_cls._meta.get_field(field_name)

Diff for: mypy_django_plugin/transformers/querysets.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,27 @@
33

44
from django.core.exceptions import FieldError
55
from django.db.models.base import Model
6+
from django.db.models.fields.related import RelatedField
67
from mypy.newsemanal.typeanal import TypeAnalyser
78
from mypy.nodes import Expression, NameExpr, TypeInfo
89
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
910
from mypy.types import AnyType, Instance
1011
from mypy.types import Type as MypyType
1112
from mypy.types import TypeOfAny
1213

13-
from django.db.models.fields.related import RelatedField
1414
from mypy_django_plugin.django.context import DjangoContext
1515
from mypy_django_plugin.lib import fullnames, helpers
1616

1717

18+
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
19+
for base_type in [queryset_type, *queryset_type.type.bases]:
20+
if (len(base_type.args)
21+
and isinstance(base_type.args[0], Instance)
22+
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)):
23+
return base_type.args[0]
24+
return None
25+
26+
1827
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
1928
default_return_type = ctx.default_return_type
2029
assert isinstance(default_return_type, Instance)
@@ -98,11 +107,10 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
98107
assert isinstance(ctx.type, Instance)
99108
assert isinstance(ctx.default_return_type, Instance)
100109

101-
# bail if queryset of Any or other non-instances
102-
if not isinstance(ctx.type.args[0], Instance):
110+
model_type = _extract_model_type_from_queryset(ctx.type)
111+
if model_type is None:
103112
return AnyType(TypeOfAny.from_omitted_generics)
104113

105-
model_type = ctx.type.args[0]
106114
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
107115
if model_cls is None:
108116
return ctx.default_return_type
@@ -148,11 +156,10 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
148156
assert isinstance(ctx.type, Instance)
149157
assert isinstance(ctx.default_return_type, Instance)
150158

151-
# if queryset of non-instance type
152-
if not isinstance(ctx.type.args[0], Instance):
159+
model_type = _extract_model_type_from_queryset(ctx.type)
160+
if model_type is None:
153161
return AnyType(TypeOfAny.from_omitted_generics)
154162

155-
model_type = ctx.type.args[0]
156163
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
157164
if model_cls is None:
158165
return ctx.default_return_type

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def find_stub_files(name: str) -> List[str]:
2828

2929
setup(
3030
name="django-stubs",
31-
version="1.0.1",
31+
version="1.0.2",
3232
description='Mypy stubs for Django',
3333
long_description=readme,
3434
long_description_content_type='text/markdown',

Diff for: test-data/typecheck/managers/querysets/test_values_list.yml

+19-1
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,22 @@
204204
class Publisher(models.Model):
205205
pass
206206
class Blog(models.Model):
207-
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
207+
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
208+
209+
- case: subclass_of_queryset_has_proper_typings_on_methods
210+
main: |
211+
from myapp.models import TransactionQuerySet
212+
reveal_type(TransactionQuerySet()) # N: Revealed type is 'myapp.models.TransactionQuerySet'
213+
reveal_type(TransactionQuerySet().values()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, TypedDict({'id': builtins.int, 'total': builtins.int})]'
214+
reveal_type(TransactionQuerySet().values_list()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, Tuple[builtins.int, builtins.int]]'
215+
installed_apps:
216+
- myapp
217+
files:
218+
- path: myapp/__init__.py
219+
- path: myapp/models.py
220+
content: |
221+
from django.db import models
222+
class TransactionQuerySet(models.QuerySet['Transaction']):
223+
pass
224+
class Transaction(models.Model):
225+
total = models.IntegerField()

0 commit comments

Comments
 (0)