diff --git a/django/db/models/query.py b/django/db/models/query.py index 721bf33e57..4cccb383fd 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1166,8 +1166,6 @@ class QuerySet(AltersData): """ if self.query.is_sliced: raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().") - if not issubclass(self._iterable_class, ModelIterable): - raise TypeError("in_bulk() cannot be used with values() or values_list().") opts = self.model._meta unique_fields = [ constraint.fields[0] @@ -1184,6 +1182,59 @@ class QuerySet(AltersData): "in_bulk()'s field_name must be a unique field but %r isn't." % field_name ) + + qs = self + + def get_obj(obj): + return obj + + if issubclass(self._iterable_class, ModelIterable): + # Raise an AttributeError if field_name is deferred. + get_key = operator.attrgetter(field_name) + + elif issubclass(self._iterable_class, ValuesIterable): + if field_name not in self.query.values_select: + qs = qs.values(field_name, *self.query.values_select) + + def get_obj(obj): # noqa: F811 + # We can safely mutate the dictionaries returned by + # ValuesIterable here, since they are limited to the scope + # of this function, and get_key runs before get_obj. + del obj[field_name] + return obj + + get_key = operator.itemgetter(field_name) + + elif issubclass(self._iterable_class, ValuesListIterable): + try: + field_index = self.query.values_select.index(field_name) + except ValueError: + # field_name is missing from values_select, so add it. + field_index = 0 + if issubclass(self._iterable_class, NamedValuesListIterable): + kwargs = {"named": True} + else: + kwargs = {} + get_obj = operator.itemgetter(slice(1, None)) + qs = qs.values_list(field_name, *self.query.values_select, **kwargs) + + get_key = operator.itemgetter(field_index) + + elif issubclass(self._iterable_class, FlatValuesListIterable): + if self.query.values_select == (field_name,): + # Mapping field_name to itself. + get_key = get_obj + else: + # Transform it back into a non-flat values_list(). + qs = qs.values_list(field_name, *self.query.values_select) + get_key = operator.itemgetter(0) + get_obj = operator.itemgetter(1) + + else: + raise TypeError( + f"in_bulk() cannot be used with {self._iterable_class.__name__}." + ) + if id_list is not None: if not id_list: return {} @@ -1193,15 +1244,16 @@ class QuerySet(AltersData): # If the database has a limit on the number of query parameters # (e.g. SQLite), retrieve objects in batches if necessary. if batch_size and batch_size < len(id_list): - qs = () + results = () for offset in range(0, len(id_list), batch_size): batch = id_list[offset : offset + batch_size] - qs += tuple(self.filter(**{filter_key: batch})) + results += tuple(qs.filter(**{filter_key: batch})) + qs = results else: - qs = self.filter(**{filter_key: id_list}) + qs = qs.filter(**{filter_key: id_list}) else: - qs = self._chain() - return {getattr(obj, field_name): obj for obj in qs} + qs = qs._chain() + return {get_key(obj): get_obj(obj) for obj in qs} async def ain_bulk(self, id_list=None, *, field_name="pk"): return await sync_to_async(self.in_bulk)( diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 5fb6cb33b3..4be1759af2 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -2588,6 +2588,11 @@ Example: If you pass ``in_bulk()`` an empty list, you'll get an empty dictionary. +.. versionchanged:: 6.1 + + Support for chaining ``in_bulk()`` after :meth:`values` or + :meth:`values_list` was added. + ``iterator()`` ~~~~~~~~~~~~~~ diff --git a/docs/releases/6.1.txt b/docs/releases/6.1.txt index edfdacf6b7..56a222f3e3 100644 --- a/docs/releases/6.1.txt +++ b/docs/releases/6.1.txt @@ -175,7 +175,8 @@ Migrations Models ~~~~~~ -* ... +* :meth:`.QuerySet.in_bulk` now supports chaining after + :meth:`.QuerySet.values` and :meth:`.QuerySet.values_list`. Pagination ~~~~~~~~~~ diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py index c4a8e6ca8c..3001847455 100644 --- a/tests/composite_pk/tests.py +++ b/tests/composite_pk/tests.py @@ -167,6 +167,67 @@ class CompositePKTests(TestCase): comment_dict = Comment.objects.in_bulk(id_list=id_list) self.assertQuerySetEqual(comment_dict, id_list) + def test_in_bulk_values(self): + result = Comment.objects.values().in_bulk([self.comment.pk]) + self.assertEqual( + result, + { + self.comment.pk: { + "tenant_id": self.comment.tenant_id, + "id": self.comment.id, + "user_id": self.comment.user_id, + "text": self.comment.text, + "integer": self.comment.integer, + } + }, + ) + + def test_in_bulk_values_field(self): + result = Comment.objects.values("text").in_bulk([self.comment.pk]) + self.assertEqual( + result, + {self.comment.pk: {"text": self.comment.text}}, + ) + + def test_in_bulk_values_fields(self): + result = Comment.objects.values("pk", "text").in_bulk([self.comment.pk]) + self.assertEqual( + result, + {self.comment.pk: {"pk": self.comment.pk, "text": self.comment.text}}, + ) + + def test_in_bulk_values_list(self): + result = Comment.objects.values_list("text").in_bulk([self.comment.pk]) + self.assertEqual(result, {self.comment.pk: (self.comment.text,)}) + + def test_in_bulk_values_list_multiple_fields(self): + result = Comment.objects.values_list("pk", "text").in_bulk([self.comment.pk]) + self.assertEqual( + result, {self.comment.pk: (self.comment.pk, self.comment.text)} + ) + + def test_in_bulk_values_list_fields_are_pk(self): + result = Comment.objects.values_list("tenant", "id").in_bulk([self.comment.pk]) + self.assertEqual( + result, {self.comment.pk: (self.comment.tenant_id, self.comment.id)} + ) + + def test_in_bulk_values_list_flat(self): + result = Comment.objects.values_list("text", flat=True).in_bulk( + [self.comment.pk] + ) + self.assertEqual(result, {self.comment.pk: self.comment.text}) + + def test_in_bulk_values_list_flat_pk(self): + result = Comment.objects.values_list("pk", flat=True).in_bulk([self.comment.pk]) + self.assertEqual(result, {self.comment.pk: self.comment.pk}) + + def test_in_bulk_values_list_flat_tenant(self): + result = Comment.objects.values_list("tenant", flat=True).in_bulk( + [self.comment.pk] + ) + self.assertEqual(result, {self.comment.pk: self.tenant.id}) + def test_iterator(self): """ Test the .iterator() method of composite_pk models. diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index e013666fc4..f6f73e9fac 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -317,12 +317,246 @@ class LookupTests(TestCase): with self.assertRaisesMessage(TypeError, msg): Article.objects.all()[0:5].in_bulk([self.a1.id, self.a2.id]) - def test_in_bulk_not_model_iterable(self): - msg = "in_bulk() cannot be used with values() or values_list()." - with self.assertRaisesMessage(TypeError, msg): - Author.objects.values().in_bulk() - with self.assertRaisesMessage(TypeError, msg): - Author.objects.values_list().in_bulk() + def test_in_bulk_values_empty(self): + arts = Article.objects.values().in_bulk([]) + self.assertEqual(arts, {}) + + def test_in_bulk_values_all(self): + Article.objects.exclude(pk__in=[self.a1.pk, self.a2.pk]).delete() + arts = Article.objects.values().in_bulk() + self.assertEqual( + arts, + { + self.a1.pk: { + "id": self.a1.pk, + "author_id": self.au1.pk, + "headline": "Article 1", + "pub_date": self.a1.pub_date, + "slug": "a1", + }, + self.a2.pk: { + "id": self.a2.pk, + "author_id": self.au1.pk, + "headline": "Article 2", + "pub_date": self.a2.pub_date, + "slug": "a2", + }, + }, + ) + + def test_in_bulk_values_pks(self): + arts = Article.objects.values().in_bulk([self.a1.pk]) + self.assertEqual( + arts, + { + self.a1.pk: { + "id": self.a1.pk, + "author_id": self.au1.pk, + "headline": "Article 1", + "pub_date": self.a1.pub_date, + "slug": "a1", + } + }, + ) + + def test_in_bulk_values_fields(self): + arts = Article.objects.values("headline").in_bulk([self.a1.pk]) + self.assertEqual( + arts, + {self.a1.pk: {"headline": "Article 1"}}, + ) + + def test_in_bulk_values_fields_including_pk(self): + arts = Article.objects.values("pk", "headline").in_bulk([self.a1.pk]) + self.assertEqual( + arts, + {self.a1.pk: {"pk": self.a1.pk, "headline": "Article 1"}}, + ) + + def test_in_bulk_values_fields_pk(self): + arts = Article.objects.values("pk").in_bulk([self.a1.pk]) + self.assertEqual( + arts, + {self.a1.pk: {"pk": self.a1.pk}}, + ) + + def test_in_bulk_values_fields_id(self): + arts = Article.objects.values("id").in_bulk([self.a1.pk]) + self.assertEqual( + arts, + {self.a1.pk: {"id": self.a1.pk}}, + ) + + def test_in_bulk_values_alternative_field_name(self): + arts = Article.objects.values("headline").in_bulk( + [self.a1.slug], field_name="slug" + ) + self.assertEqual( + arts, + {self.a1.slug: {"headline": "Article 1"}}, + ) + + def test_in_bulk_values_list_empty(self): + arts = Article.objects.values_list().in_bulk([]) + self.assertEqual(arts, {}) + + def test_in_bulk_values_list_all(self): + Article.objects.exclude(pk__in=[self.a1.pk, self.a2.pk]).delete() + arts = Article.objects.values_list().in_bulk() + self.assertEqual( + arts, + { + self.a1.pk: ( + self.a1.pk, + "Article 1", + self.a1.pub_date, + self.au1.pk, + "a1", + ), + self.a2.pk: ( + self.a2.pk, + "Article 2", + self.a2.pub_date, + self.au1.pk, + "a2", + ), + }, + ) + + def test_in_bulk_values_list_fields(self): + arts = Article.objects.values_list("headline").in_bulk([self.a1.pk, self.a2.pk]) + self.assertEqual( + arts, + { + self.a1.pk: ("Article 1",), + self.a2.pk: ("Article 2",), + }, + ) + + def test_in_bulk_values_list_fields_including_pk(self): + arts = Article.objects.values_list("pk", "headline").in_bulk( + [self.a1.pk, self.a2.pk] + ) + self.assertEqual( + arts, + { + self.a1.pk: (self.a1.pk, "Article 1"), + self.a2.pk: (self.a2.pk, "Article 2"), + }, + ) + + def test_in_bulk_values_list_fields_pk(self): + arts = Article.objects.values_list("pk").in_bulk([self.a1.pk, self.a2.pk]) + self.assertEqual( + arts, + { + self.a1.pk: (self.a1.pk,), + self.a2.pk: (self.a2.pk,), + }, + ) + + def test_in_bulk_values_list_fields_id(self): + arts = Article.objects.values_list("id").in_bulk([self.a1.pk, self.a2.pk]) + self.assertEqual( + arts, + { + self.a1.pk: (self.a1.pk,), + self.a2.pk: (self.a2.pk,), + }, + ) + + def test_in_bulk_values_list_named(self): + arts = Article.objects.values_list(named=True).in_bulk([self.a1.pk, self.a2.pk]) + self.assertIsInstance(arts, dict) + self.assertEqual(len(arts), 2) + arts1 = arts[self.a1.pk] + self.assertEqual( + arts1._fields, ("pk", "id", "headline", "pub_date", "author_id", "slug") + ) + self.assertEqual(arts1.pk, self.a1.pk) + self.assertEqual(arts1.headline, "Article 1") + self.assertEqual(arts1.pub_date, self.a1.pub_date) + self.assertEqual(arts1.author_id, self.au1.pk) + self.assertEqual(arts1.slug, "a1") + + def test_in_bulk_values_list_named_fields(self): + arts = Article.objects.values_list("pk", "headline", named=True).in_bulk( + [self.a1.pk, self.a2.pk] + ) + self.assertIsInstance(arts, dict) + self.assertEqual(len(arts), 2) + arts1 = arts[self.a1.pk] + self.assertEqual(arts1._fields, ("pk", "headline")) + self.assertEqual(arts1.pk, self.a1.pk) + self.assertEqual(arts1.headline, "Article 1") + + def test_in_bulk_values_list_named_fields_alternative_field(self): + arts = Article.objects.values_list("headline", named=True).in_bulk( + [self.a1.slug, self.a2.slug], field_name="slug" + ) + self.assertEqual(len(arts), 2) + arts1 = arts[self.a1.slug] + self.assertEqual(arts1._fields, ("slug", "headline")) + self.assertEqual(arts1.slug, "a1") + self.assertEqual(arts1.headline, "Article 1") + + def test_in_bulk_values_list_flat_empty(self): + arts = Article.objects.values_list(flat=True).in_bulk([]) + self.assertEqual(arts, {}) + + def test_in_bulk_values_list_flat_all(self): + Article.objects.exclude(pk__in=[self.a1.pk, self.a2.pk]).delete() + arts = Article.objects.values_list(flat=True).in_bulk() + self.assertEqual( + arts, + { + self.a1.pk: self.a1.pk, + self.a2.pk: self.a2.pk, + }, + ) + + def test_in_bulk_values_list_flat_pks(self): + arts = Article.objects.values_list(flat=True).in_bulk([self.a1.pk, self.a2.pk]) + self.assertEqual( + arts, + { + self.a1.pk: self.a1.pk, + self.a2.pk: self.a2.pk, + }, + ) + + def test_in_bulk_values_list_flat_field(self): + arts = Article.objects.values_list("headline", flat=True).in_bulk( + [self.a1.pk, self.a2.pk] + ) + self.assertEqual( + arts, + {self.a1.pk: "Article 1", self.a2.pk: "Article 2"}, + ) + + def test_in_bulk_values_list_flat_field_pk(self): + arts = Article.objects.values_list("pk", flat=True).in_bulk( + [self.a1.pk, self.a2.pk] + ) + self.assertEqual( + arts, + { + self.a1.pk: self.a1.pk, + self.a2.pk: self.a2.pk, + }, + ) + + def test_in_bulk_values_list_flat_field_id(self): + arts = Article.objects.values_list("id", flat=True).in_bulk( + [self.a1.pk, self.a2.pk] + ) + self.assertEqual( + arts, + { + self.a1.pk: self.a1.pk, + self.a2.pk: self.a2.pk, + }, + ) def test_values(self): # values() returns a list of dictionaries instead of object instances,