Skip to content

Commit ada7f96

Browse files
Add Model class property on Entity
1 parent dea30ff commit ada7f96

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+227
-232
lines changed

src/aiida/cmdline/groups/dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def call_command(self, ctx: click.Context, cls: t.Any, non_interactive: bool, **
9595
"""Call the ``command`` after validating the provided inputs."""
9696
from pydantic import ValidationError
9797

98-
if hasattr(cls, 'Model') or hasattr(cls, '_CLS_MODEL'):
98+
if hasattr(cls, 'Model'):
9999
# The plugin defines a pydantic model: use it to validate the provided arguments
100100
Model = cls.InputModel if hasattr(cls, 'InputModel') else cls.Model # noqa: N806
101101
try:
@@ -158,7 +158,7 @@ def list_options(self, entry_point: str) -> list[t.Callable[[FC], FC]]:
158158

159159
cls = self.factory(entry_point)
160160

161-
if not (hasattr(cls, 'Model') or hasattr(cls, '_CLS_MODEL')):
161+
if not hasattr(cls, 'Model'):
162162
from aiida.common.warnings import warn_deprecation
163163

164164
warn_deprecation(

src/aiida/orm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@
4545
'CifData',
4646
'Code',
4747
'CodeEntityLoader',
48-
'Collection',
4948
'Comment',
5049
'Computer',
5150
'ComputerEntityLoader',
5251
'ContainerizedCode',
5352
'Data',
5453
'Dict',
5554
'Entity',
55+
'EntityCollection',
5656
'EntityExtras',
5757
'EntityTypes',
5858
'EnumData',

src/aiida/orm/authinfos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
__all__ = ('AuthInfo',)
3030

3131

32-
class AuthInfoCollection(entities.Collection['AuthInfo']):
32+
class AuthInfoCollection(entities.EntityCollection['AuthInfo']):
3333
"""The collection of `AuthInfo` entries."""
3434

3535
@staticmethod
@@ -44,7 +44,7 @@ def delete(self, pk: int) -> None:
4444
self._backend.authinfos.delete(pk)
4545

4646

47-
class Model(entities.Model):
47+
class AuthInfoModel(entities.EntityModel):
4848
computer: int = MetadataField(
4949
description='The PK of the computer',
5050
is_attribute=False,
@@ -74,11 +74,11 @@ class Model(entities.Model):
7474
)
7575

7676

77-
class AuthInfo(entities.Entity['BackendAuthInfo', AuthInfoCollection, Model]):
77+
class AuthInfo(entities.Entity['BackendAuthInfo', AuthInfoCollection, AuthInfoModel]):
7878
"""ORM class that models the authorization information that allows a `User` to connect to a `Computer`."""
7979

8080
_CLS_COLLECTION = AuthInfoCollection
81-
_CLS_MODEL = Model
81+
_CLS_MODEL = AuthInfoModel
8282

8383
PROPERTY_WORKDIR = 'workdir'
8484

src/aiida/orm/comments.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
__all__ = ('Comment',)
2929

3030

31-
class CommentCollection(entities.Collection['Comment']):
31+
class CommentCollection(entities.EntityCollection['Comment']):
3232
"""The collection of Comment entries."""
3333

3434
@staticmethod
@@ -65,7 +65,7 @@ def delete_many(self, filters: dict) -> List[int]:
6565
return self._backend.comments.delete_many(filters)
6666

6767

68-
class Model(entities.Model):
68+
class CommentModel(entities.EntityModel):
6969
uuid: UUID = MetadataField(
7070
description='The UUID of the comment',
7171
is_attribute=False,
@@ -99,11 +99,11 @@ class Model(entities.Model):
9999
)
100100

101101

102-
class Comment(entities.Entity['BackendComment', CommentCollection, Model]):
102+
class Comment(entities.Entity['BackendComment', CommentCollection, CommentModel]):
103103
"""Base class to map a DbComment that represents a comment attached to a certain Node."""
104104

105105
_CLS_COLLECTION = CommentCollection
106-
_CLS_MODEL = Model
106+
_CLS_MODEL = CommentModel
107107

108108
def __init__(
109109
self, node: 'Node', user: 'User', content: Optional[str] = None, backend: Optional['StorageBackend'] = None

src/aiida/orm/computers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
__all__ = ('Computer',)
3232

3333

34-
class ComputerCollection(entities.Collection['Computer']):
34+
class ComputerCollection(entities.EntityCollection['Computer']):
3535
"""The collection of Computer entries."""
3636

3737
@staticmethod
@@ -64,7 +64,7 @@ def delete(self, pk: int) -> None:
6464
return self._backend.computers.delete(pk)
6565

6666

67-
class Model(entities.Model):
67+
class ComputerModel(entities.EntityModel):
6868
uuid: UUID = MetadataField(
6969
description='The UUID of the computer',
7070
is_attribute=False,
@@ -98,7 +98,7 @@ class Model(entities.Model):
9898
)
9999

100100

101-
class Computer(entities.Entity['BackendComputer', ComputerCollection, Model]):
101+
class Computer(entities.Entity['BackendComputer', ComputerCollection, ComputerModel]):
102102
"""Computer entity."""
103103

104104
_logger = AIIDA_LOGGER.getChild('orm.computers')
@@ -109,7 +109,7 @@ class Computer(entities.Entity['BackendComputer', ComputerCollection, Model]):
109109
PROPERTY_SHEBANG = 'shebang'
110110

111111
_CLS_COLLECTION = ComputerCollection
112-
_CLS_MODEL = Model
112+
_CLS_MODEL = ComputerModel
113113

114114
def __init__(
115115
self,

src/aiida/orm/entities.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder
3636

3737
__all__ = (
38-
'Collection',
3938
'Entity',
39+
'EntityCollection',
40+
'EntityModel',
4041
'EntityTypes',
41-
'Model',
4242
)
4343

44-
CollectionType = TypeVar('CollectionType', bound='Collection[Any]')
45-
EntityModelType = TypeVar('EntityModelType', bound='Model')
44+
CollectionType = TypeVar('CollectionType', bound='EntityCollection[Any]')
45+
EntityModelType = TypeVar('EntityModelType', bound='EntityModel')
4646
EntityType = TypeVar('EntityType', bound='Entity[Any,Any,Any]')
4747
BackendEntityType = TypeVar('BackendEntityType', bound='BackendEntity')
4848

@@ -61,7 +61,7 @@ class EntityTypes(Enum):
6161
GROUP_NODE = 'group_node'
6262

6363

64-
class Collection(abc.ABC, Generic[EntityType]):
64+
class EntityCollection(abc.ABC, Generic[EntityType]):
6565
"""Container class that represents the collection of objects of a particular entity type."""
6666

6767
@staticmethod
@@ -184,7 +184,7 @@ def count(self, filters: Optional['FilterType'] = None) -> int:
184184
return self.query(filters=filters).count()
185185

186186

187-
class Model(BaseModel, defer_build=True):
187+
class EntityModel(BaseModel, defer_build=True):
188188
model_config = ConfigDict(extra='forbid')
189189

190190
pk: int = MetadataField(
@@ -194,16 +194,6 @@ class Model(BaseModel, defer_build=True):
194194
exclude_from_cli=True,
195195
)
196196

197-
@classmethod
198-
def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
199-
"""Sets the JSON schema title of the model.
200-
201-
The qualified name of the class is used, with dots removed. For example, `Node.Model` becomes `NodeModel`
202-
in the JSON schema.
203-
"""
204-
super().__pydantic_init_subclass__(**kwargs)
205-
cls.model_config['title'] = cls.__qualname__.replace('.', '')
206-
207197
@classmethod
208198
def as_input_model(cls: Type[EntityModelType]) -> Type[EntityModelType]:
209199
"""Return a derived model class with read-only fields removed.
@@ -214,7 +204,7 @@ def as_input_model(cls: Type[EntityModelType]) -> Type[EntityModelType]:
214204
"""
215205

216206
# Derive the input model from the original model
217-
new_name = cls.__qualname__.replace('.Model', 'InputModel')
207+
new_name = cls.__qualname__.replace('Model', 'InputModel')
218208
InputModel = create_model( # noqa: N806
219209
new_name,
220210
__base__=cls,
@@ -257,10 +247,18 @@ def _prune_field_decorators(field_decorators: dict[str, Any]) -> dict[str, Any]:
257247
class Entity(abc.ABC, Generic[BackendEntityType, CollectionType, EntityModelType], metaclass=EntityFieldMeta):
258248
"""An AiiDA entity"""
259249

260-
_CLS_COLLECTION: type[CollectionType] = Collection # type: ignore[assignment]
261-
_CLS_MODEL: type[EntityModelType] = Model # type: ignore[assignment]
250+
_CLS_COLLECTION: type[CollectionType] = EntityCollection # type: ignore[assignment]
251+
_CLS_MODEL: type[EntityModelType] = EntityModel # type: ignore[assignment]
262252
_logger = log.AIIDA_LOGGER.getChild('orm.entities')
263253

254+
@classproperty
255+
def Model(cls) -> Type[EntityModelType]: # noqa: N802, N805
256+
"""Return the model class for this entity.
257+
258+
:return: The model class.
259+
"""
260+
return cls._CLS_MODEL
261+
264262
@classproperty
265263
def InputModel(cls) -> Type[EntityModelType]: # noqa: N802, N805
266264
"""Return the input version of the model class for this entity.
@@ -272,9 +270,7 @@ def InputModel(cls) -> Type[EntityModelType]: # noqa: N802, N805
272270
@classmethod
273271
def model_to_orm_fields(cls) -> dict[str, FieldInfo]:
274272
return {
275-
key: field
276-
for key, field in cls._CLS_MODEL.model_fields.items()
277-
if not get_metadata(field, 'exclude_to_orm')
273+
key: field for key, field in cls.Model.model_fields.items() if not get_metadata(field, 'exclude_to_orm')
278274
}
279275

280276
@classmethod
@@ -325,7 +321,7 @@ def to_model(
325321
with `exclude_to_orm=True`.
326322
:return: An instance of the entity's model class.
327323
"""
328-
Model = self.InputModel if unstored else self._CLS_MODEL # noqa: N806
324+
Model = self.InputModel if unstored else self.Model # noqa: N806
329325
fields = self._collect_model_field_values(
330326
repository_path=repository_path,
331327
serialize_repository_content=serialize_repository_content,
@@ -396,7 +392,7 @@ def from_serialized(cls, serialized: dict[str, Any], unstored: bool = False) ->
396392
cls._logger.warning(
397393
'Serialization through pydantic is still an experimental feature and might break in future releases.'
398394
)
399-
Model = cls.InputModel if unstored else cls._CLS_MODEL # noqa: N806
395+
Model = cls.InputModel if unstored else cls.Model # noqa: N806
400396
return cls.from_model(Model(**serialized))
401397

402398
@classproperty
@@ -542,7 +538,7 @@ def _collect_model_field_values(
542538
"""
543539
fields: dict[str, Any] = {}
544540

545-
Model = self.InputModel if unstored else self._CLS_MODEL # noqa: N806
541+
Model = self.InputModel if unstored else self.Model # noqa: N806
546542

547543
for key, field in Model.model_fields.items():
548544
if skip_cli_excluded and get_metadata(field, 'exclude_from_cli'):

src/aiida/orm/fields.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -410,65 +410,65 @@ def __init__(cls, name, bases, classdict):
410410

411411
fields = {}
412412

413-
# If the class has an attribute ``_CLS_MODEL`` that is a subclass of :class:`pydantic.BaseModel`, parse the
414-
# model fields to build up the ``fields`` class attribute, which is used to allow specifying ``QueryBuilder``
415-
# filters programmatically.
416-
if hasattr(cls, '_CLS_MODEL') and issubclass(cls._CLS_MODEL, BaseModel):
417-
# If the class itself directly specifies the ``_CLS_MODEL`` attribute, check that it is valid. Here, the
418-
# check ``cls.__dict__`` is used instead of ``hasattr`` as the former only returns true if the class itself
413+
# If the class has an attribute ``Model`` that is a subclass of :class:`pydantic.BaseModel`, parse the model
414+
# fields to build up the ``fields`` class attribute, which is used to allow specifying ``QueryBuilder`` filters
415+
# programmatically.
416+
if hasattr(cls, 'Model') and issubclass(cls.Model, BaseModel):
417+
# If the class itself directly specifies the ``Model`` attribute, check that it is valid. Here, the check
418+
# ``cls.__dict__`` is used instead of ``hasattr`` as the former only returns true if the class itself
419419
# defines the attribute and does not just inherit it from a base class. In that case, this check will
420420
# already have been performed for that subclass.
421421

422-
# When a class defines a ``_CLS_MODEL``, the following check ensures that the model inherits from the same
423-
# bases as the class containing the attribute itself. For example, if ``cls`` inherits from ``ClassA`` and
424-
# ``ClassB`` that each define a ``_CLS_MODEL``, the ``cls._CLS_MODEL`` class should inherit from both
425-
# ``ClassA._CLS_MODEL`` and ``ClassB._CLS_MODEL`` or it will be losing the attributes of some of the models.
426-
if '_CLS_MODEL' in cls.__dict__:
427-
# Get all the base classes in the MRO of this class that define a class attribute ``_CLS_MODEL`` that
428-
# is a subclass of pydantic's ``BaseModel`` and not the class itself.
422+
# When a class defines a ``Model``, the following check ensures that the model inherits from the same bases
423+
# as the class containing the attribute itself. For example, if ``cls`` inherits from ``ClassA`` and
424+
# ``ClassB`` that each define a ``Model``, the ``cls.Model`` class should inherit from both ``ClassA.Model``
425+
# and ``ClassBModel`` or it will be losing the attributes of some of the models.
426+
if 'Model' in cls.__dict__:
427+
# Get all the base classes in the MRO of this class that define a class attribute ``Model`` that is a
428+
# subclass of pydantic's ``BaseModel`` and not the class itself
429429
cls_bases_with_model = [
430430
base
431431
for base in cls.__mro__
432-
if base is not cls and '_CLS_MODEL' in base.__dict__ and issubclass(base._CLS_MODEL, BaseModel) # type: ignore[attr-defined]
432+
if base is not cls and 'Model' in base.__dict__ and issubclass(base.Model, BaseModel) # type: ignore[attr-defined]
433433
]
434434

435435
# Now get the "leaf" bases, i.e., those base classes in the subclass list that themselves do not have a
436-
# subclass in the tree. This set should be the base classes for the class' ``_CLS_MODEL`` attribute.
436+
# subclass in the tree. This set should be the base classes for the class' ``Model`` attribute.
437437
cls_bases_with_model_leaves = {
438438
base
439439
for base in cls_bases_with_model
440440
if all(
441-
not issubclass(b._CLS_MODEL, base._CLS_MODEL) # type: ignore[attr-defined]
441+
not issubclass(b.Model, base.Model) # type: ignore[attr-defined]
442442
for b in cls_bases_with_model
443443
if b is not base
444444
)
445445
}
446446

447-
cls_model_bases = {base._CLS_MODEL for base in cls_bases_with_model_leaves} # type: ignore[attr-defined]
447+
cls_model_bases = {base.Model for base in cls_bases_with_model_leaves} # type: ignore[attr-defined]
448448

449-
# If the base class does not have a base that defines a model, it means the ``_CLS_MODEL`` should simply
450-
# have ``pydantic.BaseModel`` as its sole base.
449+
# If the base class does not have a base that defines a model, it means the ``Model`` should simply have
450+
# ``pydantic.BaseModel`` as its sole base.
451451
if not cls_model_bases:
452452
cls_model_bases = {
453453
BaseModel,
454454
}
455455

456-
# Get the set of bases of ``cls._CLS_MODEL`` that are a subclass of :class:`pydantic.BaseModel`.
457-
model_bases = {base for base in cls._CLS_MODEL.__bases__ if issubclass(base, BaseModel)}
456+
# Get the set of bases of ``cls.Model`` that are a subclass of :class:`pydantic.BaseModel`.
457+
model_bases = {base for base in cls.Model.__bases__ if issubclass(base, BaseModel)}
458458

459-
# For ``cls._CLS_MODEL`` to be valid, the bases that contain a model, should equal to the leaf bases of
460-
# the ``cls`` itself that also define a model.
459+
# For ``cls.Model`` to be valid, the bases that contain a model, should equal to the leaf bases of the
460+
# ``cls`` itself that also define a model.
461461
if model_bases != cls_model_bases and not getattr(cls, '_SKIP_MODEL_INHERITANCE_CHECK', False):
462-
bases = [f'{e.__module__}.Model' for e in cls_bases_with_model_leaves]
462+
bases = [f'{e.__module__}.{e.__name__}.Model' for e in cls_bases_with_model_leaves]
463463
raise RuntimeError(
464-
f'`{cls.__module__}.Model` does not subclass all necessary base classes. It should be: '
464+
f'`{cls.__name__}.Model` does not subclass all necessary base classes. It should be: '
465465
f'`class Model({", ".join(sorted(bases))}):`'
466466
)
467467

468-
for key, field in cls._CLS_MODEL.model_fields.items():
468+
for key, field in cls.Model.model_fields.items():
469469
fields[key] = add_field(
470470
key,
471-
alias=field.alias,
471+
alias=get_metadata(field, 'alias', None),
472472
dtype=field.annotation,
473473
doc=field.description,
474474
is_attribute=get_metadata(field, 'is_attribute', False),

src/aiida/orm/groups.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def load_group_class(type_string: str) -> Type[Group]:
5757
return group_class
5858

5959

60-
class GroupCollection(entities.Collection['Group']):
60+
class GroupCollection(entities.EntityCollection['Group']):
6161
"""Collection of Groups"""
6262

6363
@staticmethod
@@ -107,7 +107,7 @@ def extras(self) -> extras.EntityExtras:
107107
return extras.EntityExtras(self._group)
108108

109109

110-
class Model(entities.Model):
110+
class GroupModel(entities.EntityModel):
111111
uuid: UUID = MetadataField(
112112
description='The UUID of the group',
113113
is_attribute=False,
@@ -148,13 +148,13 @@ class Model(entities.Model):
148148
)
149149

150150

151-
class Group(entities.Entity['BackendGroup', GroupCollection, Model]):
151+
class Group(entities.Entity['BackendGroup', GroupCollection, GroupModel]):
152152
"""An AiiDA ORM implementation of group of nodes."""
153153

154154
__type_string: ClassVar[Optional[str]]
155155

156156
_CLS_COLLECTION = GroupCollection
157-
_CLS_MODEL = Model
157+
_CLS_MODEL = GroupModel
158158

159159
def __init__(
160160
self,

0 commit comments

Comments
 (0)