@@ -69,6 +69,20 @@ def get_param_model(annotated_type: Any) -> type[ParamModel] | ForwardRef | list
69
69
return base_type
70
70
71
71
72
+ def get_param_model_name (param_model : type [ParamModel ] | ForwardRef ) -> str :
73
+ """Get the model name
74
+
75
+ :param param_model: Param model. This can be a forward ref
76
+ """
77
+ if not is_param_model (param_model ):
78
+ raise ValueError (f"{ param_model } is not a param model" )
79
+
80
+ if isinstance (param_model , ForwardRef ):
81
+ return param_model .__forward_arg__
82
+ else :
83
+ return param_model .__name__
84
+
85
+
72
86
@lru_cache
73
87
def generate_model_name (field_name : str , field_type : str | Any ) -> str :
74
88
"""Generate model name from the given field
@@ -130,7 +144,7 @@ def create_model_from_param_def(
130
144
)
131
145
for inner_param_name , inner_param_obj in param_def .get ("properties" , {}).items ()
132
146
]
133
- alias_illegal_model_field_names (fields )
147
+ alias_illegal_model_field_names (model_name , fields )
134
148
return cast (
135
149
type [ParamModel ],
136
150
make_dataclass (
@@ -299,9 +313,10 @@ def visit(model_name: str) -> None:
299
313
return sorted (models , key = lambda x : sorted_models_names .index (x .__name__ ))
300
314
301
315
302
- def alias_illegal_model_field_names (model_fields : list [DataclassModelField ]) -> str :
316
+ def alias_illegal_model_field_names (model_name : str , model_fields : list [DataclassModelField ]) -> None :
303
317
"""Clean illegal model field name and annotate the field type with Alias class
304
318
319
+ :param model_name: Model name
305
320
:param model_fields: fields value to be passed to make_dataclass()
306
321
"""
307
322
@@ -315,7 +330,7 @@ def make_alias(name: str, param_type: Any) -> str:
315
330
else :
316
331
name = clean_obj_name (name )
317
332
# NOTE: The escaping of kwargs is already is handled in endpoint model
318
- reserved_param_names = [* get_supported_request_parameters () , "validate" ]
333
+ reserved_param_names = ["self" , "validate" , * get_supported_request_parameters () ]
319
334
if name in get_reserved_model_names () + reserved_param_names :
320
335
# The field name conflicts with one of reserved names
321
336
name += "_"
0 commit comments