Skip to content

Ensure field extensions are only applied once #3832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: minor

This release fixes an issue where field extensions were being applied multiple times when a field was used in multiple schemas. This could lead to unexpected behavior or errors if the extension's `apply` method wasn't idempotent.

The issue has been resolved by introducing a caching mechanism that ensures each field extension is applied only once, regardless of how many schemas the field appears in. Test cases have been added to validate this behavior and ensure that extensions are applied correctly.
13 changes: 12 additions & 1 deletion strawberry/extensions/field_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import itertools
from collections.abc import Awaitable
from functools import cached_property
from functools import cache, cached_property
from typing import TYPE_CHECKING, Any, Callable, Union

if TYPE_CHECKING:
Expand Down Expand Up @@ -156,4 +156,15 @@ def build_field_extension_resolvers(
)


@cache
def apply_field_extensions(field: StrawberryField) -> None:
"""Applies the field extensions to the field.

This function is cached to avoid applying the extensions multiple times in the case
of multiple schema generation passes.
"""
for extension in field.extensions:
extension.apply(field)


__all__ = ["FieldExtension"]
8 changes: 5 additions & 3 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@
ScalarAlreadyRegisteredError,
UnresolvedFieldTypeError,
)
from strawberry.extensions.field_extension import build_field_extension_resolvers
from strawberry.extensions.field_extension import (
apply_field_extensions,
build_field_extension_resolvers,
)
from strawberry.schema.types.scalar import _make_scalar_type
from strawberry.types.arguments import StrawberryArgument, convert_arguments
from strawberry.types.base import (
Expand Down Expand Up @@ -683,8 +686,7 @@ def _get_result(

def wrap_field_extensions() -> Callable[..., Any]:
"""Wrap the provided field resolver with the middleware."""
for extension in field.extensions:
extension.apply(field)
apply_field_extensions(field)

extension_functions = build_field_extension_resolvers(field)

Expand Down
89 changes: 89 additions & 0 deletions tests/relay/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,92 @@
schema = strawberry.Schema(query=Query)

assert str(schema) == textwrap.dedent(expected_type).strip()


def test_multiple_schemas(mocker: MockerFixture):
"""Avoid regression of https://github.com/strawberry-graphql/strawberry/issues/3823"""
# Avoid E501 errors
mocker.patch.object(
DEFAULT_SCALAR_REGISTRY[relay.GlobalID],
"description",
"__GLOBAL_ID_DESC__",
)

@strawberry.type
class Query:
node: relay.Node = relay.node()

@relay.connection(relay.ListConnection[relay.Node])
def connection(self) -> list[relay.Node]:
return []

Check warning on line 421 in tests/relay/test_schema.py

View check run for this annotation

Codecov / codecov/patch

tests/relay/test_schema.py#L421

Added line #L421 was not covered by tests

schema_1 = strawberry.Schema(query=Query)
schema_2 = strawberry.Schema(query=Query)

expected = textwrap.dedent(
'''
"""__GLOBAL_ID_DESC__"""
scalar GlobalID @specifiedBy(url: "https://relay.dev/graphql/objectidentification.htm")

"""An object with a Globally Unique ID"""
interface Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
}

"""A connection to a list of items."""
type NodeConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [NodeEdge!]!
}

"""An edge in a connection."""
type NodeEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: Node!
}

"""Information to aid in pagination."""
type PageInfo {
"""When paginating forwards, are there more items?"""
hasNextPage: Boolean!

"""When paginating backwards, are there more items?"""
hasPreviousPage: Boolean!

"""When paginating backwards, the cursor to continue."""
startCursor: String

"""When paginating forwards, the cursor to continue."""
endCursor: String
}

type Query {
node(
"""The ID of the object."""
id: GlobalID!
): Node!
connection(
"""Returns the items in the list that come before the specified cursor."""
before: String = null

"""Returns the items in the list that come after the specified cursor."""
after: String = null

"""Returns the first n items from the list."""
first: Int = null

"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): NodeConnection!
}
'''
).strip()

assert str(schema_1) == str(schema_2) == expected
31 changes: 31 additions & 0 deletions tests/schema/extensions/test_field_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SyncExtensionResolver,
)
from strawberry.schema.config import StrawberryConfig
from strawberry.types.field import StrawberryField


class UpperCaseExtension(FieldExtension):
Expand Down Expand Up @@ -410,3 +411,33 @@
result = schema.execute_sync(query)
assert result.data, result.errors
assert result.data["string"] == "This is a test!!"


def test_extension_applied_once():
"""Avoid regression of https://github.com/strawberry-graphql/strawberry/issues/3823"""
applied = 0

class CustomExtension(FieldExtension):
def apply(self, field: StrawberryField):
nonlocal applied
applied += 1

def resolve(
self,
next_: Callable[..., Any],
source: Any,
info: strawberry.Info,
**kwargs: Any,
):
return next_(source, info, **kwargs)

Check warning on line 432 in tests/schema/extensions/test_field_extensions.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/extensions/test_field_extensions.py#L432

Added line #L432 was not covered by tests

@strawberry.type
class Query:
@strawberry.field(extensions=[CustomExtension()])
def string(self) -> str:
return "This is a test!!"

Check warning on line 438 in tests/schema/extensions/test_field_extensions.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/extensions/test_field_extensions.py#L438

Added line #L438 was not covered by tests

for _ in range(5):
strawberry.Schema(query=Query)

assert applied == 1
Loading