From 705acf4b8408112cf1832d073960603a5fe75c34 Mon Sep 17 00:00:00 2001 From: otto001 Date: Wed, 23 Feb 2022 22:03:32 +0100 Subject: [PATCH] polymorphic accessors now use builtin caching from underlying fields --- polymorphic/models.py | 18 +++++++++++++----- polymorphic/tests/test_orm.py | 23 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/polymorphic/models.py b/polymorphic/models.py index df0cc485..e47103e7 100644 --- a/polymorphic/models.py +++ b/polymorphic/models.py @@ -197,11 +197,15 @@ def __init__(self, *args, **kwargs): return self.__class__.polymorphic_super_sub_accessors_replaced = True - def create_accessor_function_for_model(model, accessor_name): + def create_accessor_function_for_model(model, field): def accessor_function(self): - objects = getattr(model, "_base_objects", model.objects) - attr = objects.get(pk=self.pk) - return attr + try: + rel_obj = field.get_cached_value(self) + except KeyError: + objects = getattr(model, "_base_objects", model.objects) + rel_obj = objects.get(pk=self.pk) + field.set_cached_value(self, rel_obj) + return rel_obj return accessor_function @@ -214,10 +218,14 @@ def accessor_function(self): type(orig_accessor), (ReverseOneToOneDescriptor, ForwardManyToOneDescriptor), ): + + field = orig_accessor.related \ + if isinstance(orig_accessor, ReverseOneToOneDescriptor) else orig_accessor.field + setattr( self.__class__, name, - property(create_accessor_function_for_model(model, name)), + property(create_accessor_function_for_model(model, field)), ) def _get_inheritance_relation_fields_and_models(self): diff --git a/polymorphic/tests/test_orm.py b/polymorphic/tests/test_orm.py index f9ead453..195af0ec 100644 --- a/polymorphic/tests/test_orm.py +++ b/polymorphic/tests/test_orm.py @@ -985,6 +985,29 @@ def test_parent_link_and_related_name(self): # test that we can delete the object t.delete() + def test_polymorphic__accessor_caching(self): + blog_a = BlogA.objects.create(name="blog") + + blog_base = BlogBase.objects.non_polymorphic().get(id=blog_a.id) + blog_a = BlogA.objects.get(id=blog_a.id) + + # test reverse accessor & check that we get back cached object on repeated access + self.assertEqual(blog_base.bloga, blog_a) + self.assertIs(blog_base.bloga, blog_base.bloga) + cached_blog_a = blog_base.bloga + + # test forward accessor & check that we get back cached object on repeated access + self.assertEqual(blog_a.blogbase_ptr, blog_base) + self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr) + cached_blog_base = blog_a.blogbase_ptr + + # check that refresh_from_db correctly clears cached related objects + blog_base.refresh_from_db() + blog_a.refresh_from_db() + + self.assertIsNot(cached_blog_a, blog_base.bloga) + self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr) + def test_polymorphic__aggregate(self): """test ModelX___field syntax on aggregate (should work for annotate either)"""