Skip to content

Escape field arg names if they are conflicting with generated client method variables #281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 4, 2024
124 changes: 86 additions & 38 deletions ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
@@ -135,6 +135,9 @@ def add_method(
arguments, arguments_dict = self.arguments_generator.generate(
definition.variable_definitions
)

variable_names = self.get_variable_names(arguments)

operation_name = definition.name.value if definition.name else ""
if definition.operation == OperationType.SUBSCRIPTION:
if not async_:
@@ -149,6 +152,7 @@ def add_method(
arguments=arguments,
arguments_dict=arguments_dict,
operation_str=operation_str,
variable_names=variable_names,
)
)
elif async_:
@@ -159,6 +163,7 @@ def add_method(
arguments_dict=arguments_dict,
operation_str=operation_str,
operation_name=operation_name,
variable_names=variable_names,
)
else:
method_def = self._generate_method(
@@ -168,6 +173,7 @@ def add_method(
arguments_dict=arguments_dict,
operation_str=operation_str,
operation_name=operation_name,
variable_names=variable_names,
)

method_def.lineno = len(self._class_def.body) + 1
@@ -181,6 +187,23 @@ def add_method(
generate_import_from(names=[return_type], from_=return_type_module, level=1)
)

def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]:
mapped_variable_names = [
self._operation_str_variable,
self._variables_dict_variable,
self._response_variable,
self._data_variable,
]
variable_names = {}
argument_names = set(arg.arg for arg in arguments.args)

for variable in mapped_variable_names:
variable_names[variable] = (
"_" + variable if variable in argument_names else variable
)

return variable_names

def _add_import(self, import_: Optional[ast.ImportFrom] = None):
if not import_:
return
@@ -197,6 +220,7 @@ def _generate_subscription_method_def(
arguments: ast.arguments,
arguments_dict: ast.Dict,
operation_str: str,
variable_names: Dict[str, str],
) -> ast.AsyncFunctionDef:
return generate_async_method_definition(
name=name,
@@ -205,9 +229,11 @@ def _generate_subscription_method_def(
value=generate_name(ASYNC_ITERATOR), slice_=generate_name(return_type)
),
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_async_generator_loop(operation_name, return_type, 3),
self._generate_operation_str_assign(variable_names, operation_str, 1),
self._generate_variables_assign(variable_names, arguments_dict, 2),
self._generate_async_generator_loop(
variable_names, operation_name, return_type, 3
),
],
)

@@ -219,17 +245,18 @@ def _generate_async_method(
arguments_dict: ast.Dict,
operation_str: str,
operation_name: str,
variable_names: Dict[str, str],
) -> ast.AsyncFunctionDef:
return generate_async_method_definition(
name=name,
arguments=arguments,
return_type=generate_name(return_type),
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_async_response_assign(operation_name, 3),
self._generate_data_retrieval(),
self._generate_return_parsed_obj(return_type),
self._generate_operation_str_assign(variable_names, operation_str, 1),
self._generate_variables_assign(variable_names, arguments_dict, 2),
self._generate_async_response_assign(variable_names, operation_name, 3),
self._generate_data_retrieval(variable_names),
self._generate_return_parsed_obj(variable_names, return_type),
],
)

@@ -241,25 +268,26 @@ def _generate_method(
arguments_dict: ast.Dict,
operation_str: str,
operation_name: str,
variable_names: Dict[str, str],
) -> ast.FunctionDef:
return generate_method_definition(
name=name,
arguments=arguments,
return_type=generate_name(return_type),
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_response_assign(operation_name, 3),
self._generate_data_retrieval(),
self._generate_return_parsed_obj(return_type),
self._generate_operation_str_assign(variable_names, operation_str, 1),
self._generate_variables_assign(variable_names, arguments_dict, 2),
self._generate_response_assign(variable_names, operation_name, 3),
self._generate_data_retrieval(variable_names),
self._generate_return_parsed_obj(variable_names, return_type),
],
)

def _generate_operation_str_assign(
self, operation_str: str, lineno: int = 1
self, variable_names: Dict[str, str], operation_str: str, lineno: int = 1
) -> ast.Assign:
return generate_assign(
targets=[self._operation_str_variable],
targets=[variable_names[self._operation_str_variable]],
value=generate_call(
func=generate_name(self._gql_func_name),
args=[
@@ -270,10 +298,10 @@ def _generate_operation_str_assign(
)

def _generate_variables_assign(
self, arguments_dict: ast.Dict, lineno: int = 1
self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1
) -> ast.AnnAssign:
return generate_ann_assign(
target=self._variables_dict_variable,
target=variable_names[self._variables_dict_variable],
annotation=generate_subscript(
generate_name(DICT),
generate_tuple([generate_name("str"), generate_name("object")]),
@@ -283,95 +311,115 @@ def _generate_variables_assign(
)

def _generate_async_response_assign(
self, operation_name: str, lineno: int = 1
self, variable_names: Dict[str, str], operation_name: str, lineno: int = 1
) -> ast.Assign:
return generate_assign(
targets=[self._response_variable],
targets=[variable_names[self._response_variable]],
value=generate_await(
self._generate_execute_call(operation_name=operation_name)
self._generate_execute_call(variable_names, operation_name)
),
lineno=lineno,
)

def _generate_response_assign(
self, operation_name: str, lineno: int = 1
self,
variable_names: Dict[str, str],
operation_name: str,
lineno: int = 1,
) -> ast.Assign:
return generate_assign(
targets=[self._response_variable],
value=self._generate_execute_call(operation_name=operation_name),
targets=[variable_names[self._response_variable]],
value=self._generate_execute_call(variable_names, operation_name),
lineno=lineno,
)

def _generate_execute_call(self, operation_name: str) -> ast.Call:
def _generate_execute_call(
self, variable_names: Dict[str, str], operation_name: str
) -> ast.Call:
return generate_call(
func=generate_attribute(generate_name("self"), "execute"),
keywords=[
generate_keyword(
value=generate_name(self._operation_str_variable), arg="query"
value=generate_name(variable_names[self._operation_str_variable]),
arg="query",
),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(self._variables_dict_variable), arg="variables"
value=generate_name(variable_names[self._variables_dict_variable]),
arg="variables",
),
generate_keyword(value=generate_name(KWARGS_NAMES)),
],
)

def _generate_data_retrieval(self) -> ast.Assign:
def _generate_data_retrieval(self, variable_names: Dict[str, str]) -> ast.Assign:
return generate_assign(
targets=[self._data_variable],
targets=[variable_names[self._data_variable]],
value=generate_call(
func=generate_attribute(value=generate_name("self"), attr="get_data"),
args=[generate_name(self._response_variable)],
args=[generate_name(variable_names[self._response_variable])],
),
)

def _generate_return_parsed_obj(self, return_type: str) -> ast.Return:
def _generate_return_parsed_obj(
self, variable_names: Dict[str, str], return_type: str
) -> ast.Return:
return generate_return(
generate_call(
func=generate_attribute(
generate_name(return_type), MODEL_VALIDATE_METHOD
),
args=[generate_name(self._data_variable)],
args=[generate_name(variable_names[self._data_variable])],
)
)

def _generate_async_generator_loop(
self, operation_name: str, return_type: str, lineno: int = 1
self,
variable_names: Dict[str, str],
operation_name: str,
return_type: str,
lineno: int = 1,
) -> ast.AsyncFor:
return generate_async_for(
target=generate_name(self._data_variable),
target=generate_name(variable_names[self._data_variable]),
iter_=generate_call(
func=generate_attribute(value=generate_name("self"), attr="execute_ws"),
keywords=[
generate_keyword(
value=generate_name(self._operation_str_variable), arg="query"
value=generate_name(
variable_names[self._operation_str_variable]
),
arg="query",
),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(self._variables_dict_variable),
value=generate_name(
variable_names[self._variables_dict_variable]
),
arg="variables",
),
generate_keyword(value=generate_name(KWARGS_NAMES)),
],
),
body=[self._generate_yield_parsed_obj(return_type)],
body=[self._generate_yield_parsed_obj(variable_names, return_type)],
lineno=lineno,
)

def _generate_yield_parsed_obj(self, return_type: str) -> ast.Expr:
def _generate_yield_parsed_obj(
self, variable_names: Dict[str, str], return_type: str
) -> ast.Expr:
return generate_expr(
generate_yield(
generate_call(
func=generate_attribute(
value=generate_name(return_type),
attr=MODEL_VALIDATE_METHOD,
),
args=[generate_name(self._data_variable)],
args=[generate_name(variable_names[self._data_variable])],
)
)
)
99 changes: 99 additions & 0 deletions tests/client_generators/test_client_generator.py
Original file line number Diff line number Diff line change
@@ -617,3 +617,102 @@ def test_add_method_triggers_generate_client_method_hook(
generator.generate()

assert mocked_plugin_manager.generate_client_method.called


def test_add_method_generates_correct_method_body_for_shadowed_variables(
base_client_import,
):
schema_str = """
schema { query: Query }
type Query { xyz(query: String!, variables: String!, response: String!, data: String!): String }
"""
query_str = """
query GetXyz($query: String!, $variables: String!, $response: String!, $data: String! ) {
xyz(query: $query, variables: $variables, response: $response, data: $data)
}
"""
generator = ClientGenerator(
base_client_import=base_client_import,
arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)),
)
method_name = "list_xyz"
return_type = "GetXyz"
return_type_module_name = method_name
expected_method_body = [
ast.Assign(
targets=[ast.Name(id="_query")],
value=ast.Call(
func=ast.Name("gql"),
keywords=[],
args=[[ast.Constant(value=l + "\n") for l in query_str.splitlines()]],
),
),
ast.AnnAssign(
target=ast.Name(id="_variables"),
annotation=ast.Subscript(
value=ast.Name(id="Dict"),
slice=ast.Tuple(elts=[ast.Name(id="str"), ast.Name(id="object")]),
),
value=ast.Dict(
keys=[
ast.Constant(value="query"),
ast.Constant(value="variables"),
ast.Constant(value="response"),
ast.Constant(value="data"),
],
values=[
ast.Name(id="query"),
ast.Name(id="variables"),
ast.Name(id="response"),
ast.Name(id="data"),
],
),
simple=1,
),
ast.Assign(
targets=[ast.Name(id="_response")],
value=ast.Call(
func=ast.Attribute(value=ast.Name(id="self"), attr="execute"),
args=[],
keywords=[
ast.keyword(arg="query", value=ast.Name(id="_query")),
ast.keyword(
arg="operation_name", value=ast.Constant(value="GetXyz")
),
ast.keyword(arg="variables", value=ast.Name(id="_variables")),
ast.keyword(value=ast.Name(id="kwargs")),
],
),
),
ast.Assign(
targets=[ast.Name(id="_data")],
value=ast.Call(
func=ast.Attribute(value=ast.Name(id="self"), attr="get_data"),
args=[ast.Name(id="_response")],
keywords=[],
),
),
ast.Return(
value=ast.Call(
func=ast.Attribute(value=ast.Name(id="GetXyz"), attr="model_validate"),
args=[ast.Name(id="_data")],
keywords=[],
)
),
]

generator.add_method(
definition=cast(OperationDefinitionNode, parse(query_str).definitions[0]),
name=method_name,
return_type=return_type,
return_type_module=return_type_module_name,
operation_str=query_str,
async_=False,
)
module = generator.generate()

class_def = get_class_def(module)
assert class_def
method_def = class_def.body[0]
assert isinstance(method_def, ast.FunctionDef)
assert compare_ast(method_def.body, expected_method_body)