Skip to content

Commit 5eb5e3e

Browse files
committed
Merge remote-tracking branch 'otto001/related_accessor' into select_related
2 parents 7ab11a3 + 705acf4 commit 5eb5e3e

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

polymorphic/models.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,15 @@ def __init__(self, *args, **kwargs):
197197
return
198198
self.__class__.polymorphic_super_sub_accessors_replaced = True
199199

200-
def create_accessor_function_for_model(model, accessor_name):
201-
NOT_PROVIDED = object()
202-
200+
def create_accessor_function_for_model(model, field):
203201
def accessor_function(self):
204-
attr = NOT_PROVIDED
205202
try:
206-
attr = self._state.fields_cache[accessor_name]
207-
pass
203+
rel_obj = field.get_cached_value(self)
208204
except KeyError:
209-
pass
210-
if attr is NOT_PROVIDED:
211205
objects = getattr(model, "_base_objects", model.objects)
212-
attr = objects.get(pk=self.pk)
213-
return attr
206+
rel_obj = objects.get(pk=self.pk)
207+
field.set_cached_value(self, rel_obj)
208+
return rel_obj
214209

215210
return accessor_function
216211

@@ -223,10 +218,14 @@ def accessor_function(self):
223218
type(orig_accessor),
224219
(ReverseOneToOneDescriptor, ForwardManyToOneDescriptor),
225220
):
221+
222+
field = orig_accessor.related \
223+
if isinstance(orig_accessor, ReverseOneToOneDescriptor) else orig_accessor.field
224+
226225
setattr(
227226
self.__class__,
228227
name,
229-
property(create_accessor_function_for_model(model, name)),
228+
property(create_accessor_function_for_model(model, field)),
230229
)
231230

232231
def _get_inheritance_relation_fields_and_models(self):

polymorphic/tests/test_orm.py

+23
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,29 @@ def test_parent_link_and_related_name(self):
997997
# test that we can delete the object
998998
t.delete()
999999

1000+
def test_polymorphic__accessor_caching(self):
1001+
blog_a = BlogA.objects.create(name="blog")
1002+
1003+
blog_base = BlogBase.objects.non_polymorphic().get(id=blog_a.id)
1004+
blog_a = BlogA.objects.get(id=blog_a.id)
1005+
1006+
# test reverse accessor & check that we get back cached object on repeated access
1007+
self.assertEqual(blog_base.bloga, blog_a)
1008+
self.assertIs(blog_base.bloga, blog_base.bloga)
1009+
cached_blog_a = blog_base.bloga
1010+
1011+
# test forward accessor & check that we get back cached object on repeated access
1012+
self.assertEqual(blog_a.blogbase_ptr, blog_base)
1013+
self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr)
1014+
cached_blog_base = blog_a.blogbase_ptr
1015+
1016+
# check that refresh_from_db correctly clears cached related objects
1017+
blog_base.refresh_from_db()
1018+
blog_a.refresh_from_db()
1019+
1020+
self.assertIsNot(cached_blog_a, blog_base.bloga)
1021+
self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr)
1022+
10001023
def test_polymorphic__aggregate(self):
10011024
"""test ModelX___field syntax on aggregate (should work for annotate either)"""
10021025

0 commit comments

Comments
 (0)