diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..87c5c15797 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,5 @@ +Release type: minor + +This release fixes an issue where field extensions were being applied multiple times when a field was used in multiple schemas. This could lead to unexpected behavior or errors if the extension's `apply` method wasn't idempotent. + +The issue has been resolved by introducing a caching mechanism that ensures each field extension is applied only once, regardless of how many schemas the field appears in. Test cases have been added to validate this behavior and ensure that extensions are applied correctly. diff --git a/strawberry/extensions/field_extension.py b/strawberry/extensions/field_extension.py index 7f97c07222..967ead6547 100644 --- a/strawberry/extensions/field_extension.py +++ b/strawberry/extensions/field_extension.py @@ -2,7 +2,7 @@ import itertools from collections.abc import Awaitable -from functools import cached_property +from functools import cache, cached_property from typing import TYPE_CHECKING, Any, Callable, Union if TYPE_CHECKING: @@ -156,4 +156,15 @@ def build_field_extension_resolvers( ) +@cache +def apply_field_extensions(field: StrawberryField) -> None: + """Applies the field extensions to the field. + + This function is cached to avoid applying the extensions multiple times in the case + of multiple schema generation passes. + """ + for extension in field.extensions: + extension.apply(field) + + __all__ = ["FieldExtension"] diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 332a9bdb5f..7334e1f1be 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -47,7 +47,10 @@ ScalarAlreadyRegisteredError, UnresolvedFieldTypeError, ) -from strawberry.extensions.field_extension import build_field_extension_resolvers +from strawberry.extensions.field_extension import ( + apply_field_extensions, + build_field_extension_resolvers, +) from strawberry.schema.types.scalar import _make_scalar_type from strawberry.types.arguments import StrawberryArgument, convert_arguments from strawberry.types.base import ( @@ -683,8 +686,7 @@ def _get_result( def wrap_field_extensions() -> Callable[..., Any]: """Wrap the provided field resolver with the middleware.""" - for extension in field.extensions: - extension.apply(field) + apply_field_extensions(field) extension_functions = build_field_extension_resolvers(field) diff --git a/tests/relay/test_schema.py b/tests/relay/test_schema.py index 22c8897e08..b7c2f56444 100644 --- a/tests/relay/test_schema.py +++ b/tests/relay/test_schema.py @@ -401,3 +401,92 @@ class Query: schema = strawberry.Schema(query=Query) assert str(schema) == textwrap.dedent(expected_type).strip() + + +def test_multiple_schemas(mocker: MockerFixture): + """Avoid regression of https://github.com/strawberry-graphql/strawberry/issues/3823""" + # Avoid E501 errors + mocker.patch.object( + DEFAULT_SCALAR_REGISTRY[relay.GlobalID], + "description", + "__GLOBAL_ID_DESC__", + ) + + @strawberry.type + class Query: + node: relay.Node = relay.node() + + @relay.connection(relay.ListConnection[relay.Node]) + def connection(self) -> list[relay.Node]: + return [] + + schema_1 = strawberry.Schema(query=Query) + schema_2 = strawberry.Schema(query=Query) + + expected = textwrap.dedent( + ''' + """__GLOBAL_ID_DESC__""" + scalar GlobalID @specifiedBy(url: "https://relay.dev/graphql/objectidentification.htm") + + """An object with a Globally Unique ID""" + interface Node { + """The Globally Unique ID of this object""" + id: GlobalID! + } + + """A connection to a list of items.""" + type NodeConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [NodeEdge!]! + } + + """An edge in a connection.""" + type NodeEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Node! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + node( + """The ID of the object.""" + id: GlobalID! + ): Node! + connection( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + ): NodeConnection! + } + ''' + ).strip() + + assert str(schema_1) == str(schema_2) == expected diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index ae9bc3848d..2a523c4ae0 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -10,6 +10,7 @@ SyncExtensionResolver, ) from strawberry.schema.config import StrawberryConfig +from strawberry.types.field import StrawberryField class UpperCaseExtension(FieldExtension): @@ -410,3 +411,33 @@ def string(self) -> str: result = schema.execute_sync(query) assert result.data, result.errors assert result.data["string"] == "This is a test!!" + + +def test_extension_applied_once(): + """Avoid regression of https://github.com/strawberry-graphql/strawberry/issues/3823""" + applied = 0 + + class CustomExtension(FieldExtension): + def apply(self, field: StrawberryField): + nonlocal applied + applied += 1 + + def resolve( + self, + next_: Callable[..., Any], + source: Any, + info: strawberry.Info, + **kwargs: Any, + ): + return next_(source, info, **kwargs) + + @strawberry.type + class Query: + @strawberry.field(extensions=[CustomExtension()]) + def string(self) -> str: + return "This is a test!!" + + for _ in range(5): + strawberry.Schema(query=Query) + + assert applied == 1