diff --git a/ariadne_codegen/client_generators/custom_fields.py b/ariadne_codegen/client_generators/custom_fields.py index bfc324c..bac9371 100644 --- a/ariadne_codegen/client_generators/custom_fields.py +++ b/ariadne_codegen/client_generators/custom_fields.py @@ -114,6 +114,7 @@ def _parse_object_type_definitions( class_def = self._generate_class_def_body( definition=graphql_type, class_name=f"{graphql_type.name}{self._get_suffix(graphql_type)}", + description=graphql_type.description, ) if isinstance(graphql_type, GraphQLInterfaceType): class_def.body.append( @@ -129,6 +130,7 @@ def _generate_class_def_body( self, definition: Union[GraphQLObjectType, GraphQLInterfaceType], class_name: str, + description: Optional[str] = None, ) -> ast.ClassDef: """ Generates the body of a class definition for a given GraphQL object @@ -136,10 +138,12 @@ def _generate_class_def_body( """ base_names = [GRAPHQL_BASE_FIELD_CLASS] additional_fields_typing = set() - class_def = generate_class_def(name=class_name, base_names=base_names) - for lineno, (org_name, field) in enumerate( - self._get_combined_fields(definition).items(), start=1 - ): + class_def = generate_class_def( + name=class_name, base_names=base_names, description=description + ) + lineno = 0 + for org_name, field in self._get_combined_fields(definition).items(): + lineno += 1 name = process_name( org_name, convert_to_snake_case=self.convert_to_snake_case ) @@ -154,6 +158,12 @@ def _generate_class_def_body( name, field_name, org_name, field, method_required, lineno ) ) + if field.description and not method_required: + lineno += 1 + docstring = ast.Expr( + value=ast.Constant(field.description), lineno=lineno + ) + class_def.body.append(docstring) class_def.body.append( self._generate_fields_method( @@ -216,7 +226,10 @@ def _generate_class_field( """Handles the generation of field types.""" if getattr(field, "args") or method_required: return self.generate_product_type_method( - name, field_name, getattr(field, "args") + name, + field_name, + getattr(field, "args"), + description=getattr(field, "description"), ) return generate_ann_assign( target=generate_name(name), @@ -311,7 +324,11 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef: ) def generate_product_type_method( - self, name: str, class_name: str, arguments: Optional[Dict[str, Any]] = None + self, + name: str, + class_name: str, + arguments: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, ) -> ast.FunctionDef: """Generates a method for a product type.""" arguments = arguments or {} @@ -351,6 +368,7 @@ def generate_product_type_method( ), return_type=generate_name(f'"{class_name}"'), decorator_list=[generate_name("classmethod")], + description=description, ) def _get_suffix( diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py index f9ee64f..f2c861e 100644 --- a/ariadne_codegen/client_generators/custom_operation.py +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -84,6 +84,7 @@ def generate(self) -> ast.Module: operation_name=name, operation_args=field.args, final_type=final_type, + description=field.description, ) method_def.lineno = len(self._class_def.body) + 1 self._class_def.body.append(method_def) @@ -115,6 +116,7 @@ def _generate_method( operation_name: str, operation_args, final_type, + description: Optional[str] = None, ) -> ast.FunctionDef: """Generates a method definition for a given operation.""" ( @@ -141,6 +143,7 @@ def _generate_method( name=str_to_snake_case(operation_name), arguments=method_arguments, return_type=generate_name(return_type_name), + description=description, body=[ *arguments_body, generate_return( diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index 0a5797d..3be20b8 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -157,10 +157,13 @@ def _parse_input_definition( self, definition: GraphQLInputObjectType ) -> ast.ClassDef: class_def = generate_class_def( - name=definition.name, base_names=[BASE_MODEL_CLASS_NAME] + name=definition.name, + base_names=[BASE_MODEL_CLASS_NAME], + description=definition.description, ) - - for lineno, (org_name, field) in enumerate(definition.fields.items(), start=1): + lineno = 0 + for org_name, field in definition.fields.items(): + lineno += 1 name = process_name( org_name, convert_to_snake_case=self.convert_to_snake_case, @@ -190,6 +193,10 @@ def _parse_input_definition( field_implementation, input_field=field, field_name=org_name ) class_def.body.append(field_implementation) + if field.description: + lineno += 1 + docstring = ast.Expr(value=ast.Constant(value=field.description)) + class_def.body.append(docstring) self._save_dependencies(root_type=definition.name, field_type=field_type) if self.plugin_manager: diff --git a/ariadne_codegen/codegen.py b/ariadne_codegen/codegen.py index 3d6cd84..e361156 100644 --- a/ariadne_codegen/codegen.py +++ b/ariadne_codegen/codegen.py @@ -111,16 +111,21 @@ def generate_class_def( name: str, base_names: Optional[List[str]] = None, body: Optional[List[ast.stmt]] = None, + description: str = "", ) -> ast.ClassDef: """Generate class definition.""" bases = cast( List[ast.expr], [ast.Name(id=name) for name in base_names] if base_names else [] ) + body = body if body else [] + if description: + docstring = ast.Expr(value=ast.Constant(value=description)) + body.insert(0, docstring) params: Dict[str, Any] = { "name": name, "bases": bases, "keywords": [], - "body": body if body else [], + "body": body, "decorator_list": [], } if sys.version_info >= (3, 12): @@ -354,10 +359,15 @@ def generate_method_definition( name: str, arguments: ast.arguments, return_type: Union[ast.Name, ast.Subscript], + description: str = "", body: Optional[List[ast.stmt]] = None, lineno: int = 1, decorator_list: Optional[List[ast.expr]] = None, ) -> ast.FunctionDef: + body = body if body else [ast.Pass()] + if description: + docstring = ast.Expr(value=ast.Constant(value=description)) + body.insert(0, docstring) params: Dict[str, Any] = { "name": name, "args": arguments,