|
| 1 | +import ast |
| 2 | +from typing import Any, Dict, List, Optional, Tuple, Union, cast |
| 3 | + |
| 4 | +from graphql import ( |
| 5 | + GraphQLEnumType, |
| 6 | + GraphQLInputObjectType, |
| 7 | + GraphQLInterfaceType, |
| 8 | + GraphQLNonNull, |
| 9 | + GraphQLObjectType, |
| 10 | + GraphQLScalarType, |
| 11 | + GraphQLUnionType, |
| 12 | +) |
| 13 | + |
| 14 | +from ..codegen import ( |
| 15 | + generate_ann_assign, |
| 16 | + generate_annotation_name, |
| 17 | + generate_arg, |
| 18 | + generate_arguments, |
| 19 | + generate_assign, |
| 20 | + generate_call, |
| 21 | + generate_comp, |
| 22 | + generate_constant, |
| 23 | + generate_dict, |
| 24 | + generate_import_from, |
| 25 | + generate_keyword, |
| 26 | + generate_name, |
| 27 | + generate_subscript, |
| 28 | + generate_tuple, |
| 29 | +) |
| 30 | +from ..exceptions import ParsingError |
| 31 | +from ..plugins.manager import PluginManager |
| 32 | +from ..utils import process_name |
| 33 | +from .constants import ( |
| 34 | + ANY, |
| 35 | + BASE_MODEL_FILE_PATH, |
| 36 | + DICT, |
| 37 | + INPUT_SCALARS_MAP, |
| 38 | + UPLOAD_CLASS_NAME, |
| 39 | +) |
| 40 | +from .custom_generator_utils import get_final_type |
| 41 | +from .scalars import ScalarData, generate_scalar_imports |
| 42 | + |
| 43 | + |
| 44 | +class ArgumentGenerator: |
| 45 | + """Generates method arguments for GraphQL fields.""" |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + custom_scalars: Dict[str, ScalarData], |
| 50 | + convert_to_snake_case: bool, |
| 51 | + plugin_manager: Optional[PluginManager] = None, |
| 52 | + ) -> None: |
| 53 | + self.custom_scalars = custom_scalars |
| 54 | + self.convert_to_snake_case = convert_to_snake_case |
| 55 | + self.plugin_manager = plugin_manager |
| 56 | + self.imports: List[ast.ImportFrom] = [] |
| 57 | + self._used_custom_scalars: List[str] = [] |
| 58 | + |
| 59 | + def _add_import(self, import_: Optional[ast.ImportFrom] = None) -> None: |
| 60 | + """Adds an import statement to the list of imports.""" |
| 61 | + if import_: |
| 62 | + if self.plugin_manager: |
| 63 | + import_ = self.plugin_manager.generate_client_import(import_) |
| 64 | + if import_.names: |
| 65 | + self.imports.append(import_) |
| 66 | + |
| 67 | + def generate_arguments( |
| 68 | + self, operation_args: Dict[str, Any] |
| 69 | + ) -> Tuple[ast.arguments, List[ast.expr], List[ast.expr]]: |
| 70 | + """Generates method arguments from operation arguments.""" |
| 71 | + cls_arg = generate_arg(name="cls") |
| 72 | + args: List[ast.arg] = [] |
| 73 | + kw_only_args: List[ast.arg] = [] |
| 74 | + kw_defaults: List[ast.expr] = [] |
| 75 | + return_arguments_keys: List[ast.expr] = [] |
| 76 | + return_arguments_values: List[ast.expr] = [] |
| 77 | + |
| 78 | + for arg_name, arg_value in operation_args.items(): |
| 79 | + final_type = get_final_type(arg_value) |
| 80 | + is_required = isinstance(arg_value.type, GraphQLNonNull) |
| 81 | + name = process_name( |
| 82 | + arg_name, convert_to_snake_case=self.convert_to_snake_case |
| 83 | + ) |
| 84 | + annotation, used_custom_scalar = self._parse_graphql_type_name( |
| 85 | + final_type, not is_required |
| 86 | + ) |
| 87 | + |
| 88 | + self._accumulate_method_arguments( |
| 89 | + args, kw_only_args, kw_defaults, name, annotation, is_required |
| 90 | + ) |
| 91 | + self._accumulate_return_arguments( |
| 92 | + return_arguments_keys, |
| 93 | + return_arguments_values, |
| 94 | + arg_name, |
| 95 | + name, |
| 96 | + final_type, |
| 97 | + is_required, |
| 98 | + used_custom_scalar, |
| 99 | + ) |
| 100 | + |
| 101 | + method_arguments = self._assemble_method_arguments( |
| 102 | + cls_arg, args, kw_only_args, kw_defaults |
| 103 | + ) |
| 104 | + return method_arguments, return_arguments_keys, return_arguments_values |
| 105 | + |
| 106 | + def _accumulate_method_arguments( |
| 107 | + self, |
| 108 | + args: List[ast.arg], |
| 109 | + kw_only_args: List[ast.arg], |
| 110 | + kw_defaults: List[ast.expr], |
| 111 | + name: str, |
| 112 | + annotation: Optional[Union[ast.Name, ast.Subscript]], |
| 113 | + is_required: bool, |
| 114 | + ) -> None: |
| 115 | + """Accumulates method arguments.""" |
| 116 | + if is_required: |
| 117 | + args.append(generate_arg(name=name, annotation=annotation)) |
| 118 | + else: |
| 119 | + kw_only_args.append(generate_arg(name=name, annotation=annotation)) |
| 120 | + kw_defaults.append(generate_constant(value=None)) |
| 121 | + |
| 122 | + def _accumulate_return_arguments( |
| 123 | + self, |
| 124 | + return_arguments_keys: List[ast.expr], |
| 125 | + return_arguments_values: List[ast.expr], |
| 126 | + arg_name: str, |
| 127 | + name: str, |
| 128 | + final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], |
| 129 | + is_required: bool, |
| 130 | + used_custom_scalar: Optional[str], |
| 131 | + ) -> None: |
| 132 | + """Accumulates return arguments.""" |
| 133 | + constant_value = f"{final_type.name}!" if is_required else final_type.name |
| 134 | + return_arg_dict_value = self._generate_return_arg_value( |
| 135 | + name, used_custom_scalar |
| 136 | + ) |
| 137 | + |
| 138 | + return_arguments_keys.append(generate_constant(arg_name)) |
| 139 | + return_arguments_values.append( |
| 140 | + generate_dict( |
| 141 | + keys=[generate_constant("type"), generate_constant("value")], |
| 142 | + values=[generate_constant(constant_value), return_arg_dict_value], |
| 143 | + ) |
| 144 | + ) |
| 145 | + |
| 146 | + def _generate_return_arg_value( |
| 147 | + self, name: str, used_custom_scalar: Optional[str] |
| 148 | + ) -> Union[ast.Call, ast.Name]: |
| 149 | + """Generates the return argument value.""" |
| 150 | + return_arg_dict_value: Union[ast.Call, ast.Name] = generate_name(name) |
| 151 | + |
| 152 | + if used_custom_scalar: |
| 153 | + self._used_custom_scalars.append(used_custom_scalar) |
| 154 | + scalar_data = self.custom_scalars[used_custom_scalar] |
| 155 | + if scalar_data.serialize_name: |
| 156 | + return_arg_dict_value = generate_call( |
| 157 | + func=generate_name(scalar_data.serialize_name), |
| 158 | + args=[generate_name(name)], |
| 159 | + ) |
| 160 | + |
| 161 | + return return_arg_dict_value |
| 162 | + |
| 163 | + def _assemble_method_arguments( |
| 164 | + self, |
| 165 | + cls_arg: ast.arg, |
| 166 | + args: List[ast.arg], |
| 167 | + kw_only_args: List[ast.arg], |
| 168 | + kw_defaults: List[ast.expr], |
| 169 | + ) -> ast.arguments: |
| 170 | + """Assembles method arguments.""" |
| 171 | + return generate_arguments( |
| 172 | + args=[cls_arg, *args], |
| 173 | + kwonlyargs=kw_only_args, |
| 174 | + kw_defaults=kw_defaults, # type: ignore |
| 175 | + ) |
| 176 | + |
| 177 | + def _parse_graphql_type_name( |
| 178 | + self, |
| 179 | + type_: Union[GraphQLScalarType, GraphQLInputObjectType, GraphQLEnumType], |
| 180 | + nullable: bool = True, |
| 181 | + ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: |
| 182 | + """Parses the GraphQL type name and determines if it is a custom scalar.""" |
| 183 | + name = type_.name |
| 184 | + used_custom_scalar = None |
| 185 | + if isinstance(type_, GraphQLInputObjectType): |
| 186 | + self._add_import( |
| 187 | + generate_import_from(names=[name], from_="input_types", level=1) |
| 188 | + ) |
| 189 | + elif isinstance(type_, GraphQLEnumType): |
| 190 | + self._add_import(generate_import_from(names=[name], level=1)) |
| 191 | + elif isinstance(type_, GraphQLScalarType): |
| 192 | + if name not in self.custom_scalars: |
| 193 | + name = INPUT_SCALARS_MAP.get(name, ANY) |
| 194 | + if name == UPLOAD_CLASS_NAME: |
| 195 | + self._add_import( |
| 196 | + generate_import_from( |
| 197 | + names=[UPLOAD_CLASS_NAME], |
| 198 | + from_=BASE_MODEL_FILE_PATH.stem, |
| 199 | + level=1, |
| 200 | + ) |
| 201 | + ) |
| 202 | + else: |
| 203 | + used_custom_scalar = name |
| 204 | + name = self.custom_scalars[name].type_name |
| 205 | + self._used_custom_scalars.append(used_custom_scalar) |
| 206 | + else: |
| 207 | + raise ParsingError(f"Incorrect argument type {name}") |
| 208 | + return generate_annotation_name(name, nullable), used_custom_scalar |
| 209 | + |
| 210 | + def add_custom_scalar_imports(self) -> None: |
| 211 | + """Adds imports for custom scalars used in the schema.""" |
| 212 | + for custom_scalar_name in self._used_custom_scalars: |
| 213 | + scalar_data = self.custom_scalars[custom_scalar_name] |
| 214 | + for import_ in generate_scalar_imports(scalar_data): |
| 215 | + self._add_import(import_) |
| 216 | + |
| 217 | + def generate_clear_arguments_section( |
| 218 | + self, |
| 219 | + return_arguments_keys: List[ast.expr], |
| 220 | + return_arguments_values: List[ast.expr], |
| 221 | + ) -> Tuple[List[ast.stmt], List[ast.keyword]]: |
| 222 | + arguments_body = [ |
| 223 | + generate_ann_assign( |
| 224 | + "arguments", |
| 225 | + generate_subscript( |
| 226 | + generate_name(DICT), |
| 227 | + generate_tuple( |
| 228 | + [ |
| 229 | + generate_name("str"), |
| 230 | + generate_subscript( |
| 231 | + generate_name(DICT), |
| 232 | + generate_tuple( |
| 233 | + [ |
| 234 | + generate_name("str"), |
| 235 | + generate_name(ANY), |
| 236 | + ] |
| 237 | + ), |
| 238 | + ), |
| 239 | + ] |
| 240 | + ), |
| 241 | + ), |
| 242 | + generate_dict( |
| 243 | + return_arguments_keys, |
| 244 | + return_arguments_values, # type: ignore |
| 245 | + ), |
| 246 | + ), |
| 247 | + generate_assign( |
| 248 | + ["cleared_arguments"], |
| 249 | + ast.DictComp( |
| 250 | + key=generate_name("key"), |
| 251 | + value=generate_name("value"), |
| 252 | + generators=[ |
| 253 | + generate_comp( |
| 254 | + target="key, value", |
| 255 | + iter_="arguments.items()", |
| 256 | + ifs=cast( |
| 257 | + List[ast.expr], |
| 258 | + [ |
| 259 | + ast.Compare( |
| 260 | + left=generate_subscript( |
| 261 | + value=generate_name("value"), |
| 262 | + slice_=generate_constant("value"), |
| 263 | + ), |
| 264 | + ops=[ast.IsNot()], |
| 265 | + comparators=[generate_constant(None)], |
| 266 | + ) |
| 267 | + ], |
| 268 | + ), |
| 269 | + ) |
| 270 | + ], |
| 271 | + ), |
| 272 | + ), |
| 273 | + ] |
| 274 | + arguments_keyword = [ |
| 275 | + generate_keyword(arg="arguments", value=generate_name("cleared_arguments")) |
| 276 | + ] |
| 277 | + return arguments_body, arguments_keyword |
0 commit comments