Skip to content

Commit b2b97d6

Browse files
Rename InputModel and as_input_model to CreateModel and as_create_model
1 parent 510755a commit b2b97d6

File tree

3 files changed

+25
-25
lines changed

3 files changed

+25
-25
lines changed

src/aiida/cmdline/commands/cmd_code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def verdi_code():
4040
def create_code(ctx: click.Context, cls: Code, **kwargs) -> None:
4141
"""Create a new `Code` instance."""
4242
try:
43-
model = cls.InputModel(**kwargs)
43+
model = cls.CreateModel(**kwargs)
4444
instance = cls.from_model(model) # type: ignore[arg-type]
4545
except (TypeError, ValueError) as exception:
4646
echo.echo_critical(f'Failed to create instance `{cls}`: {exception}')

src/aiida/cmdline/groups/dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def call_command(self, ctx: click.Context, cls: t.Any, non_interactive: bool, **
9797

9898
if hasattr(cls, 'Model'):
9999
# The plugin defines a pydantic model: use it to validate the provided arguments
100-
Model = cls.InputModel if hasattr(cls, 'InputModel') else cls.Model # noqa: N806
100+
Model = cls.CreateModel if hasattr(cls, 'CreateModel') else cls.Model # noqa: N806
101101
try:
102102
Model(**kwargs)
103103
except ValidationError as exception:
@@ -169,7 +169,7 @@ def list_options(self, entry_point: str) -> list[t.Callable[[FC], FC]]:
169169
options_spec = self.factory(entry_point).get_cli_options() # type: ignore[union-attr]
170170
return [self.create_option(*item) for item in options_spec]
171171

172-
Model = cls.InputModel if hasattr(cls, 'InputModel') else cls.Model # noqa: N806
172+
Model = cls.CreateModel if hasattr(cls, 'CreateModel') else cls.Model # noqa: N806
173173

174174
options_spec = {}
175175

src/aiida/orm/entities.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -206,41 +206,41 @@ def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
206206
cls.model_config['title'] = cls.__qualname__.replace('.', '')
207207

208208
@classmethod
209-
def as_input_model(cls: Type[EntityModelType]) -> Type[EntityModelType]:
210-
"""Return a derived model class with read-only fields removed.
209+
def as_create_model(cls: Type[EntityModelType]) -> Type[EntityModelType]:
210+
"""Return a derived creation model class with read-only fields removed.
211211
212212
This also removes any serializers/validators defined on those fields.
213213
214-
:return: The derived input model class.
214+
:return: The derived creation model class.
215215
"""
216216

217-
# Derive the input model from the original model
218-
new_name = cls.__qualname__.replace('.Model', 'InputModel')
219-
InputModel = create_model( # noqa: N806
217+
# Derive the creation model from the original model
218+
new_name = cls.__qualname__.replace('.Model', 'CreateModel')
219+
CreateModel = create_model( # noqa: N806
220220
new_name,
221221
__base__=cls,
222-
__doc__=f'Input version of {cls.__name__}.',
222+
__doc__=f'Creation version of {cls.__name__}.',
223223
)
224-
InputModel.__qualname__ = new_name
225-
InputModel.__module__ = cls.__module__
224+
CreateModel.__qualname__ = new_name
225+
CreateModel.__module__ = cls.__module__
226226

227227
# Identify read-only fields
228228
readonly_fields = [
229229
name
230-
for name, field in InputModel.model_fields.items()
230+
for name, field in CreateModel.model_fields.items()
231231
if hasattr(field, 'json_schema_extra')
232232
and isinstance(field.json_schema_extra, dict)
233233
and field.json_schema_extra.get('readOnly')
234234
]
235235

236236
# Remove read-only fields
237237
for name in readonly_fields:
238-
InputModel.model_fields.pop(name, None)
239-
if hasattr(InputModel, name):
240-
delattr(InputModel, name)
238+
CreateModel.model_fields.pop(name, None)
239+
if hasattr(CreateModel, name):
240+
delattr(CreateModel, name)
241241

242242
# Prune field validators/serializers referring to read-only fields
243-
decorators = InputModel.__pydantic_decorators__
243+
decorators = CreateModel.__pydantic_decorators__
244244

245245
def _prune_field_decorators(field_decorators: dict[str, Any]) -> dict[str, Any]:
246246
return {
@@ -252,15 +252,15 @@ def _prune_field_decorators(field_decorators: dict[str, Any]) -> dict[str, Any]:
252252
decorators.field_validators = _prune_field_decorators(decorators.field_validators)
253253
decorators.field_serializers = _prune_field_decorators(decorators.field_serializers)
254254

255-
return InputModel
255+
return CreateModel
256256

257257
@classproperty
258-
def InputModel(cls) -> Type[Model]: # noqa: N802, N805
259-
"""Return the input version of the model class for this entity.
258+
def CreateModel(cls) -> Type[Model]: # noqa: N802, N805
259+
"""Return the creation version of the model class for this entity.
260260
261-
:return: The input model class, with read-only fields removed.
261+
:return: The creation model class, with read-only fields removed.
262262
"""
263-
return cls.Model.as_input_model()
263+
return cls.Model.as_create_model()
264264

265265
@classmethod
266266
def model_to_orm_fields(cls) -> dict[str, FieldInfo]:
@@ -316,7 +316,7 @@ def to_model(
316316
with `exclude_to_orm=True`.
317317
:return: An instance of the entity's model class.
318318
"""
319-
Model = self.InputModel if unstored else self.Model # noqa: N806
319+
Model = self.CreateModel if unstored else self.Model # noqa: N806
320320
fields = self._collect_model_field_values(
321321
repository_path=repository_path,
322322
serialize_repository_content=serialize_repository_content,
@@ -387,7 +387,7 @@ def from_serialized(cls, serialized: dict[str, Any], unstored: bool = False) ->
387387
cls._logger.warning(
388388
'Serialization through pydantic is still an experimental feature and might break in future releases.'
389389
)
390-
Model = cls.InputModel if unstored else cls.Model # noqa: N806
390+
Model = cls.CreateModel if unstored else cls.Model # noqa: N806
391391
return cls.from_model(Model(**serialized))
392392

393393
@classproperty
@@ -533,7 +533,7 @@ def _collect_model_field_values(
533533
"""
534534
fields: dict[str, Any] = {}
535535

536-
Model = self.InputModel if unstored else self.Model # noqa: N806
536+
Model = self.CreateModel if unstored else self.Model # noqa: N806
537537

538538
for key, field in Model.model_fields.items():
539539
if skip_cli_excluded and get_metadata(field, 'exclude_from_cli'):

0 commit comments

Comments
 (0)