Skip to content

Commit c1a0548

Browse files
authored
Merge pull request #270 from mirumee/fix-tests
Fix blank line between class and first method not being formatted out.
2 parents 3b8164a + 29dee20 commit c1a0548

File tree

6 files changed

+39
-25
lines changed

6 files changed

+39
-25
lines changed

ariadne_codegen/client_generators/client.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,15 @@ def add_method(
141141
raise NotSupported(
142142
"Subscriptions are only available when using async client."
143143
)
144-
method_def: Union[
145-
ast.FunctionDef, ast.AsyncFunctionDef
146-
] = self._generate_subscription_method_def(
147-
name=name,
148-
operation_name=operation_name,
149-
return_type=return_type,
150-
arguments=arguments,
151-
arguments_dict=arguments_dict,
152-
operation_str=operation_str,
144+
method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef] = (
145+
self._generate_subscription_method_def(
146+
name=name,
147+
operation_name=operation_name,
148+
return_type=return_type,
149+
arguments=arguments,
150+
arguments_dict=arguments_dict,
151+
operation_str=operation_str,
152+
)
153153
)
154154
elif async_:
155155
method_def = self._generate_async_method(

ariadne_codegen/client_generators/result_fields.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def parse_operation_field(
8989
schema=schema,
9090
field_node=field,
9191
custom_scalars=custom_scalars if custom_scalars else {},
92-
fragments_definitions=fragments_definitions
93-
if fragments_definitions
94-
else {},
92+
fragments_definitions=(
93+
fragments_definitions if fragments_definitions else {}
94+
),
9595
)
9696
)
9797

ariadne_codegen/utils.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
from keyword import iskeyword
44
from textwrap import indent
5-
from typing import Optional
5+
from typing import List, Optional
66

77
import isort
88
from autoflake import fix_code # type: ignore
@@ -25,13 +25,31 @@ def ast_to_str(
2525
) -> str:
2626
"""Convert ast object into string."""
2727
code = ast.unparse(ast_obj)
28+
code = remove_blank_line_between_class_and_content(code)
2829
if remove_unused_imports:
2930
code = fix_code(code, remove_all_unused_imports=True)
3031
if multiline_strings:
3132
code = format_multiline_strings(code, offset=multiline_strings_offset)
3233
return format_str(isort.code(code), mode=Mode())
3334

3435

36+
def remove_blank_line_between_class_and_content(code: str) -> str:
37+
"""Removes blank lines between class and first method.
38+
39+
We are doing this for code style consistency and backwards compatibility.
40+
"""
41+
code_lines: List[str] = []
42+
skip_blank_lines = False
43+
for line in code.splitlines():
44+
if skip_blank_lines and line:
45+
skip_blank_lines = False
46+
elif line.startswith("class "):
47+
skip_blank_lines = True
48+
if not skip_blank_lines or line:
49+
code_lines.append(line)
50+
return "\n".join(code_lines)
51+
52+
3553
def str_to_snake_case(name: str) -> str:
3654
"""Converts camelCase or PascalCase string into snake_case."""
3755
# lower-case letters that optionally start with a single upper-case letter

tests/main/clients/inline_fragments/expected_client/union_a.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77

88
class UnionA(BaseModel):
9-
query_u: Union[
10-
"UnionAQueryUTypeA", "UnionAQueryUTypeB", "UnionAQueryUTypeC"
11-
] = Field(alias="queryU", discriminator="typename__")
9+
query_u: Union["UnionAQueryUTypeA", "UnionAQueryUTypeB", "UnionAQueryUTypeC"] = (
10+
Field(alias="queryU", discriminator="typename__")
11+
)
1212

1313

1414
class UnionAQueryUTypeA(BaseModel):

tests/main/clients/inline_fragments/expected_client/union_b.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77

88
class UnionB(BaseModel):
9-
query_u: Union[
10-
"UnionBQueryUTypeA", "UnionBQueryUTypeB", "UnionBQueryUTypeC"
11-
] = Field(alias="queryU", discriminator="typename__")
9+
query_u: Union["UnionBQueryUTypeA", "UnionBQueryUTypeB", "UnionBQueryUTypeC"] = (
10+
Field(alias="queryU", discriminator="typename__")
11+
)
1212

1313

1414
class UnionBQueryUTypeA(BaseModel):

tests/main/clients/shorter_results/expected_client/client.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,7 @@ async def list_type_a(
133133
data = self.get_data(response)
134134
return ListTypeA.model_validate(data).list_optional_type_a
135135

136-
async def get_animal_by_name(
137-
self, name: str, **kwargs: Any
138-
) -> Union[
136+
async def get_animal_by_name(self, name: str, **kwargs: Any) -> Union[
139137
GetAnimalByNameAnimalByNameAnimal,
140138
GetAnimalByNameAnimalByNameCat,
141139
GetAnimalByNameAnimalByNameDog,
@@ -163,9 +161,7 @@ async def get_animal_by_name(
163161
data = self.get_data(response)
164162
return GetAnimalByName.model_validate(data).animal_by_name
165163

166-
async def list_animals(
167-
self, **kwargs: Any
168-
) -> List[
164+
async def list_animals(self, **kwargs: Any) -> List[
169165
Union[
170166
ListAnimalsListAnimalsAnimal,
171167
ListAnimalsListAnimalsCat,

0 commit comments

Comments
 (0)