diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 56a464f2d2..8da7aaef91 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -117,6 +117,18 @@ def _filter_prefetch_queryset(queryset, field_name, instances): return queryset +def _traverse_ancestors(model, starting_instance): + current_instance = starting_instance + while current_instance is not None: + ancestor_link = current_instance._meta.get_ancestor_link(model) + if not ancestor_link: + yield current_instance, None + break + ancestor = ancestor_link.get_cached_value(current_instance, None) + yield current_instance, ancestor + current_instance = ancestor + + class ForwardManyToOneDescriptor: """ Accessor to the related object on the forward side of a many-to-one or @@ -228,21 +240,19 @@ class ForwardManyToOneDescriptor: try: rel_obj = self.field.get_cached_value(instance) except KeyError: + rel_obj = None has_value = None not in self.field.get_local_related_value(instance) - ancestor_link = ( - instance._meta.get_ancestor_link(self.field.model) - if has_value - else None - ) - if ancestor_link and ancestor_link.is_cached(instance): - # An ancestor link will exist if this field is defined on a - # multi-table inheritance parent of the instance's class. - ancestor = ancestor_link.get_cached_value(instance) - # The value might be cached on an ancestor if the instance - # originated from walking down the inheritance chain. - rel_obj = self.field.get_cached_value(ancestor, default=None) - else: - rel_obj = None + if has_value: + model = self.field.model + for current_instance, ancestor in _traverse_ancestors(model, instance): + if ancestor: + # The value might be cached on an ancestor if the + # instance originated from walking down the inheritance + # chain. + rel_obj = self.field.get_cached_value(ancestor, default=None) + if rel_obj is not None: + break + if rel_obj is None and has_value: rel_obj = self.get_object(instance) remote_field = self.field.remote_field @@ -1095,16 +1105,23 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): return queryset._next_is_sticky().filter(**self.core_filters) def get_prefetch_cache(self): - try: - return self.instance._prefetched_objects_cache[self.prefetch_cache_name] - except (AttributeError, KeyError): - return None + # Walk up the ancestor-chain (if cached) to try and find a prefetch + # in an ancestor. + for instance, _ in _traverse_ancestors(rel.field.model, self.instance): + try: + return instance._prefetched_objects_cache[self.prefetch_cache_name] + except (AttributeError, KeyError): + pass + return None def _remove_prefetched_objects(self): - try: - self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name) - except (AttributeError, KeyError): - pass # nothing to clear from cache + # Walk up the ancestor-chain (if cached) to try and find a prefetch + # in an ancestor. + for instance, _ in _traverse_ancestors(rel.field.model, self.instance): + try: + instance._prefetched_objects_cache.pop(self.prefetch_cache_name) + except (AttributeError, KeyError): + pass # nothing to clear from cache def get_queryset(self): if (cache := self.get_prefetch_cache()) is not None: diff --git a/tests/prefetch_related/models.py b/tests/prefetch_related/models.py index 405e9bba00..de2d88395d 100644 --- a/tests/prefetch_related/models.py +++ b/tests/prefetch_related/models.py @@ -28,6 +28,10 @@ class AuthorWithAge(Author): age = models.IntegerField() +class AuthorWithAgeChild(AuthorWithAge): + pass + + class FavoriteAuthors(models.Model): author = models.ForeignKey( Author, models.CASCADE, to_field="name", related_name="i_like" diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index bd37ca0ec3..1955809aec 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -21,6 +21,7 @@ from .models import ( Author2, AuthorAddress, AuthorWithAge, + AuthorWithAgeChild, Bio, Book, Bookmark, @@ -2086,3 +2087,184 @@ class PrefetchLimitTests(TestDataMixin, TestCase): with self.subTest(book=book): self.assertEqual(len(book.authors_sliced), 1) self.assertIn(book.authors_sliced[0], list(book.authors.all())) + + +class PrefetchRelatedMTICacheTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.book = Book.objects.create(title="Book1") + cls.related1 = Author.objects.create(name="related1", first_book=cls.book) + cls.related2 = Author.objects.create(name="related2", first_book=cls.book) + cls.related3 = Author.objects.create(name="related3", first_book=cls.book) + cls.related4 = Author.objects.create(name="related4", first_book=cls.book) + + cls.child = AuthorWithAgeChild.objects.create( + name="child", + age=31, + first_book=cls.book, + ) + cls.m2m_child = AuthorWithAgeChild.objects.create( + name="m2m_child", + age=31, + first_book=cls.book, + ) + cls.m2m_child.favorite_authors.set([cls.related1, cls.related2, cls.related3]) + + def test_parent_fk_available_in_child(self): + qs = ( + Author.objects.select_related("authorwithage") + .prefetch_related("first_book") + .filter(pk=self.child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].authorwithage.first_book, self.book) + + def test_grandparent_fk_available_in_child(self): + qs = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("first_book") + .filter(pk=self.child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertEqual( + results[0].authorwithage.authorwithagechild.first_book, self.book + ) + + def test_parent_m2m_available_in_child(self): + qs = ( + Author.objects.select_related("authorwithage") + .prefetch_related("favorite_authors") + .filter(pk=self.m2m_child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertQuerySetEqual( + results[0].authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + + def test_grandparent_m2m_available_in_child(self): + qs = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("favorite_authors") + .filter(pk=self.m2m_child.pk) + ) + with self.assertNumQueries(2): + results = list(qs) + self.assertEqual(len(results), 1) + self.assertQuerySetEqual( + set(results[0].authorwithage.authorwithagechild.favorite_authors.all()), + {self.related1, self.related2, self.related3}, + ) + + def test_add_clears_prefetched_objects_in_parent(self): + gp = ( + Author.objects.select_related("authorwithage") + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + {self.related1, self.related2, self.related3}, + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + {self.related1, self.related2, self.related3}, + ) + gp.authorwithage.favorite_authors.add(self.related4) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + {self.related1, self.related2, self.related3, self.related4}, + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + {self.related1, self.related2, self.related3, self.related4}, + ) + + def test_add_clears_prefetched_objects_in_grandparent(self): + gp = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + gp.authorwithage.authorwithagechild.favorite_authors.add(self.related4) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3, self.related4], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3, self.related4], + ) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), + [self.related1, self.related2, self.related3, self.related4], + ) + + def test_remove_clears_prefetched_objects_in_parent(self): + gp = ( + Author.objects.select_related("authorwithage") + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + gp.authorwithage.favorite_authors.clear() + self.assertQuerySetEqual(gp.favorite_authors.all(), []) + self.assertQuerySetEqual(gp.authorwithage.favorite_authors.all(), []) + + def test_remove_clears_prefetched_objects_in_grandparent(self): + gp = ( + Author.objects.select_related( + "authorwithage", "authorwithage__authorwithagechild" + ) + .prefetch_related("favorite_authors") + .get(pk=self.m2m_child.pk) + ) + self.assertQuerySetEqual( + gp.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), + [self.related1, self.related2, self.related3], + ) + gp.authorwithage.favorite_authors.clear() + self.assertQuerySetEqual(gp.favorite_authors.all(), []) + self.assertQuerySetEqual(gp.authorwithage.favorite_authors.all(), []) + self.assertQuerySetEqual( + gp.authorwithage.authorwithagechild.favorite_authors.all(), [] + )