Skip to content

Commit 379ba40

Browse files
authored
Merge pull request #281 from mirumee/277-proof-of-concept
Escape field arg names if they are conflicting with generated client method variables
2 parents c720473 + 7fcfd12 commit 379ba40

File tree

3 files changed

+542
-38
lines changed

3 files changed

+542
-38
lines changed

ariadne_codegen/client_generators/client.py

+86-38
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def add_method(
135135
arguments, arguments_dict = self.arguments_generator.generate(
136136
definition.variable_definitions
137137
)
138+
139+
variable_names = self.get_variable_names(arguments)
140+
138141
operation_name = definition.name.value if definition.name else ""
139142
if definition.operation == OperationType.SUBSCRIPTION:
140143
if not async_:
@@ -149,6 +152,7 @@ def add_method(
149152
arguments=arguments,
150153
arguments_dict=arguments_dict,
151154
operation_str=operation_str,
155+
variable_names=variable_names,
152156
)
153157
)
154158
elif async_:
@@ -159,6 +163,7 @@ def add_method(
159163
arguments_dict=arguments_dict,
160164
operation_str=operation_str,
161165
operation_name=operation_name,
166+
variable_names=variable_names,
162167
)
163168
else:
164169
method_def = self._generate_method(
@@ -168,6 +173,7 @@ def add_method(
168173
arguments_dict=arguments_dict,
169174
operation_str=operation_str,
170175
operation_name=operation_name,
176+
variable_names=variable_names,
171177
)
172178

173179
method_def.lineno = len(self._class_def.body) + 1
@@ -181,6 +187,23 @@ def add_method(
181187
generate_import_from(names=[return_type], from_=return_type_module, level=1)
182188
)
183189

190+
def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]:
191+
mapped_variable_names = [
192+
self._operation_str_variable,
193+
self._variables_dict_variable,
194+
self._response_variable,
195+
self._data_variable,
196+
]
197+
variable_names = {}
198+
argument_names = set(arg.arg for arg in arguments.args)
199+
200+
for variable in mapped_variable_names:
201+
variable_names[variable] = (
202+
f"_{variable}" if variable in argument_names else variable
203+
)
204+
205+
return variable_names
206+
184207
def _add_import(self, import_: Optional[ast.ImportFrom] = None):
185208
if not import_:
186209
return
@@ -197,6 +220,7 @@ def _generate_subscription_method_def(
197220
arguments: ast.arguments,
198221
arguments_dict: ast.Dict,
199222
operation_str: str,
223+
variable_names: Dict[str, str],
200224
) -> ast.AsyncFunctionDef:
201225
return generate_async_method_definition(
202226
name=name,
@@ -205,9 +229,11 @@ def _generate_subscription_method_def(
205229
value=generate_name(ASYNC_ITERATOR), slice_=generate_name(return_type)
206230
),
207231
body=[
208-
self._generate_operation_str_assign(operation_str, 1),
209-
self._generate_variables_assign(arguments_dict, 2),
210-
self._generate_async_generator_loop(operation_name, return_type, 3),
232+
self._generate_operation_str_assign(variable_names, operation_str, 1),
233+
self._generate_variables_assign(variable_names, arguments_dict, 2),
234+
self._generate_async_generator_loop(
235+
variable_names, operation_name, return_type, 3
236+
),
211237
],
212238
)
213239

@@ -219,17 +245,18 @@ def _generate_async_method(
219245
arguments_dict: ast.Dict,
220246
operation_str: str,
221247
operation_name: str,
248+
variable_names: Dict[str, str],
222249
) -> ast.AsyncFunctionDef:
223250
return generate_async_method_definition(
224251
name=name,
225252
arguments=arguments,
226253
return_type=generate_name(return_type),
227254
body=[
228-
self._generate_operation_str_assign(operation_str, 1),
229-
self._generate_variables_assign(arguments_dict, 2),
230-
self._generate_async_response_assign(operation_name, 3),
231-
self._generate_data_retrieval(),
232-
self._generate_return_parsed_obj(return_type),
255+
self._generate_operation_str_assign(variable_names, operation_str, 1),
256+
self._generate_variables_assign(variable_names, arguments_dict, 2),
257+
self._generate_async_response_assign(variable_names, operation_name, 3),
258+
self._generate_data_retrieval(variable_names),
259+
self._generate_return_parsed_obj(variable_names, return_type),
233260
],
234261
)
235262

@@ -241,25 +268,26 @@ def _generate_method(
241268
arguments_dict: ast.Dict,
242269
operation_str: str,
243270
operation_name: str,
271+
variable_names: Dict[str, str],
244272
) -> ast.FunctionDef:
245273
return generate_method_definition(
246274
name=name,
247275
arguments=arguments,
248276
return_type=generate_name(return_type),
249277
body=[
250-
self._generate_operation_str_assign(operation_str, 1),
251-
self._generate_variables_assign(arguments_dict, 2),
252-
self._generate_response_assign(operation_name, 3),
253-
self._generate_data_retrieval(),
254-
self._generate_return_parsed_obj(return_type),
278+
self._generate_operation_str_assign(variable_names, operation_str, 1),
279+
self._generate_variables_assign(variable_names, arguments_dict, 2),
280+
self._generate_response_assign(variable_names, operation_name, 3),
281+
self._generate_data_retrieval(variable_names),
282+
self._generate_return_parsed_obj(variable_names, return_type),
255283
],
256284
)
257285

258286
def _generate_operation_str_assign(
259-
self, operation_str: str, lineno: int = 1
287+
self, variable_names: Dict[str, str], operation_str: str, lineno: int = 1
260288
) -> ast.Assign:
261289
return generate_assign(
262-
targets=[self._operation_str_variable],
290+
targets=[variable_names[self._operation_str_variable]],
263291
value=generate_call(
264292
func=generate_name(self._gql_func_name),
265293
args=[
@@ -270,10 +298,10 @@ def _generate_operation_str_assign(
270298
)
271299

272300
def _generate_variables_assign(
273-
self, arguments_dict: ast.Dict, lineno: int = 1
301+
self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1
274302
) -> ast.AnnAssign:
275303
return generate_ann_assign(
276-
target=self._variables_dict_variable,
304+
target=variable_names[self._variables_dict_variable],
277305
annotation=generate_subscript(
278306
generate_name(DICT),
279307
generate_tuple([generate_name("str"), generate_name("object")]),
@@ -283,95 +311,115 @@ def _generate_variables_assign(
283311
)
284312

285313
def _generate_async_response_assign(
286-
self, operation_name: str, lineno: int = 1
314+
self, variable_names: Dict[str, str], operation_name: str, lineno: int = 1
287315
) -> ast.Assign:
288316
return generate_assign(
289-
targets=[self._response_variable],
317+
targets=[variable_names[self._response_variable]],
290318
value=generate_await(
291-
self._generate_execute_call(operation_name=operation_name)
319+
self._generate_execute_call(variable_names, operation_name)
292320
),
293321
lineno=lineno,
294322
)
295323

296324
def _generate_response_assign(
297-
self, operation_name: str, lineno: int = 1
325+
self,
326+
variable_names: Dict[str, str],
327+
operation_name: str,
328+
lineno: int = 1,
298329
) -> ast.Assign:
299330
return generate_assign(
300-
targets=[self._response_variable],
301-
value=self._generate_execute_call(operation_name=operation_name),
331+
targets=[variable_names[self._response_variable]],
332+
value=self._generate_execute_call(variable_names, operation_name),
302333
lineno=lineno,
303334
)
304335

305-
def _generate_execute_call(self, operation_name: str) -> ast.Call:
336+
def _generate_execute_call(
337+
self, variable_names: Dict[str, str], operation_name: str
338+
) -> ast.Call:
306339
return generate_call(
307340
func=generate_attribute(generate_name("self"), "execute"),
308341
keywords=[
309342
generate_keyword(
310-
value=generate_name(self._operation_str_variable), arg="query"
343+
value=generate_name(variable_names[self._operation_str_variable]),
344+
arg="query",
311345
),
312346
generate_keyword(
313347
value=generate_constant(operation_name), arg="operation_name"
314348
),
315349
generate_keyword(
316-
value=generate_name(self._variables_dict_variable), arg="variables"
350+
value=generate_name(variable_names[self._variables_dict_variable]),
351+
arg="variables",
317352
),
318353
generate_keyword(value=generate_name(KWARGS_NAMES)),
319354
],
320355
)
321356

322-
def _generate_data_retrieval(self) -> ast.Assign:
357+
def _generate_data_retrieval(self, variable_names: Dict[str, str]) -> ast.Assign:
323358
return generate_assign(
324-
targets=[self._data_variable],
359+
targets=[variable_names[self._data_variable]],
325360
value=generate_call(
326361
func=generate_attribute(value=generate_name("self"), attr="get_data"),
327-
args=[generate_name(self._response_variable)],
362+
args=[generate_name(variable_names[self._response_variable])],
328363
),
329364
)
330365

331-
def _generate_return_parsed_obj(self, return_type: str) -> ast.Return:
366+
def _generate_return_parsed_obj(
367+
self, variable_names: Dict[str, str], return_type: str
368+
) -> ast.Return:
332369
return generate_return(
333370
generate_call(
334371
func=generate_attribute(
335372
generate_name(return_type), MODEL_VALIDATE_METHOD
336373
),
337-
args=[generate_name(self._data_variable)],
374+
args=[generate_name(variable_names[self._data_variable])],
338375
)
339376
)
340377

341378
def _generate_async_generator_loop(
342-
self, operation_name: str, return_type: str, lineno: int = 1
379+
self,
380+
variable_names: Dict[str, str],
381+
operation_name: str,
382+
return_type: str,
383+
lineno: int = 1,
343384
) -> ast.AsyncFor:
344385
return generate_async_for(
345-
target=generate_name(self._data_variable),
386+
target=generate_name(variable_names[self._data_variable]),
346387
iter_=generate_call(
347388
func=generate_attribute(value=generate_name("self"), attr="execute_ws"),
348389
keywords=[
349390
generate_keyword(
350-
value=generate_name(self._operation_str_variable), arg="query"
391+
value=generate_name(
392+
variable_names[self._operation_str_variable]
393+
),
394+
arg="query",
351395
),
352396
generate_keyword(
353397
value=generate_constant(operation_name), arg="operation_name"
354398
),
355399
generate_keyword(
356-
value=generate_name(self._variables_dict_variable),
400+
value=generate_name(
401+
variable_names[self._variables_dict_variable]
402+
),
357403
arg="variables",
358404
),
359405
generate_keyword(value=generate_name(KWARGS_NAMES)),
360406
],
361407
),
362-
body=[self._generate_yield_parsed_obj(return_type)],
408+
body=[self._generate_yield_parsed_obj(variable_names, return_type)],
363409
lineno=lineno,
364410
)
365411

366-
def _generate_yield_parsed_obj(self, return_type: str) -> ast.Expr:
412+
def _generate_yield_parsed_obj(
413+
self, variable_names: Dict[str, str], return_type: str
414+
) -> ast.Expr:
367415
return generate_expr(
368416
generate_yield(
369417
generate_call(
370418
func=generate_attribute(
371419
value=generate_name(return_type),
372420
attr=MODEL_VALIDATE_METHOD,
373421
),
374-
args=[generate_name(self._data_variable)],
422+
args=[generate_name(variable_names[self._data_variable])],
375423
)
376424
)
377425
)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ disable = [
6969
"duplicate-code",
7070
"no-name-in-module",
7171
"too-many-locals",
72+
"too-many-lines",
7273
]
7374

7475
[tool.pytest.ini_options]

0 commit comments

Comments
 (0)