From e709301000edddc48ebfe9535e2e2177d3bcedb1 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 30 Mar 2025 11:53:32 +0200 Subject: [PATCH] Fixed #36282 -- Used prefetched values in ForwardManyToOneDescriptor from indirect ancestors. When looking for cached values in ManyRelatedManager and ForwardManyToOneDescriptor walk up the whole chain of ancestors (as long as they are cached) to find the prefetched relation. --- .../db/models/fields/related_descriptors.py | 61 +++--- tests/prefetch_related/models.py | 4 + tests/prefetch_related/tests.py | 182 ++++++++++++++++++ 3 files changed, 225 insertions(+), 22 deletions(-) diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 56a464f2d2..8da7aaef91 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -117,6 +117,18 @@ def _filter_prefetch_queryset(queryset, field_name, instances): return queryset +def _traverse_ancestors(model, starting_instance): + current_instance = starting_instance + while current_instance is not None: + ancestor_link = current_instance._meta.get_ancestor_link(model) + if not ancestor_link: + yield current_instance, None + break + ancestor = ancestor_link.get_cached_value(current_instance, None) + yield current_instance, ancestor + current_instance = ancestor + + class ForwardManyToOneDescriptor: """ Accessor to the related object on the forward side of a many-to-one or @@ -228,21 +240,19 @@ class ForwardManyToOneDescriptor: try: rel_obj = self.field.get_cached_value(instance) except KeyError: + rel_obj = None has_value = None not in self.field.get_local_related_value(instance) - ancestor_link = ( - instance._meta.get_ancestor_link(self.field.model) - if has_value - else None - ) - if ancestor_link and ancestor_link.is_cached(instance): - # An ancestor link will exist if this field is defined on a - # multi-table inheritance parent of the instance's class. - ancestor = ancestor_link.get_cached_value(instance) - # The value might be cached on an ancestor if the instance - # originated from walking down the inheritance chain. - rel_obj = self.field.get_cached_value(ancestor, default=None) - else: - rel_obj = None + if has_value: + model = self.field.model + for current_instance, ancestor in _traverse_ancestors(model, instance): + if ancestor: + # The value might be cached on an ancestor if the + # instance originated from walking down the inheritance + # chain. + rel_obj = self.field.get_cached_value(ancestor, default=None) + if rel_obj is not None: + break + if rel_obj is None and has_value: rel_obj = self.get_object(instance) remote_field = self.field.remote_field @@ -1095,16 +1105,23 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): return queryset._next_is_sticky().filter(**self.core_filters) def get_prefetch_cache(self): - try: - return self.instance._prefetched_objects_cache[self.prefetch_cache_name] - except (AttributeError, KeyError): - return None + # Walk up the ancestor-chain (if cached) to try and find a prefetch + # in an ancestor. + for instance, _ in _traverse_ancestors(rel.field.model, self.instance): + try: + return instance._prefetched_objects_cache[self.prefetch_cache_name] + except (AttributeError, KeyError): + pass + return None def _remove_prefetched_objects(self): - try: - self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name) - except (AttributeError, KeyError): - pass # nothing to clear from cache + # Walk up the ancestor-chain (if cached) to try and find a prefetch + # in an ancestor. + for instance, _ in _traverse_ancestors(rel.field.model, self.instance): + try: + instance._prefetched_objects_cache.pop(self.prefetch_cache_name) + except (AttributeError, KeyError): + pass # nothing to clear from cache def get_queryset(self): if (cache := self.get_prefetch_cache()) is not None: diff --git a/tests/prefetch_related/models.py b/tests/prefetch_related/models.py index 405e9bba00..de2d88395d 100644 --- a/tests/prefetch_related/models.py +++ b/tests/prefetch_related/models.py @@ -28,6 +28,10 @@ class AuthorWithAge(Author): age = models.IntegerField() +class AuthorWithAgeChild(AuthorWithAge): + pass + + class FavoriteAuthors(models.Model): author = models.ForeignKey( Author, models.CASCADE, to_field="name", related_name="i_like" diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index bd37ca0ec3..1955809aec 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -21,6 +21,7 @@ from .models import ( Author2, AuthorAddress, AuthorWithAge, + AuthorWithAgeChild, Bio, Book, Bookmark, @@ -2086,3 +2087,184 @@ class PrefetchLimitTests(TestDataMixin, TestCase): with self.subTest(book=book): self.assertEqual(len(book.authors_sliced), 1) self.assertIn(book.authors_sliced[0], list(book.authors.all())) + + +class PrefetchRelatedMTICacheTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.book = Book.objects.create(title="Book1") + cls.related1 = Author.objects.create(name="related1", first_book=cls.book) + cls.related2 = Author.objects.create(name="related2", first_book=cls.book) + cls.related3 = Author.objects.create(name="related3", first_book=cls.book) + cls.related4 = Author.objects.create(name="related4", first_book=cls.book) + + cls.child = AuthorWithAgeChild.objects.create( + name="child", + age=31, + first_book=cls.book, + ) + cls.m2m_child = AuthorWithAgeChild.objects.create( + name="m2m_child", + age=31, + first_book=cls.book, + ) + cls.m2m_child.favorite_authors.set([cls.related1, cls.related2, cls.related3]) + + def test_parent_fk_available_in_child(self): + qs = ( + Author.objects.select_related("authorwithage") + .prefetch_related("first_book") + .filter(pk=self.child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].authorwithage.first_book, self.book) + + def test_grandparent_fk_available_in_child(self): + qs = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("first_book") + .filter(pk=self.child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertEqual( + results[0].authorwithage.authorwithagechild.first_book, self.book + ) + + def test_parent_m2m_available_in_child(self): + qs = ( + Author.objects.select_related("authorwithage") + .prefetch_related("favorite_authors") + .filter(pk=self.m2m_child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertQuerySetEqual( + results[0].authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + + def test_grandparent_m2m_available_in_child(self): + qs = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("favorite_authors") + .filter(pk=self.m2m_child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertQuerySetEqual( + set(results[0].authorwithage.authorwithagechild.favorite_authors.all()), + {self.related1, self.related2, self.related3}, + ) + + def test_add_clears_prefetched_objects_in_parent(self): + gp = ( + Author.objects.select_related("authorwithage") + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + {self.related1, self.related2, self.related3}, + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + {self.related1, self.related2, self.related3}, + ) + gp.authorwithage.favorite_authors.add(self.related4) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + {self.related1, self.related2, self.related3, self.related4}, + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + {self.related1, self.related2, self.related3, self.related4}, + ) + + def test_add_clears_prefetched_objects_in_grandparent(self): + gp = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + gp.authorwithage.authorwithagechild.favorite_authors.add(self.related4) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3, self.related4], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3, self.related4], + ) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), + [self.related1, self.related2, self.related3, self.related4], + ) + + def test_remove_clears_prefetched_objects_in_parent(self): + gp = ( + Author.objects.select_related("authorwithage") + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + gp.authorwithage.favorite_authors.clear() + self.assertQuerySetEqual(gp.favorite_authors.all(), []) + self.assertQuerySetEqual(gp.authorwithage.favorite_authors.all(), []) + + def test_remove_clears_prefetched_objects_in_grandparent(self): + gp = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + gp.authorwithage.favorite_authors.clear() + self.assertQuerySetEqual(gp.favorite_authors.all(), []) + self.assertQuerySetEqual(gp.authorwithage.favorite_authors.all(), []) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), [] + )