diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 62ddfc60b3..a8f298230a 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -75,7 +75,7 @@ from django.db import ( router, transaction, ) -from django.db.models import Q, Window, signals +from django.db.models import Manager, Q, Window, signals from django.db.models.functions import RowNumber from django.db.models.lookups import GreaterThan, LessThanOrEqual from django.db.models.query import QuerySet @@ -1121,6 +1121,12 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): queryset._defer_next_filter = True return queryset._next_is_sticky().filter(**self.core_filters) + def get_prefetch_cache(self): + try: + return self.instance._prefetched_objects_cache[self.prefetch_cache_name] + except (AttributeError, KeyError): + return None + def _remove_prefetched_objects(self): try: self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name) @@ -1128,9 +1134,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): pass # nothing to clear from cache def get_queryset(self): - try: - return self.instance._prefetched_objects_cache[self.prefetch_cache_name] - except (AttributeError, KeyError): + if (cache := self.get_prefetch_cache()) is not None: + return cache + else: queryset = super().get_queryset() return self._apply_rel_filters(queryset) @@ -1195,6 +1201,45 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): False, ) + @property + def constrained_target(self): + # If the through relation's target field's foreign integrity is + # enforced, the query can be performed solely against the through + # table as the INNER JOIN'ing against target table is unnecessary. + if not self.target_field.db_constraint: + return None + db = router.db_for_read(self.through, instance=self.instance) + if not connections[db].features.supports_foreign_keys: + return None + hints = {"instance": self.instance} + manager = self.through._base_manager.db_manager(db, hints=hints) + filters = {self.source_field_name: self.instance.pk} + # Nullable target rows must be excluded as well as they would have + # been filtered out from an INNER JOIN. + if self.target_field.null: + filters["%s__isnull" % self.target_field_name] = False + return manager.filter(**filters) + + def exists(self): + if ( + superclass is Manager + and self.get_prefetch_cache() is None + and (constrained_target := self.constrained_target) is not None + ): + return constrained_target.exists() + else: + return super().exists() + + def count(self): + if ( + superclass is Manager + and self.get_prefetch_cache() is None + and (constrained_target := self.constrained_target) is not None + ): + return constrained_target.count() + else: + return super().count() + def add(self, *objs, through_defaults=None): self._remove_prefetched_objects() db = router.db_for_write(self.through, instance=self.instance) diff --git a/tests/many_to_many/models.py b/tests/many_to_many/models.py index 42fc426990..df7222e08d 100644 --- a/tests/many_to_many/models.py +++ b/tests/many_to_many/models.py @@ -78,3 +78,15 @@ class InheritedArticleA(AbstractArticle): class InheritedArticleB(AbstractArticle): pass + + +class NullableTargetArticle(models.Model): + headline = models.CharField(max_length=100) + publications = models.ManyToManyField( + Publication, through="NullablePublicationThrough" + ) + + +class NullablePublicationThrough(models.Model): + article = models.ForeignKey(NullableTargetArticle, models.CASCADE) + publication = models.ForeignKey(Publication, models.CASCADE, null=True) diff --git a/tests/many_to_many/tests.py b/tests/many_to_many/tests.py index 7ed3b80abc..351e4eb8cc 100644 --- a/tests/many_to_many/tests.py +++ b/tests/many_to_many/tests.py @@ -1,10 +1,18 @@ from unittest import mock -from django.db import transaction +from django.db import connection, transaction from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.utils.deprecation import RemovedInDjango60Warning -from .models import Article, InheritedArticleA, InheritedArticleB, Publication, User +from .models import ( + Article, + InheritedArticleA, + InheritedArticleB, + NullablePublicationThrough, + NullableTargetArticle, + Publication, + User, +) class ManyToManyTests(TestCase): @@ -558,10 +566,16 @@ class ManyToManyTests(TestCase): def test_custom_default_manager_exists_count(self): a5 = Article.objects.create(headline="deleted") a5.publications.add(self.p2) - self.assertEqual(self.p2.article_set.count(), self.p2.article_set.all().count()) - self.assertEqual( - self.p3.article_set.exists(), self.p3.article_set.all().exists() - ) + with self.assertNumQueries(2) as ctx: + self.assertEqual( + self.p2.article_set.count(), self.p2.article_set.all().count() + ) + self.assertIn("JOIN", ctx.captured_queries[0]["sql"]) + with self.assertNumQueries(2) as ctx: + self.assertEqual( + self.p3.article_set.exists(), self.p3.article_set.all().exists() + ) + self.assertIn("JOIN", ctx.captured_queries[0]["sql"]) def test_get_prefetch_queryset_warning(self): articles = Article.objects.all() @@ -582,3 +596,73 @@ class ManyToManyTests(TestCase): instances=articles, querysets=[Publication.objects.all(), Publication.objects.all()], ) + + +class ManyToManyQueryTests(TestCase): + """ + SQL is optimized to reference the through table without joining against the + related table when using count() and exists() functions on a queryset for + many to many relations. The optimization applies to the case where there + are no filters. + """ + + @classmethod + def setUpTestData(cls): + cls.article = Article.objects.create( + headline="Django lets you build Web apps easily" + ) + cls.nullable_target_article = NullableTargetArticle.objects.create( + headline="The python is good" + ) + NullablePublicationThrough.objects.create( + article=cls.nullable_target_article, publication=None + ) + + @skipUnlessDBFeature("supports_foreign_keys") + def test_count_join_optimization(self): + with self.assertNumQueries(1) as ctx: + self.article.publications.count() + self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"]) + + with self.assertNumQueries(1) as ctx: + self.article.publications.count() + self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"]) + self.assertEqual(self.nullable_target_article.publications.count(), 0) + + def test_count_join_optimization_disabled(self): + with ( + mock.patch.object(connection.features, "supports_foreign_keys", False), + self.assertNumQueries(1) as ctx, + ): + self.article.publications.count() + + self.assertIn("JOIN", ctx.captured_queries[0]["sql"]) + + @skipUnlessDBFeature("supports_foreign_keys") + def test_exists_join_optimization(self): + with self.assertNumQueries(1) as ctx: + self.article.publications.exists() + self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"]) + + self.article.publications.prefetch_related() + with self.assertNumQueries(1) as ctx: + self.article.publications.exists() + self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"]) + self.assertIs(self.nullable_target_article.publications.exists(), False) + + def test_exists_join_optimization_disabled(self): + with ( + mock.patch.object(connection.features, "supports_foreign_keys", False), + self.assertNumQueries(1) as ctx, + ): + self.article.publications.exists() + + self.assertIn("JOIN", ctx.captured_queries[0]["sql"]) + + def test_prefetch_related_no_queries_optimization_disabled(self): + qs = Article.objects.prefetch_related("publications") + article = qs.get() + with self.assertNumQueries(0): + article.publications.count() + with self.assertNumQueries(0): + article.publications.exists()