Skip to content

Commit f1068b5

Browse files
Discard type aliasing and inject a generic EntityModelType
1 parent 0410f50 commit f1068b5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+93
-100
lines changed

src/aiida/common/pydantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def MetadataField( # noqa: N802
3333
priority: int = 0,
3434
short_name: str | None = None,
3535
option_cls: t.Any | None = None,
36-
orm_class: type[Entity[t.Any, t.Any]] | str | None = None,
37-
orm_to_model: t.Callable[[Entity[t.Any, t.Any]], t.Any] # without arguments
38-
| t.Callable[[Entity[t.Any, t.Any], dict[str, t.Any]], t.Any] # with arguments
36+
orm_class: type[Entity[t.Any, t.Any, t.Any]] | str | None = None,
37+
orm_to_model: t.Callable[[Entity[t.Any, t.Any, t.Any]], t.Any] # without arguments
38+
| t.Callable[[Entity[t.Any, t.Any, t.Any], dict[str, t.Any]], t.Any] # with arguments
3939
| None = None,
4040
model_to_orm: t.Callable[[BaseModel], t.Any] | None = None,
4141
exclude_to_orm: bool = False,

src/aiida/orm/authinfos.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from aiida.common import exceptions
1616
from aiida.common.pydantic import MetadataField
17-
from aiida.common.typing import TypeAlias
1817
from aiida.manage import get_manager
1918
from aiida.plugins import TransportFactory
2019

@@ -75,10 +74,10 @@ class AuthInfoModel(entities.EntityModel):
7574
)
7675

7776

78-
class AuthInfo(entities.Entity['BackendAuthInfo', AuthInfoCollection]):
77+
class AuthInfo(entities.Entity['BackendAuthInfo', AuthInfoCollection, AuthInfoModel]):
7978
"""ORM class that models the authorization information that allows a `User` to connect to a `Computer`."""
8079

81-
Model: TypeAlias = AuthInfoModel
80+
Model = AuthInfoModel
8281

8382
_CLS_COLLECTION = AuthInfoCollection
8483

src/aiida/orm/comments.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from uuid import UUID
1616

1717
from aiida.common.pydantic import MetadataField
18-
from aiida.common.typing import TypeAlias
1918
from aiida.manage import get_manager
2019

2120
from . import entities
@@ -100,10 +99,10 @@ class CommentModel(entities.EntityModel):
10099
)
101100

102101

103-
class Comment(entities.Entity['BackendComment', CommentCollection]):
102+
class Comment(entities.Entity['BackendComment', CommentCollection, CommentModel]):
104103
"""Base class to map a DbComment that represents a comment attached to a certain Node."""
105104

106-
Model: TypeAlias = CommentModel
105+
Model = CommentModel
107106

108107
_CLS_COLLECTION = CommentCollection
109108

src/aiida/orm/computers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from aiida.common import exceptions
1818
from aiida.common.log import AIIDA_LOGGER, AiidaLoggerType
1919
from aiida.common.pydantic import MetadataField
20-
from aiida.common.typing import TypeAlias
2120
from aiida.manage import get_manager
2221
from aiida.plugins import SchedulerFactory, TransportFactory
2322

@@ -99,10 +98,10 @@ class ComputerModel(entities.EntityModel):
9998
)
10099

101100

102-
class Computer(entities.Entity['BackendComputer', ComputerCollection]):
101+
class Computer(entities.Entity['BackendComputer', ComputerCollection, ComputerModel]):
103102
"""Computer entity."""
104103

105-
Model: TypeAlias = ComputerModel
104+
Model = ComputerModel
106105

107106
_logger = AIIDA_LOGGER.getChild('orm.computers')
108107

src/aiida/orm/entities.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
from plumpy.base.utils import call_with_super_check, super_check
2020
from pydantic import BaseModel, ConfigDict, create_model
2121
from pydantic.fields import FieldInfo
22+
from typing_extensions import overload
2223

2324
from aiida.common import exceptions, log
2425
from aiida.common.exceptions import EntryPointError, InvalidOperation, NotExistent
2526
from aiida.common.lang import classproperty, type_check
2627
from aiida.common.pydantic import MetadataField, get_metadata
27-
from aiida.common.typing import Self, TypeAlias
28+
from aiida.common.typing import Self
2829
from aiida.common.warnings import warn_deprecation
2930
from aiida.manage import get_manager
3031

@@ -38,7 +39,8 @@
3839

3940
BackendEntityType = TypeVar('BackendEntityType', bound='BackendEntity')
4041
EntityCollectionType = TypeVar('EntityCollectionType', bound='EntityCollection[Any]')
41-
EntityType = TypeVar('EntityType', bound='Entity[Any,Any]')
42+
EntityModelType = TypeVar('EntityModelType', bound='EntityModel')
43+
EntityType = TypeVar('EntityType', bound='Entity[Any,Any,Any]')
4244

4345

4446
class EntityTypes(Enum):
@@ -189,7 +191,7 @@ class EntityModel(BaseModel, defer_build=True):
189191
)
190192

191193
@classmethod
192-
def as_create_model(cls: Type[EntityModel]) -> Type[EntityModel]:
194+
def as_create_model(cls: type[EntityModel]) -> type[BaseModel]:
193195
"""Return a derived creation model class with read-only fields removed.
194196
195197
This also removes any serializers/validators defined on those fields.
@@ -238,16 +240,16 @@ def _prune_field_decorators(field_decorators: dict[str, Any]) -> dict[str, Any]:
238240
return CreateModel
239241

240242

241-
class Entity(abc.ABC, Generic[BackendEntityType, EntityCollectionType], metaclass=EntityFieldMeta):
243+
class Entity(abc.ABC, Generic[BackendEntityType, EntityCollectionType, EntityModelType], metaclass=EntityFieldMeta):
242244
"""An AiiDA entity"""
243245

244-
Model: TypeAlias = EntityModel
246+
Model: type[EntityModelType] = EntityModel # type: ignore[assignment]
245247

246248
_CLS_COLLECTION: type[EntityCollectionType] = EntityCollection # type: ignore[assignment]
247249
_logger = log.AIIDA_LOGGER.getChild('orm.entities')
248250

249251
@classproperty
250-
def CreateModel(cls) -> Type[Model]: # noqa: N802, N805
252+
def CreateModel(cls) -> type[BaseModel]: # noqa: N802, N805
251253
"""Return the creation version of the model class for this entity.
252254
253255
:return: The creation model class, with read-only fields removed.
@@ -261,7 +263,7 @@ def model_to_orm_fields(cls) -> dict[str, FieldInfo]:
261263
}
262264

263265
@classmethod
264-
def model_to_orm_field_values(cls, model: EntityModel) -> dict[str, Any]:
266+
def model_to_orm_field_values(cls, model: BaseModel) -> dict[str, Any]:
265267
from aiida.plugins.factories import BaseFactory
266268

267269
fields = {}
@@ -291,13 +293,40 @@ def model_to_orm_field_values(cls, model: EntityModel) -> dict[str, Any]:
291293

292294
return fields
293295

296+
@overload
297+
def to_model(
298+
self,
299+
*,
300+
repository_path: Optional[pathlib.Path] = None,
301+
serialize_repository_content: bool = False,
302+
unstored: Literal[False] = False,
303+
) -> EntityModelType: ...
304+
305+
@overload
306+
def to_model(
307+
self,
308+
*,
309+
repository_path: Optional[pathlib.Path] = None,
310+
serialize_repository_content: bool = False,
311+
unstored: Literal[True] = True,
312+
) -> BaseModel: ...
313+
314+
@overload
315+
def to_model(
316+
self,
317+
*,
318+
repository_path: Optional[pathlib.Path] = None,
319+
serialize_repository_content: bool = False,
320+
unstored: bool = False,
321+
) -> BaseModel: ...
322+
294323
def to_model(
295324
self,
296325
*,
297326
repository_path: Optional[pathlib.Path] = None,
298327
serialize_repository_content: bool = False,
299328
unstored: bool = False,
300-
) -> EntityModel:
329+
) -> BaseModel:
301330
"""Return the entity instance as an instance of its model.
302331
303332
:param repository_path: If the orm node has files in the repository, this path is used to read the repository
@@ -318,7 +347,7 @@ def to_model(
318347
return Model(**fields)
319348

320349
@classmethod
321-
def from_model(cls, model: EntityModel) -> Self:
350+
def from_model(cls, model: BaseModel) -> Self:
322351
"""Return an entity instance from an instance of its model.
323352
324353
:param model: An instance of the entity's model class.

src/aiida/orm/groups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from aiida.common import exceptions
2121
from aiida.common.lang import classproperty, type_check
2222
from aiida.common.pydantic import MetadataField
23-
from aiida.common.typing import Self, TypeAlias
23+
from aiida.common.typing import Self
2424
from aiida.common.warnings import warn_deprecation
2525
from aiida.manage import get_manager
2626

@@ -148,10 +148,10 @@ class GroupModel(entities.EntityModel):
148148
)
149149

150150

151-
class Group(entities.Entity['BackendGroup', GroupCollection]):
151+
class Group(entities.Entity['BackendGroup', GroupCollection, GroupModel]):
152152
"""An AiiDA ORM implementation of group of nodes."""
153153

154-
Model: TypeAlias = GroupModel
154+
Model = GroupModel
155155

156156
__type_string: ClassVar[Optional[str]]
157157

src/aiida/orm/logs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from aiida.common import timezone
1919
from aiida.common.pydantic import MetadataField
20-
from aiida.common.typing import TypeAlias
2120
from aiida.manage import get_manager
2221

2322
from . import entities
@@ -161,10 +160,10 @@ class LogModel(entities.EntityModel):
161160
)
162161

163162

164-
class Log(entities.Entity['BackendLog', LogCollection]):
163+
class Log(entities.Entity['BackendLog', LogCollection, LogModel]):
165164
"""An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node."""
166165

167-
Model: TypeAlias = LogModel
166+
Model = LogModel
168167

169168
_CLS_COLLECTION = LogCollection
170169

src/aiida/orm/nodes/data/array/array.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pydantic import field_validator
2020

2121
from aiida.common.pydantic import MetadataField
22-
from aiida.common.typing import TypeAlias
2322

2423
from .. import data
2524
from ..base import to_aiida_type
@@ -75,7 +74,7 @@ class ArrayData(data.Data):
7574
7675
"""
7776

78-
Model: TypeAlias = ArrayDataModel
77+
Model = ArrayDataModel
7978

8079
array_prefix = 'array|'
8180
default_array_name = 'default'

src/aiida/orm/nodes/data/array/bands.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from aiida.common.exceptions import ValidationError
2222
from aiida.common.pydantic import MetadataField
23-
from aiida.common.typing import TypeAlias
2423
from aiida.common.utils import join_labels, prettify_labels
2524

2625
from . import kpoints
@@ -228,7 +227,7 @@ class BandsDataModel(kpoints.KpointsDataModel):
228227
class BandsData(kpoints.KpointsData):
229228
"""Class to handle bands data"""
230229

231-
Model: TypeAlias = BandsDataModel
230+
Model = BandsDataModel
232231

233232
def __init__(
234233
self,
@@ -244,7 +243,7 @@ def set_kpointsdata(self, kpointsdata):
244243
"""Load the kpoints from a kpoint object.
245244
:param kpointsdata: an instance of KpointsData class
246245
"""
247-
if not isinstance(kpointsdata, KpointsData):
246+
if not isinstance(kpointsdata, kpoints.KpointsData):
248247
raise ValueError('kpointsdata must be of the KpointsData class')
249248
try:
250249
self.cell = kpointsdata.cell

src/aiida/orm/nodes/data/array/kpoints.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy
1919

2020
from aiida.common.pydantic import MetadataField
21-
from aiida.common.typing import TypeAlias
2221

2322
from . import array
2423

@@ -84,7 +83,7 @@ class KpointsData(array.ArrayData):
8483
set_cell_from_structure methods.
8584
"""
8685

87-
Model: TypeAlias = KpointsDataModel
86+
Model = KpointsDataModel
8887

8988
def __init__(
9089
self,

0 commit comments

Comments
 (0)