diff --git a/django/db/models/query.py b/django/db/models/query.py index e5e1c1b9f4..36ebec1905 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -560,16 +560,19 @@ class QuerySet: return objects[0] return None - def in_bulk(self, id_list=None): + def in_bulk(self, id_list=None, *, field_name='pk'): """ Return a dictionary mapping each of the given IDs to the object with that ID. If `id_list` isn't provided, evaluate the entire QuerySet. """ assert self.query.can_filter(), \ "Cannot use 'limit' or 'offset' with in_bulk" + if field_name != 'pk' and not self.model._meta.get_field(field_name).unique: + raise ValueError("in_bulk()'s field_name must be a unique field but %r isn't." % field_name) if id_list is not None: if not id_list: return {} + filter_key = '{}__in'.format(field_name) batch_size = connections[self.db].features.max_query_params id_list = tuple(id_list) # If the database has a limit on the number of query parameters @@ -578,12 +581,12 @@ class QuerySet: qs = () for offset in range(0, len(id_list), batch_size): batch = id_list[offset:offset + batch_size] - qs += tuple(self.filter(pk__in=batch).order_by()) + qs += tuple(self.filter(**{filter_key: batch}).order_by()) else: - qs = self.filter(pk__in=id_list).order_by() + qs = self.filter(**{filter_key: id_list}).order_by() else: qs = self._clone() - return {obj.pk: obj for obj in qs} + return {getattr(obj, field_name): obj for obj in qs} def delete(self): """Delete the records in the current QuerySet.""" diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 2c1f68fc06..74f83ab8c5 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1997,11 +1997,13 @@ database query like ``count()`` would. ``in_bulk()`` ~~~~~~~~~~~~~ -.. method:: in_bulk(id_list=None) +.. method:: in_bulk(id_list=None, field_name='pk') -Takes a list of primary-key values and returns a dictionary mapping each -primary-key value to an instance of the object with the given ID. If a list -isn't provided, all objects in the queryset are returned. +Takes a list of field values (``id_list``) and the ``field_name`` for those +values, and returns a dictionary mapping each value to an instance of the +object with the given field value. If ``id_list`` isn't provided, all objects +in the queryset are returned. ``field_name`` must be a unique field, and it +defaults to the primary key. Example:: @@ -2013,9 +2015,15 @@ Example:: {} >>> Blog.objects.in_bulk() {1: <Blog: Beatles Blog>, 2: <Blog: Cheddar Talk>, 3: <Blog: Django Weblog>} + >>> Blog.objects.in_bulk(['beatles_blog'], field_name='slug') + {'beatles_blog': <Blog: Beatles Blog>} If you pass ``in_bulk()`` an empty list, you'll get an empty dictionary. +.. versionchanged:: 2.0 + + The ``field_name`` parameter was added. + ``iterator()`` ~~~~~~~~~~~~~~ diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index 16024ae244..162ac42306 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -259,6 +259,9 @@ Models :meth:`~.QuerySet.select_for_update()` is used in conjunction with :meth:`~.QuerySet.select_related()`. +* The new ``field_name`` parameter of :meth:`.QuerySet.in_bulk` allows fetching + results based on any unique model field. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/lookup/models.py b/tests/lookup/models.py index 14742e8a8c..2fa5b87755 100644 --- a/tests/lookup/models.py +++ b/tests/lookup/models.py @@ -26,6 +26,7 @@ class Article(models.Model): headline = models.CharField(max_length=100) pub_date = models.DateTimeField() author = models.ForeignKey(Author, models.SET_NULL, blank=True, null=True) + slug = models.SlugField(unique=True, blank=True, null=True) class Meta: ordering = ('-pub_date', 'headline') diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index 70ccf2c02d..289d8dde1b 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -16,14 +16,49 @@ class LookupTests(TestCase): # Create a few Authors. self.au1 = Author.objects.create(name='Author 1') self.au2 = Author.objects.create(name='Author 2') - # Create a couple of Articles. - self.a1 = Article.objects.create(headline='Article 1', pub_date=datetime(2005, 7, 26), author=self.au1) - self.a2 = Article.objects.create(headline='Article 2', pub_date=datetime(2005, 7, 27), author=self.au1) - self.a3 = Article.objects.create(headline='Article 3', pub_date=datetime(2005, 7, 27), author=self.au1) - self.a4 = Article.objects.create(headline='Article 4', pub_date=datetime(2005, 7, 28), author=self.au1) - self.a5 = Article.objects.create(headline='Article 5', pub_date=datetime(2005, 8, 1, 9, 0), author=self.au2) - self.a6 = Article.objects.create(headline='Article 6', pub_date=datetime(2005, 8, 1, 8, 0), author=self.au2) - self.a7 = Article.objects.create(headline='Article 7', pub_date=datetime(2005, 7, 27), author=self.au2) + # Create a few Articles. + self.a1 = Article.objects.create( + headline='Article 1', + pub_date=datetime(2005, 7, 26), + author=self.au1, + slug='a1', + ) + self.a2 = Article.objects.create( + headline='Article 2', + pub_date=datetime(2005, 7, 27), + author=self.au1, + slug='a2', + ) + self.a3 = Article.objects.create( + headline='Article 3', + pub_date=datetime(2005, 7, 27), + author=self.au1, + slug='a3', + ) + self.a4 = Article.objects.create( + headline='Article 4', + pub_date=datetime(2005, 7, 28), + author=self.au1, + slug='a4', + ) + self.a5 = Article.objects.create( + headline='Article 5', + pub_date=datetime(2005, 8, 1, 9, 0), + author=self.au2, + slug='a5', + ) + self.a6 = Article.objects.create( + headline='Article 6', + pub_date=datetime(2005, 8, 1, 8, 0), + author=self.au2, + slug='a6', + ) + self.a7 = Article.objects.create( + headline='Article 7', + pub_date=datetime(2005, 7, 27), + author=self.au2, + slug='a7', + ) # Create a few Tags. self.t1 = Tag.objects.create(name='Tag 1') self.t1.articles.add(self.a1, self.a2, self.a3) @@ -138,6 +173,21 @@ class LookupTests(TestCase): with self.assertNumQueries(expected_num_queries): self.assertEqual(Author.objects.in_bulk(authors), authors) + def test_in_bulk_with_field(self): + self.assertEqual( + Article.objects.in_bulk([self.a1.slug, self.a2.slug, self.a3.slug], field_name='slug'), + { + self.a1.slug: self.a1, + self.a2.slug: self.a2, + self.a3.slug: self.a3, + } + ) + + def test_in_bulk_non_unique_field(self): + msg = "in_bulk()'s field_name must be a unique field but 'author' isn't." + with self.assertRaisesMessage(ValueError, msg): + Article.objects.in_bulk([self.au1], field_name='author') + def test_values(self): # values() returns a list of dictionaries instead of object instances -- # and you can specify which fields you want to retrieve. @@ -274,7 +324,8 @@ class LookupTests(TestCase): 'id': self.a5.id, 'author_id': self.au2.id, 'headline': 'Article 5', - 'pub_date': datetime(2005, 8, 1, 9, 0) + 'pub_date': datetime(2005, 8, 1, 9, 0), + 'slug': 'a5', }], ) @@ -503,7 +554,7 @@ class LookupTests(TestCase): with self.assertRaisesMessage( FieldError, "Cannot resolve keyword 'pub_date_year' into field. Choices are: " - "author, author_id, headline, id, pub_date, tag" + "author, author_id, headline, id, pub_date, slug, tag" ): Article.objects.filter(pub_date_year='2005').count()