Skip to content

Commit d95acd6

Browse files
authored
fix(optimizer): Avoid merging prefetches when using aliases (#698)
1 parent 610e12b commit d95acd6

File tree

2 files changed

+93
-40
lines changed

2 files changed

+93
-40
lines changed

Diff for: strawberry_django/optimizer.py

+46-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import copy
66
import dataclasses
77
import itertools
8-
from collections import defaultdict
8+
from collections import Counter, defaultdict
99
from collections.abc import Callable
1010
from typing import (
1111
TYPE_CHECKING,
@@ -58,6 +58,7 @@
5858
from .utils.inspect import (
5959
PrefetchInspector,
6060
get_model_field,
61+
get_model_fields,
6162
get_possible_type_definitions,
6263
)
6364
from .utils.typing import (
@@ -1035,19 +1036,29 @@ def _get_model_hints(
10351036
if pk is not None:
10361037
store.only.append(pk.attname)
10371038

1038-
for f_selections in _get_selections(info, parent_type).values():
1039-
field_data = _get_field_data(
1040-
f_selections,
1041-
object_definition,
1042-
schema,
1043-
parent_type=parent_type,
1044-
info=info,
1039+
selections = [
1040+
field_data
1041+
for f_selection in _get_selections(info, parent_type).values()
1042+
if (
1043+
field_data := _get_field_data(
1044+
f_selection,
1045+
object_definition,
1046+
schema,
1047+
parent_type=parent_type,
1048+
info=info,
1049+
)
10451050
)
1046-
if field_data is None:
1051+
is not None
1052+
]
1053+
fields_counter = Counter(field_data[0] for field_data in selections)
1054+
1055+
for field, f_definition, f_selection, f_info in selections:
1056+
# If a field is selected more than once in the query, that means it is being
1057+
# aliased. In this case, optimizing it would make one query to affect the other,
1058+
# resulting in wrong results for both.
1059+
if fields_counter[field] > 1:
10471060
continue
10481061

1049-
field, f_definition, f_selection, f_info = field_data
1050-
10511062
# Add annotations from the field if they exist
10521063
if field_store := _get_hints_from_field(field, f_info=f_info, prefix=prefix):
10531064
store |= field_store
@@ -1089,6 +1100,30 @@ def _get_model_hints(
10891100
store.only.extend(inner_store.only)
10901101
store.select_related.extend(inner_store.select_related)
10911102

1103+
# In case we skipped optimization for a relation, we might end up with a new QuerySet
1104+
# which would not select its parent relation field on `.only()`, causing n+1 issues.
1105+
# Make sure that in this case we also select it.
1106+
if level == 0 and store.only and info.path.prev:
1107+
own_fk_fields = [
1108+
field
1109+
for field in get_model_fields(model).values()
1110+
if isinstance(field, models.ForeignKey)
1111+
]
1112+
1113+
path = info.path
1114+
while path := path.prev:
1115+
type_ = schema.get_type_by_name(path.typename)
1116+
if not isinstance(type_, StrawberryObjectDefinition):
1117+
continue
1118+
1119+
if not (strawberry_django_type := get_django_definition(type_.origin)):
1120+
continue
1121+
1122+
for field in own_fk_fields:
1123+
if field.related_model is strawberry_django_type.model:
1124+
store.only.append(field.attname)
1125+
break
1126+
10921127
return store
10931128

10941129

Diff for: tests/test_optimizer.py

+47-29
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,6 @@ def test_query_forward_with_fragments(db, gql_client: GraphQLTestClient):
308308
}
309309
... milestoneFrag
310310
}
311-
milestoneAgain: milestone {
312-
name
313-
project {
314-
id
315-
name
316-
}
317-
... milestoneFrag
318-
}
319311
}
320312
}
321313
}
@@ -341,7 +333,6 @@ def test_query_forward_with_fragments(db, gql_client: GraphQLTestClient):
341333
"nameWithKind": f"{i.kind}: {i.name}",
342334
"nameWithPriority": f"{i.kind}: {i.priority}",
343335
"milestone": m_res,
344-
"milestoneAgain": m_res,
345336
},
346337
)
347338

@@ -538,12 +529,6 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
538529
... milestoneFrag
539530
}
540531
}
541-
otherIssues: issues {
542-
id
543-
milestone {
544-
... milestoneFrag
545-
}
546-
}
547532
}
548533
}
549534
}
@@ -566,7 +551,6 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
566551
"name": p_res["name"],
567552
},
568553
"issues": [],
569-
"otherIssues": [],
570554
}
571555
p_res["milestones"].append(m_res)
572556
for i in IssueFactory.create_batch(3, milestone=m):
@@ -585,22 +569,10 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
585569
},
586570
},
587571
)
588-
m_res["otherIssues"].append(
589-
{
590-
"id": to_base64("IssueType", i.id),
591-
"milestone": {
592-
"id": m_res["id"],
593-
"project": {
594-
"id": p_res["id"],
595-
"name": p_res["name"],
596-
},
597-
},
598-
},
599-
)
600572

601573
assert len(expected) == 3
602574
for e in expected:
603-
with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 8):
575+
with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 5):
604576
res = gql_client.query(query, {"node_id": e["id"]})
605577

606578
assert res.data == {"project": e}
@@ -1089,6 +1061,52 @@ def test_query_nested_connection_with_filter(db, gql_client: GraphQLTestClient):
10891061
} == expected
10901062

10911063

1064+
@pytest.mark.django_db(transaction=True)
1065+
def test_query_nested_connection_with_filter_and_alias(
1066+
db, gql_client: GraphQLTestClient
1067+
):
1068+
query = """
1069+
query TestQuery ($id: GlobalID!) {
1070+
milestone(id: $id) {
1071+
id
1072+
fooIssues: issuesWithFilters (filters: {search: "Foo"}) {
1073+
edges {
1074+
node {
1075+
id
1076+
}
1077+
}
1078+
}
1079+
barIssues: issuesWithFilters (filters: {search: "Bar"}) {
1080+
edges {
1081+
node {
1082+
id
1083+
}
1084+
}
1085+
}
1086+
}
1087+
}
1088+
"""
1089+
1090+
milestone = MilestoneFactory.create()
1091+
issue1 = IssueFactory.create(milestone=milestone, name="Foo")
1092+
issue2 = IssueFactory.create(milestone=milestone, name="Foo Bar")
1093+
issue3 = IssueFactory.create(milestone=milestone, name="Bar Foo")
1094+
issue4 = IssueFactory.create(milestone=milestone, name="Bar Bin")
1095+
1096+
with assert_num_queries(3):
1097+
res = gql_client.query(query, {"id": to_base64("MilestoneType", milestone.pk)})
1098+
1099+
assert isinstance(res.data, dict)
1100+
result = res.data["milestone"]
1101+
assert isinstance(result, dict)
1102+
1103+
foo_expected = {to_base64("IssueType", i.pk) for i in [issue1, issue2, issue3]}
1104+
assert {edge["node"]["id"] for edge in result["fooIssues"]["edges"]} == foo_expected
1105+
1106+
bar_expected = {to_base64("IssueType", i.pk) for i in [issue2, issue3, issue4]}
1107+
assert {edge["node"]["id"] for edge in result["barIssues"]["edges"]} == bar_expected
1108+
1109+
10921110
@pytest.mark.django_db(transaction=True)
10931111
def test_query_with_optimizer_paginated_prefetch():
10941112
@strawberry_django.type(Milestone, pagination=True)

0 commit comments

Comments
 (0)