-
-
Notifications
You must be signed in to change notification settings - Fork 607
unit tests for Union type look-up during queries #4088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
himanshu007-creator
wants to merge
2
commits into
strawberry-graphql:main
Choose a base branch
from
himanshu007-creator:1468-himanshu007-creator
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+372
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,372 @@ | ||
| """Tests for StrawberryUnion.get_type_resolver method and the resolver function it returns.""" | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Annotated, Generic, TypeVar | ||
| from unittest.mock import MagicMock | ||
|
|
||
| import pytest | ||
|
|
||
| import strawberry | ||
| from strawberry.annotation import StrawberryAnnotation | ||
| from strawberry.exceptions import UnallowedReturnTypeForUnion, WrongReturnTypeForUnion | ||
| from strawberry.schema.types.concrete_type import TypeMap | ||
| from strawberry.types.union import StrawberryUnion | ||
|
|
||
|
|
||
| def test_get_type_resolver_returns_function(): | ||
| """Test that get_type_resolver returns a callable function.""" | ||
|
|
||
| @strawberry.type | ||
| class A: | ||
| a: int | ||
|
|
||
| @strawberry.type | ||
| class B: | ||
| b: int | ||
|
|
||
| union = StrawberryUnion( | ||
| type_annotations=( | ||
| StrawberryAnnotation(A), | ||
| StrawberryAnnotation(B), | ||
| ) | ||
| ) | ||
|
|
||
| type_map: TypeMap = {} | ||
| resolver = union.get_type_resolver(type_map) | ||
|
|
||
| assert callable(resolver) | ||
|
|
||
|
|
||
| def test_type_resolver_with_object_definition(): | ||
| """Test that the resolver correctly resolves types when root has object definition.""" | ||
|
|
||
| @strawberry.type | ||
| class A: | ||
| a: int | ||
|
|
||
| @strawberry.type | ||
| class B: | ||
| b: int | ||
|
|
||
| # Create a schema to get proper GraphQL types | ||
| @strawberry.type | ||
| class Query: | ||
| @strawberry.field | ||
| def test(self) -> A | B: | ||
| return A(a=1) | ||
|
|
||
| schema = strawberry.Schema(query=Query) | ||
|
|
||
| # Find the union in the schema's type map | ||
| union_name = None | ||
| union_def = None | ||
| for name, concrete_type in schema.schema_converter.type_map.items(): | ||
| if isinstance(concrete_type.definition, StrawberryUnion): | ||
| union_name = name | ||
| union_def = concrete_type.definition | ||
| break | ||
|
|
||
| assert union_name is not None, "Union should be in type map" | ||
| assert union_def is not None, "Union definition should be found" | ||
| graphql_union = schema.schema_converter.type_map[union_name].implementation | ||
|
|
||
| # Get the resolver from the union | ||
| resolver = union_def.get_type_resolver(schema.schema_converter.type_map) | ||
|
|
||
| # Create mock info | ||
| info = MagicMock() | ||
| info.field_name = "test" | ||
|
|
||
| # Test with A instance | ||
| a_instance = A(a=1) | ||
| result = resolver(a_instance, info, graphql_union) | ||
| assert result == "A" | ||
|
|
||
| # Test with B instance | ||
| b_instance = B(b=2) | ||
| result = resolver(b_instance, info, graphql_union) | ||
| assert result == "B" | ||
|
|
||
|
|
||
| def test_type_resolver_with_is_type_of(): | ||
| """Test that the resolver uses is_type_of when root doesn't have object definition.""" | ||
|
|
||
| @dataclass | ||
| class ADataclass: | ||
| a: int | ||
|
|
||
| @dataclass | ||
| class BDataclass: | ||
| b: int | ||
|
|
||
| @strawberry.type | ||
| class A: | ||
| a: int | ||
|
|
||
| @classmethod | ||
| def is_type_of(cls, obj, _info) -> bool: | ||
| return isinstance(obj, ADataclass) | ||
|
|
||
| @strawberry.type | ||
| class B: | ||
| b: int | ||
|
|
||
| @classmethod | ||
| def is_type_of(cls, obj, _info) -> bool: | ||
| return isinstance(obj, BDataclass) | ||
|
|
||
| # Create a schema to get proper GraphQL types | ||
| @strawberry.type | ||
| class Query: | ||
| @strawberry.field | ||
| def test(self) -> A | B: | ||
| return ADataclass(a=1) # type: ignore | ||
|
|
||
| schema = strawberry.Schema(query=Query) | ||
|
|
||
| # Find the union in the schema's type map | ||
| union_name = None | ||
| union_def = None | ||
| for name, concrete_type in schema.schema_converter.type_map.items(): | ||
| if isinstance(concrete_type.definition, StrawberryUnion): | ||
| union_name = name | ||
| union_def = concrete_type.definition | ||
| break | ||
|
|
||
| assert union_name is not None, "Union should be in type map" | ||
| assert union_def is not None, "Union definition should be found" | ||
| graphql_union = schema.schema_converter.type_map[union_name].implementation | ||
|
|
||
| # Get the resolver from the union | ||
| resolver = union_def.get_type_resolver(schema.schema_converter.type_map) | ||
|
|
||
| # Create mock info | ||
| info = MagicMock() | ||
| info.field_name = "test" | ||
|
|
||
| # Test with ADataclass instance (should resolve to A) | ||
| a_dataclass = ADataclass(a=1) | ||
| result = resolver(a_dataclass, info, graphql_union) | ||
| assert result == "A" | ||
|
|
||
| # Test with BDataclass instance (should resolve to B) | ||
| b_dataclass = BDataclass(b=2) | ||
| result = resolver(b_dataclass, info, graphql_union) | ||
| assert result == "B" | ||
|
|
||
|
|
||
| def test_type_resolver_raises_wrong_return_type_when_no_is_type_of_matches(): | ||
| """Test that the resolver raises WrongReturnTypeForUnion when is_type_of doesn't match.""" | ||
|
|
||
| @dataclass | ||
| class UnknownDataclass: | ||
| x: int | ||
|
|
||
| @strawberry.type | ||
| class A: | ||
| a: int | ||
|
|
||
| @classmethod | ||
| def is_type_of(cls, obj, _info) -> bool: | ||
| return False # This won't match | ||
|
|
||
| @strawberry.type | ||
| class B: | ||
| b: int | ||
|
|
||
| @classmethod | ||
| def is_type_of(cls, obj, _info) -> bool: | ||
| return False # This won't match either | ||
|
|
||
| # Create a schema to get proper GraphQL types | ||
| @strawberry.type | ||
| class Query: | ||
| @strawberry.field | ||
| def test(self) -> A | B: | ||
| return UnknownDataclass(x=1) # type: ignore | ||
|
|
||
| schema = strawberry.Schema(query=Query) | ||
|
|
||
| # Find the union in the schema's type map | ||
| union_name = None | ||
| union_def = None | ||
| for name, concrete_type in schema.schema_converter.type_map.items(): | ||
| if isinstance(concrete_type.definition, StrawberryUnion): | ||
| union_name = name | ||
| union_def = concrete_type.definition | ||
| break | ||
|
|
||
| assert union_name is not None, "Union should be in type map" | ||
| assert union_def is not None, "Union definition should be found" | ||
| graphql_union = schema.schema_converter.type_map[union_name].implementation | ||
|
|
||
| # Get the resolver from the union | ||
| resolver = union_def.get_type_resolver(schema.schema_converter.type_map) | ||
|
|
||
| # Create mock info | ||
| info = MagicMock() | ||
| info.field_name = "test" | ||
|
|
||
| # Test with UnknownDataclass instance (should raise error) | ||
| unknown_instance = UnknownDataclass(x=1) | ||
| with pytest.raises(WrongReturnTypeForUnion) as exc_info: | ||
| resolver(unknown_instance, info, graphql_union) | ||
|
|
||
| assert "test" in str(exc_info.value) | ||
| assert "UnknownDataclass" in str(exc_info.value) | ||
|
|
||
|
|
||
| def test_type_resolver_raises_unallowed_return_type_when_not_in_union(): | ||
| """Test that the resolver raises UnallowedReturnTypeForUnion when type is not in union.""" | ||
|
|
||
| @strawberry.type | ||
| class A: | ||
| a: int | ||
|
|
||
| @strawberry.type | ||
| class B: | ||
| b: int | ||
|
|
||
| @strawberry.type | ||
| class Outside: | ||
| c: int | ||
|
|
||
| # Create a schema that includes Outside but not in the union | ||
| @strawberry.type | ||
| class Query: | ||
| @strawberry.field | ||
| def test(self) -> A | B: | ||
| return Outside(c=1) # type: ignore | ||
|
|
||
| schema = strawberry.Schema(query=Query, types=[Outside]) | ||
|
|
||
| # Find the union in the schema's type map | ||
| union_name = None | ||
| union_def = None | ||
| for name, concrete_type in schema.schema_converter.type_map.items(): | ||
| if isinstance(concrete_type.definition, StrawberryUnion): | ||
| union_name = name | ||
| union_def = concrete_type.definition | ||
| break | ||
|
|
||
| assert union_name is not None, "Union should be in type map" | ||
| assert union_def is not None, "Union definition should be found" | ||
| graphql_union = schema.schema_converter.type_map[union_name].implementation | ||
|
|
||
| # Get the resolver from the union | ||
| resolver = union_def.get_type_resolver(schema.schema_converter.type_map) | ||
|
|
||
| # Create mock info | ||
| info = MagicMock() | ||
| info.field_name = "test" | ||
|
|
||
| # Test with Outside instance (should raise error) | ||
| outside_instance = Outside(c=1) | ||
| with pytest.raises(UnallowedReturnTypeForUnion) as exc_info: | ||
| resolver(outside_instance, info, graphql_union) | ||
|
|
||
| assert "test" in str(exc_info.value) | ||
| assert "Outside" in str(exc_info.value) | ||
| assert "A" in str(exc_info.value) or "B" in str(exc_info.value) | ||
|
|
||
|
|
||
| def test_type_resolver_prioritizes_union_types(): | ||
| """Test that the resolver prioritizes union types over other types in type_map. | ||
|
|
||
| This tests the bug fix from PR #1463 where union types should be checked first | ||
| before falling back to all types in the type_map. | ||
| """ | ||
| T = TypeVar("T") | ||
|
|
||
| @strawberry.type | ||
| class Container(Generic[T]): | ||
| items: list[T] | ||
|
|
||
| @strawberry.type | ||
| class A: | ||
| a: str | ||
|
|
||
| @strawberry.type | ||
| class B: | ||
| b: int | ||
|
|
||
| @strawberry.type | ||
| class Query: | ||
| @strawberry.field | ||
| def container_a(self) -> Container[A] | A: | ||
| return Container(items=[A(a="hello")]) | ||
|
|
||
| @strawberry.field | ||
| def container_b(self) -> Container[B] | B: | ||
| return Container(items=[B(b=3)]) | ||
|
|
||
| schema = strawberry.Schema(query=Query) | ||
|
|
||
| # Find the union for Container[A] | A in the schema's type map | ||
| # We'll look for a union that has "AContainer" in its name (generated by schema) | ||
| union_name = None | ||
| union_def = None | ||
| for name, concrete_type in schema.schema_converter.type_map.items(): | ||
| if isinstance(concrete_type.definition, StrawberryUnion): | ||
| # Check if this union name contains "AContainer" (indicating Container[A] | A) | ||
| # and check if A is in the union types | ||
| union_types = concrete_type.definition.types | ||
| has_a = A in union_types | ||
| if "AContainer" in name and has_a: | ||
| union_name = name | ||
| union_def = concrete_type.definition | ||
| break | ||
|
|
||
| assert union_name is not None, "Union Container[A] | A should be in type map" | ||
| assert union_def is not None, "Union definition should be found" | ||
| graphql_union = schema.schema_converter.type_map[union_name].implementation | ||
|
|
||
| # Get the resolver from the union | ||
| resolver = union_def.get_type_resolver(schema.schema_converter.type_map) | ||
|
|
||
| # Create mock info | ||
| info = MagicMock() | ||
| info.field_name = "containerA" | ||
|
|
||
| # Test with Container[A] instance - should resolve to Container[A], not Container[B] | ||
| container_a_instance = Container(items=[A(a="hello")]) | ||
| result = resolver(container_a_instance, info, graphql_union) | ||
|
|
||
| # Should resolve to the Container[A] type name | ||
| assert result == "AContainer" | ||
|
|
||
| # Test with A instance | ||
| a_instance = A(a="hello") | ||
| result = resolver(a_instance, info, graphql_union) | ||
| assert result == "A" | ||
|
|
||
|
|
||
| def test_type_resolver_with_single_type_union(): | ||
| """Test that the resolver works correctly with a single-type union.""" | ||
|
|
||
| @strawberry.type | ||
| class A: | ||
| a: int | ||
|
|
||
| @strawberry.type | ||
| class Query: | ||
| @strawberry.field | ||
| def test(self) -> Annotated[A, strawberry.union(name="SingleUnion")]: | ||
| return A(a=1) | ||
|
|
||
| schema = strawberry.Schema(query=Query) | ||
| graphql_union = schema.schema_converter.type_map["SingleUnion"].implementation | ||
| union_def = schema.schema_converter.type_map["SingleUnion"].definition | ||
| assert isinstance(union_def, StrawberryUnion), "Should be a StrawberryUnion" | ||
|
|
||
| # Get the resolver from the union | ||
| resolver = union_def.get_type_resolver(schema.schema_converter.type_map) | ||
|
|
||
| # Create mock info | ||
| info = MagicMock() | ||
| info.field_name = "test" | ||
|
|
||
| # Test with A instance | ||
| a_instance = A(a=1) | ||
| result = resolver(a_instance, info, graphql_union) | ||
| assert result == "A" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this iteration? Since this is a test, we might be able to just hardcode what we are expecting, so it is easier to understand what we are doing
Same for similar for loops in this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure will look into this