Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstring to generated code based on schema description #362

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions ariadne_codegen/client_generators/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _parse_object_type_definitions(
class_def = self._generate_class_def_body(
definition=graphql_type,
class_name=f"{graphql_type.name}{self._get_suffix(graphql_type)}",
description=graphql_type.description,
)
if isinstance(graphql_type, GraphQLInterfaceType):
class_def.body.append(
Expand All @@ -129,17 +130,20 @@ def _generate_class_def_body(
self,
definition: Union[GraphQLObjectType, GraphQLInterfaceType],
class_name: str,
description: Optional[str] = None,
) -> ast.ClassDef:
"""
Generates the body of a class definition for a given GraphQL object
or interface type.
"""
base_names = [GRAPHQL_BASE_FIELD_CLASS]
additional_fields_typing = set()
class_def = generate_class_def(name=class_name, base_names=base_names)
for lineno, (org_name, field) in enumerate(
self._get_combined_fields(definition).items(), start=1
):
class_def = generate_class_def(
name=class_name, base_names=base_names, description=description
)
lineno = 0
for org_name, field in self._get_combined_fields(definition).items():
lineno += 1
name = process_name(
org_name, convert_to_snake_case=self.convert_to_snake_case
)
Expand All @@ -154,6 +158,12 @@ def _generate_class_def_body(
name, field_name, org_name, field, method_required, lineno
)
)
if field.description and not method_required:
lineno += 1
docstring = ast.Expr(
value=ast.Constant(field.description), lineno=lineno
)
class_def.body.append(docstring)

class_def.body.append(
self._generate_fields_method(
Expand Down Expand Up @@ -216,7 +226,10 @@ def _generate_class_field(
"""Handles the generation of field types."""
if getattr(field, "args") or method_required:
return self.generate_product_type_method(
name, field_name, getattr(field, "args")
name,
field_name,
getattr(field, "args"),
description=getattr(field, "description"),
)
return generate_ann_assign(
target=generate_name(name),
Expand Down Expand Up @@ -311,7 +324,11 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef:
)

def generate_product_type_method(
self, name: str, class_name: str, arguments: Optional[Dict[str, Any]] = None
self,
name: str,
class_name: str,
arguments: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
) -> ast.FunctionDef:
"""Generates a method for a product type."""
arguments = arguments or {}
Expand Down Expand Up @@ -351,6 +368,7 @@ def generate_product_type_method(
),
return_type=generate_name(f'"{class_name}"'),
decorator_list=[generate_name("classmethod")],
description=description,
)

def _get_suffix(
Expand Down
3 changes: 3 additions & 0 deletions ariadne_codegen/client_generators/custom_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def generate(self) -> ast.Module:
operation_name=name,
operation_args=field.args,
final_type=final_type,
description=field.description,
)
method_def.lineno = len(self._class_def.body) + 1
self._class_def.body.append(method_def)
Expand Down Expand Up @@ -115,6 +116,7 @@ def _generate_method(
operation_name: str,
operation_args,
final_type,
description: Optional[str] = None,
) -> ast.FunctionDef:
"""Generates a method definition for a given operation."""
(
Expand All @@ -141,6 +143,7 @@ def _generate_method(
name=str_to_snake_case(operation_name),
arguments=method_arguments,
return_type=generate_name(return_type_name),
description=description,
body=[
*arguments_body,
generate_return(
Expand Down
13 changes: 10 additions & 3 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,13 @@ def _parse_input_definition(
self, definition: GraphQLInputObjectType
) -> ast.ClassDef:
class_def = generate_class_def(
name=definition.name, base_names=[BASE_MODEL_CLASS_NAME]
name=definition.name,
base_names=[BASE_MODEL_CLASS_NAME],
description=definition.description,
)

for lineno, (org_name, field) in enumerate(definition.fields.items(), start=1):
lineno = 0
for org_name, field in definition.fields.items():
lineno += 1
name = process_name(
org_name,
convert_to_snake_case=self.convert_to_snake_case,
Expand Down Expand Up @@ -190,6 +193,10 @@ def _parse_input_definition(
field_implementation, input_field=field, field_name=org_name
)
class_def.body.append(field_implementation)
if field.description:
lineno += 1
docstring = ast.Expr(value=ast.Constant(value=field.description))
class_def.body.append(docstring)
self._save_dependencies(root_type=definition.name, field_type=field_type)

if self.plugin_manager:
Expand Down
12 changes: 11 additions & 1 deletion ariadne_codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,21 @@ def generate_class_def(
name: str,
base_names: Optional[List[str]] = None,
body: Optional[List[ast.stmt]] = None,
description: str = "",
) -> ast.ClassDef:
"""Generate class definition."""
bases = cast(
List[ast.expr], [ast.Name(id=name) for name in base_names] if base_names else []
)
body = body if body else []
if description:
docstring = ast.Expr(value=ast.Constant(value=description))
body.insert(0, docstring)
params: Dict[str, Any] = {
"name": name,
"bases": bases,
"keywords": [],
"body": body if body else [],
"body": body,
"decorator_list": [],
}
if sys.version_info >= (3, 12):
Expand Down Expand Up @@ -354,10 +359,15 @@ def generate_method_definition(
name: str,
arguments: ast.arguments,
return_type: Union[ast.Name, ast.Subscript],
description: str = "",
body: Optional[List[ast.stmt]] = None,
lineno: int = 1,
decorator_list: Optional[List[ast.expr]] = None,
) -> ast.FunctionDef:
body = body if body else [ast.Pass()]
if description:
docstring = ast.Expr(value=ast.Constant(value=description))
body.insert(0, docstring)
params: Dict[str, Any] = {
"name": name,
"args": arguments,
Expand Down