From d3cf24e9b415b41f570c9f426b2cd113b5fdb4de Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Tue, 3 Jun 2025 21:53:10 -0400 Subject: [PATCH] Refs #36430, #36416, #34378 -- Simplified batch size calculation in QuerySet.in_bulk(). --- django/db/models/query.py | 4 +--- tests/composite_pk/tests.py | 30 +++++++++++++++++------------- tests/lookup/tests.py | 19 ++++++------------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 9f245b02ca..7ae9f53bfd 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1187,10 +1187,8 @@ class QuerySet(AltersData): if not id_list: return {} filter_key = "{}__in".format(field_name) - max_params = connections[self.db].features.max_query_params or 0 - num_fields = len(opts.pk_fields) if field_name == "pk" else 1 - batch_size = max_params // num_fields id_list = tuple(id_list) + batch_size = connections[self.db].ops.bulk_batch_size([opts.pk], id_list) # If the database has a limit on the number of query parameters # (e.g. SQLite), retrieve objects in batches if necessary. if batch_size and batch_size < len(id_list): diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py index cc78f3495a..c4a8e6ca8c 100644 --- a/tests/composite_pk/tests.py +++ b/tests/composite_pk/tests.py @@ -147,20 +147,24 @@ class CompositePKTests(TestCase): result = Comment.objects.in_bulk([self.comment.pk]) self.assertEqual(result, {self.comment.pk: self.comment}) - @unittest.mock.patch.object( - type(connection.features), "max_query_params", new_callable=lambda: 10 - ) - def test_in_bulk_batching(self, mocked_max_query_params): + def test_in_bulk_batching(self): Comment.objects.all().delete() - num_requiring_batching = (connection.features.max_query_params // 2) + 1 - comments = [ - Comment(id=i, tenant=self.tenant, user=self.user) - for i in range(1, num_requiring_batching + 1) - ] - Comment.objects.bulk_create(comments) - id_list = list(Comment.objects.values_list("pk", flat=True)) - with self.assertNumQueries(2): - comment_dict = Comment.objects.in_bulk(id_list=id_list) + batching_required = connection.features.max_query_params is not None + expected_queries = 2 if batching_required else 1 + with unittest.mock.patch.object( + type(connection.features), "max_query_params", 10 + ): + num_requiring_batching = ( + connection.ops.bulk_batch_size([Comment._meta.pk], []) + 1 + ) + comments = [ + Comment(id=i, tenant=self.tenant, user=self.user) + for i in range(1, num_requiring_batching + 1) + ] + Comment.objects.bulk_create(comments) + id_list = list(Comment.objects.values_list("pk", flat=True)) + with self.assertNumQueries(expected_queries): + comment_dict = Comment.objects.in_bulk(id_list=id_list) self.assertQuerySetEqual(comment_dict, id_list) def test_iterator(self): diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index 25336cbee7..ef54472e54 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -248,28 +248,21 @@ class LookupTests(TestCase): with self.assertRaisesMessage(ValueError, msg): Article.objects.in_bulk([self.au1], field_name="author") - @skipUnlessDBFeature("can_distinct_on_fields") def test_in_bulk_preserve_ordering(self): - articles = ( - Article.objects.order_by("author_id", "-pub_date") - .distinct("author_id") - .in_bulk([self.au1.id, self.au2.id], field_name="author_id") - ) self.assertEqual( - articles, - {self.au1.id: self.a4, self.au2.id: self.a5}, + list(Article.objects.in_bulk([self.au2.id, self.au1.id])), + [self.au2.id, self.au1.id], ) - @skipUnlessDBFeature("can_distinct_on_fields") def test_in_bulk_preserve_ordering_with_batch_size(self): - qs = Article.objects.order_by("author_id", "-pub_date").distinct("author_id") + qs = Article.objects.all() with ( - mock.patch.object(connection.features.__class__, "max_query_params", 1), + mock.patch.object(connection.ops, "bulk_batch_size", return_value=2), self.assertNumQueries(2), ): self.assertEqual( - qs.in_bulk([self.au1.id, self.au2.id], field_name="author_id"), - {self.au1.id: self.a4, self.au2.id: self.a5}, + list(qs.in_bulk([self.a4.id, self.a3.id, self.a2.id, self.a1.id])), + [self.a4.id, self.a3.id, self.a2.id, self.a1.id], ) @skipUnlessDBFeature("can_distinct_on_fields")