Skip to content

Commit 826287f

Browse files
fix(printer): Fix printer ignoring input arguments using snake_case (#3780)
* fix(printer): Fix printer ignoring input arguments using snake_case Fix #3760 * Update strawberry/printer/printer.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update RELEASE.md Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * one missing propagation * skip test on gql2 due to different formatting --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
1 parent b8bda78 commit 826287f

File tree

3 files changed

+149
-8
lines changed

3 files changed

+149
-8
lines changed

RELEASE.md

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
Release type: patch
2+
3+
This release fixes an issue where directives with input types using snake_case
4+
would not be printed in the schema.
5+
6+
For example, the following:
7+
8+
```python
9+
@strawberry.input
10+
class FooInput:
11+
hello: str
12+
hello_world: str
13+
14+
15+
@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
16+
class FooDirective:
17+
input: FooInput
18+
19+
20+
@strawberry.type
21+
class Query:
22+
@strawberry.field(
23+
directives=[
24+
FooDirective(input=FooInput(hello="hello", hello_world="hello world")),
25+
]
26+
)
27+
def foo(self, info) -> str: ...
28+
```
29+
30+
Would previously print as:
31+
32+
```graphql
33+
directive @fooDirective(
34+
input: FooInput!
35+
optionalInput: FooInput
36+
) on FIELD_DEFINITION
37+
38+
type Query {
39+
foo: String! @fooDirective(input: { hello: "hello" })
40+
}
41+
42+
input FooInput {
43+
hello: String!
44+
hello_world: String!
45+
}
46+
```
47+
48+
Now it will be correctly printed as:
49+
50+
```graphql
51+
directive @fooDirective(
52+
input: FooInput!
53+
optionalInput: FooInput
54+
) on FIELD_DEFINITION
55+
56+
type Query {
57+
foo: String!
58+
@fooDirective(input: { hello: "hello", helloWorld: "hello world" })
59+
}
60+
61+
input FooInput {
62+
hello: String!
63+
hello_world: String!
64+
}
65+
```

strawberry/printer/printer.py

+43-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import (
66
TYPE_CHECKING,
77
Any,
8+
Callable,
89
Optional,
910
TypeVar,
1011
Union,
@@ -68,40 +69,73 @@ class PrintExtras:
6869

6970

7071
@overload
71-
def _serialize_dataclasses(value: dict[_T, object]) -> dict[_T, object]: ...
72+
def _serialize_dataclasses(
73+
value: dict[_T, object],
74+
*,
75+
name_converter: Callable[[str], str] | None = None,
76+
) -> dict[_T, object]: ...
7277

7378

7479
@overload
7580
def _serialize_dataclasses(
7681
value: Union[list[object], tuple[object]],
82+
*,
83+
name_converter: Callable[[str], str] | None = None,
7784
) -> list[object]: ...
7885

7986

8087
@overload
81-
def _serialize_dataclasses(value: object) -> object: ...
88+
def _serialize_dataclasses(
89+
value: object,
90+
*,
91+
name_converter: Callable[[str], str] | None = None,
92+
) -> object: ...
8293

8394

84-
def _serialize_dataclasses(value):
95+
def _serialize_dataclasses(
96+
value,
97+
*,
98+
name_converter: Callable[[str], str] | None = None,
99+
):
100+
if name_converter is None:
101+
name_converter = lambda x: x # noqa: E731
102+
85103
if dataclasses.is_dataclass(value):
86-
return {k: v for k, v in dataclasses.asdict(value).items() if v is not UNSET} # type: ignore
104+
return {
105+
name_converter(k): v
106+
for k, v in dataclasses.asdict(value).items() # type: ignore
107+
if v is not UNSET
108+
}
87109
if isinstance(value, (list, tuple)):
88-
return [_serialize_dataclasses(v) for v in value]
110+
return [_serialize_dataclasses(v, name_converter=name_converter) for v in value]
89111
if isinstance(value, dict):
90-
return {k: _serialize_dataclasses(v) for k, v in value.items()}
112+
return {
113+
name_converter(k): _serialize_dataclasses(v, name_converter=name_converter)
114+
for k, v in value.items()
115+
}
91116

92117
return value
93118

94119

95120
def print_schema_directive_params(
96-
directive: GraphQLDirective, values: dict[str, Any]
121+
directive: GraphQLDirective,
122+
values: dict[str, Any],
123+
*,
124+
schema: BaseSchema,
97125
) -> str:
98126
params = []
99127
for name, arg in directive.args.items():
100128
value = values.get(name, arg.default_value)
101129
if value is UNSET:
102130
value = None
103131
else:
104-
ast = ast_from_value(_serialize_dataclasses(value), arg.type)
132+
ast = ast_from_value(
133+
_serialize_dataclasses(
134+
value,
135+
name_converter=schema.config.name_converter.apply_naming_config,
136+
),
137+
arg.type,
138+
)
105139
value = ast and f"{name}: {print_ast(ast)}"
106140

107141
if value:
@@ -129,6 +163,7 @@ def print_schema_directive(
129163
)
130164
for f in strawberry_directive.fields
131165
},
166+
schema=schema,
132167
)
133168

134169
printed_directive = print_directive(gql_directive, schema=schema)

tests/test_printer/test_schema_directives.py

+41
Original file line numberDiff line numberDiff line change
@@ -749,3 +749,44 @@ def foo(self, info) -> str: ...
749749
schema = strawberry.Schema(query=Query)
750750

751751
assert print_schema(schema) == textwrap.dedent(expected_output).strip()
752+
753+
754+
@skip_if_gql_32("formatting is different in gql 3.2")
755+
def test_print_directive_with_snake_case_arguments():
756+
@strawberry.input
757+
class FooInput:
758+
hello: str
759+
hello_world: str
760+
761+
@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
762+
class FooDirective:
763+
input: FooInput
764+
optional_input: Optional[FooInput] = strawberry.UNSET
765+
766+
@strawberry.type
767+
class Query:
768+
@strawberry.field(
769+
directives=[
770+
FooDirective(input=FooInput(hello="hello", hello_world="hello world"))
771+
]
772+
)
773+
def foo(self, info) -> str: ...
774+
775+
schema = strawberry.Schema(query=Query)
776+
777+
expected_output = """
778+
directive @fooDirective(input: FooInput!, optionalInput: FooInput) on FIELD_DEFINITION
779+
780+
type Query {
781+
foo: String! @fooDirective(input: { hello: "hello", helloWorld: "hello world" })
782+
}
783+
784+
input FooInput {
785+
hello: String!
786+
helloWorld: String!
787+
}
788+
"""
789+
790+
schema = strawberry.Schema(query=Query)
791+
792+
assert print_schema(schema) == textwrap.dedent(expected_output).strip()

0 commit comments

Comments
 (0)