1919from plumpy .base .utils import call_with_super_check , super_check
2020from pydantic import BaseModel , ConfigDict , create_model
2121from pydantic .fields import FieldInfo
22+ from typing_extensions import overload
2223
2324from aiida .common import exceptions , log
2425from aiida .common .exceptions import EntryPointError , InvalidOperation , NotExistent
2526from aiida .common .lang import classproperty , type_check
2627from aiida .common .pydantic import MetadataField , get_metadata
27- from aiida .common .typing import Self , TypeAlias
28+ from aiida .common .typing import Self
2829from aiida .common .warnings import warn_deprecation
2930from aiida .manage import get_manager
3031
3839
3940BackendEntityType = TypeVar ('BackendEntityType' , bound = 'BackendEntity' )
4041EntityCollectionType = TypeVar ('EntityCollectionType' , bound = 'EntityCollection[Any]' )
41- EntityType = TypeVar ('EntityType' , bound = 'Entity[Any,Any]' )
42+ EntityModelType = TypeVar ('EntityModelType' , bound = 'EntityModel' )
43+ EntityType = TypeVar ('EntityType' , bound = 'Entity[Any,Any,Any]' )
4244
4345
4446class EntityTypes (Enum ):
@@ -189,7 +191,7 @@ class EntityModel(BaseModel, defer_build=True):
189191 )
190192
191193 @classmethod
192- def as_create_model (cls : Type [EntityModel ]) -> Type [ EntityModel ]:
194+ def as_create_model (cls : type [EntityModel ]) -> type [ BaseModel ]:
193195 """Return a derived creation model class with read-only fields removed.
194196
195197 This also removes any serializers/validators defined on those fields.
@@ -238,16 +240,16 @@ def _prune_field_decorators(field_decorators: dict[str, Any]) -> dict[str, Any]:
238240 return CreateModel
239241
240242
241- class Entity (abc .ABC , Generic [BackendEntityType , EntityCollectionType ], metaclass = EntityFieldMeta ):
243+ class Entity (abc .ABC , Generic [BackendEntityType , EntityCollectionType , EntityModelType ], metaclass = EntityFieldMeta ):
242244 """An AiiDA entity"""
243245
244- Model : TypeAlias = EntityModel
246+ Model : type [ EntityModelType ] = EntityModel # type: ignore[assignment]
245247
246248 _CLS_COLLECTION : type [EntityCollectionType ] = EntityCollection # type: ignore[assignment]
247249 _logger = log .AIIDA_LOGGER .getChild ('orm.entities' )
248250
249251 @classproperty
250- def CreateModel (cls ) -> Type [ Model ]: # noqa: N802, N805
252+ def CreateModel (cls ) -> type [ BaseModel ]: # noqa: N802, N805
251253 """Return the creation version of the model class for this entity.
252254
253255 :return: The creation model class, with read-only fields removed.
@@ -261,7 +263,7 @@ def model_to_orm_fields(cls) -> dict[str, FieldInfo]:
261263 }
262264
263265 @classmethod
264- def model_to_orm_field_values (cls , model : EntityModel ) -> dict [str , Any ]:
266+ def model_to_orm_field_values (cls , model : BaseModel ) -> dict [str , Any ]:
265267 from aiida .plugins .factories import BaseFactory
266268
267269 fields = {}
@@ -291,13 +293,40 @@ def model_to_orm_field_values(cls, model: EntityModel) -> dict[str, Any]:
291293
292294 return fields
293295
296+ @overload
297+ def to_model (
298+ self ,
299+ * ,
300+ repository_path : Optional [pathlib .Path ] = None ,
301+ serialize_repository_content : bool = False ,
302+ unstored : Literal [False ] = False ,
303+ ) -> EntityModelType : ...
304+
305+ @overload
306+ def to_model (
307+ self ,
308+ * ,
309+ repository_path : Optional [pathlib .Path ] = None ,
310+ serialize_repository_content : bool = False ,
311+ unstored : Literal [True ] = True ,
312+ ) -> BaseModel : ...
313+
314+ @overload
315+ def to_model (
316+ self ,
317+ * ,
318+ repository_path : Optional [pathlib .Path ] = None ,
319+ serialize_repository_content : bool = False ,
320+ unstored : bool = False ,
321+ ) -> BaseModel : ...
322+
294323 def to_model (
295324 self ,
296325 * ,
297326 repository_path : Optional [pathlib .Path ] = None ,
298327 serialize_repository_content : bool = False ,
299328 unstored : bool = False ,
300- ) -> EntityModel :
329+ ) -> BaseModel :
301330 """Return the entity instance as an instance of its model.
302331
303332 :param repository_path: If the orm node has files in the repository, this path is used to read the repository
@@ -318,7 +347,7 @@ def to_model(
318347 return Model (** fields )
319348
320349 @classmethod
321- def from_model (cls , model : EntityModel ) -> Self :
350+ def from_model (cls , model : BaseModel ) -> Self :
322351 """Return an entity instance from an instance of its model.
323352
324353 :param model: An instance of the entity's model class.
0 commit comments