diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index a3561b8b1..a175cd119 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -113,7 +113,11 @@ def __init__(self, list_items, instance, name): BaseDocument = _import_class("BaseDocument") if isinstance(instance, BaseDocument): - self._instance = weakref.proxy(instance) + if isinstance(instance, weakref.ProxyTypes): + self._instance = instance + else: + self._instance = weakref.proxy(instance) + self._name = name super().__init__(list_items) @@ -186,10 +190,6 @@ def _mark_as_changed(self, key=None): class EmbeddedDocumentList(BaseList): - def __init__(self, list_items, instance, name): - super().__init__(list_items, instance, name) - self._instance = instance - @classmethod def __match_all(cls, embedded_doc, kwargs): """Return True if a given embedded doc matches all the filter diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index fefee4efd..a892c0dcd 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -1,3 +1,4 @@ +import weakref from copy import deepcopy import pytest @@ -62,6 +63,36 @@ class MyFailingDoc(Document): class MyFailingdoc2(Document): emb = EmbeddedDocumentField("MyDoc") + def test_embedded_document_list_field__has__instance_weakref(self): + class Comment(EmbeddedDocument): + content = StringField() + + class Post(Document): + title = StringField() + comment = EmbeddedDocumentField(Comment) + comments = EmbeddedDocumentListField(Comment) + comments2 = ListField(EmbeddedDocumentField(Comment)) + + Post.drop_collection() + + for i in range(5): + Post( + title=f"{i}", + comment=Comment(content=f"{i}"), + comments=[Comment(content=f"{i}")], + comments2=[Comment(content=f"{i}")], + ).save() + + posts = list(Post.objects) + for post in posts: + assert isinstance(post.comments._instance, weakref.ProxyTypes) + assert isinstance(post.comments2._instance, weakref.ProxyTypes) + assert isinstance(post.comment._instance, weakref.ProxyTypes) + for comment in post.comments: + assert isinstance(comment._instance, weakref.ProxyTypes) + for comment2 in post.comments2: + assert isinstance(comment2._instance, weakref.ProxyTypes) + def test_embedded_document_field_validate_subclass(self): class BaseItem(EmbeddedDocument): f = IntField()