diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index aef26444f0..23f8cd154b 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -375,24 +375,29 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri return hasattr(instance, self.cache_name) def get_queryset(self, **hints): + # Gotcha: we return a `Manager` instance (i.e. not a `QuerySet`)! return self.related.model._base_manager.db_manager(hints=hints) def get_prefetch_queryset(self, instances, queryset=None): - if queryset is not None: - raise ValueError("Custom queryset can't be used for this lookup.") + if queryset is None: + # Despite its name `get_queryset()` returns an instance of + # `Manager`, therefore we call `all()` to normalize to `QuerySet`. + queryset = self.get_queryset().all() + queryset._add_hints(instance=instances[0]) rel_obj_attr = attrgetter(self.related.field.attname) instance_attr = lambda obj: obj._get_pk_val() instances_dict = dict((instance_attr(inst), inst) for inst in instances) query = {'%s__in' % self.related.field.name: instances} - qs = self.get_queryset(instance=instances[0]).filter(**query) + queryset = queryset.filter(**query) + # Since we're going to assign directly in the cache, # we must manage the reverse relation cache manually. rel_obj_cache_name = self.related.field.get_cache_name() - for rel_obj in qs: + for rel_obj in queryset: instance = instances_dict[rel_obj_attr(rel_obj)] setattr(rel_obj, rel_obj_cache_name, instance) - return qs, rel_obj_attr, instance_attr, True, self.cache_name + return queryset, rel_obj_attr, instance_attr, True, self.cache_name def __get__(self, instance, instance_type=None): if instance is None: @@ -503,13 +508,17 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec # If the related manager indicates that it should be used for # related fields, respect that. if getattr(rel_mgr, 'use_for_related_fields', False): + # Gotcha: we return a `Manager` instance (i.e. not a `QuerySet`)! return rel_mgr else: return QuerySet(self.field.rel.to, hints=hints) def get_prefetch_queryset(self, instances, queryset=None): - if queryset is not None: - raise ValueError("Custom queryset can't be used for this lookup.") + if queryset is None: + # Despite its name `get_queryset()` may return an instance of + # `Manager`, therefore we call `all()` to normalize to `QuerySet`. + queryset = self.get_queryset().all() + queryset._add_hints(instance=instances[0]) rel_obj_attr = self.field.get_foreign_related_value instance_attr = self.field.get_local_related_value @@ -524,16 +533,16 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec query = {'%s__in' % related_field.name: set(instance_attr(inst)[0] for inst in instances)} else: query = {'%s__in' % self.field.related_query_name(): instances} + queryset = queryset.filter(**query) - qs = self.get_queryset(instance=instances[0]).filter(**query) # Since we're going to assign directly in the cache, # we must manage the reverse relation cache manually. if not self.field.rel.multiple: rel_obj_cache_name = self.field.related.get_cache_name() - for rel_obj in qs: + for rel_obj in queryset: instance = instances_dict[rel_obj_attr(rel_obj)] setattr(rel_obj, rel_obj_cache_name, instance) - return qs, rel_obj_attr, instance_attr, True, self.cache_name + return queryset, rel_obj_attr, instance_attr, True, self.cache_name def __get__(self, instance, instance_type=None): if instance is None: diff --git a/django/db/models/query.py b/django/db/models/query.py index 954ca65f3a..15339ccbdb 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1667,8 +1667,8 @@ class Prefetch(object): def get_current_to_attr(self, level): parts = self.prefetch_to.split(LOOKUP_SEP) to_attr = parts[level] - to_list = self.to_attr and level == len(parts) - 1 - return to_attr, to_list + as_attr = self.to_attr and level == len(parts) - 1 + return to_attr, as_attr def get_current_queryset(self, level): if self.get_current_prefetch_to(level) == self.prefetch_to: @@ -1913,12 +1913,13 @@ def prefetch_one_level(instances, prefetcher, lookup, level): for obj in instances: instance_attr_val = instance_attr(obj) vals = rel_obj_cache.get(instance_attr_val, []) + to_attr, as_attr = lookup.get_current_to_attr(level) if single: - # Need to assign to single cache on instance - setattr(obj, cache_name, vals[0] if vals else None) + val = vals[0] if vals else None + to_attr = to_attr if as_attr else cache_name + setattr(obj, to_attr, val) else: - to_attr, to_list = lookup.get_current_to_attr(level) - if to_list: + if as_attr: setattr(obj, to_attr, vals) else: # Cache in the QuerySet.all(). diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index a160587953..22ccda4b8e 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -989,6 +989,24 @@ less ambiguous than storing a filtered result in the related manager's cache: ... Prefetch('pizzas', queryset=queryset)) >>> vegetarian_pizzas = restaurants[0].pizzas.all() +Custom prefetching also works with single related relations like +forward ``ForeignKey`` or ``OneToOneField``. Generally you'll want to use +:meth:`select_related()` for these relations, but there are a number of cases +where prefetching with a custom ``QuerySet`` is useful: + +* You want to use a ``QuerySet`` that performs further prefetching + on related models. + +* You want to prefetch only a subset of the related objects. + +* You want to use performance optimization techniques like + :meth:`deferred fields `: + + >>> queryset = Pizza.objects.only('name') + >>> + >>> restaurants = Restaurant.objects.prefetch_related( + ... Prefetch('best_pizza', queryset=queryset)) + .. note:: The ordering of lookups matters. diff --git a/tests/prefetch_related/models.py b/tests/prefetch_related/models.py index 35be3d6804..7d300ea84c 100644 --- a/tests/prefetch_related/models.py +++ b/tests/prefetch_related/models.py @@ -170,8 +170,10 @@ class Comment(models.Model): ## Models for lookup ordering tests class House(models.Model): + name = models.CharField(max_length=50) address = models.CharField(max_length=255) owner = models.ForeignKey('Person', null=True) + main_room = models.OneToOneField('Room', related_name='main_room_of', null=True) class Meta: ordering = ['id'] diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index 5dfca0fa99..1bb64c95fb 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +from django.core.exceptions import ObjectDoesNotExist from django.contrib.contenttypes.models import ContentType from django.db import connection from django.db.models import Prefetch @@ -228,17 +229,22 @@ class CustomPrefetchTests(TestCase): for part in path: if not part: continue - rel_objs.extend(cls.traverse_qs(getattr(obj, part[0]), [part[1:]])) + try: + related = getattr(obj, part[0]) + except ObjectDoesNotExist: + continue + if related is not None: + rel_objs.extend(cls.traverse_qs(related, [part[1:]])) ret_val.append((obj, rel_objs)) return ret_val def setUp(self): self.person1 = Person.objects.create(name="Joe") self.person2 = Person.objects.create(name="Mary") - self.house1 = House.objects.create(address="123 Main St", owner=self.person1) - self.house2 = House.objects.create(address="45 Side St", owner=self.person1) - self.house3 = House.objects.create(address="6 Downing St", owner=self.person2) - self.house4 = House.objects.create(address="7 Regents St", owner=self.person2) + self.house1 = House.objects.create(name='House 1', address="123 Main St", owner=self.person1) + self.house2 = House.objects.create(name='House 2', address="45 Side St", owner=self.person1) + self.house3 = House.objects.create(name='House 3', address="6 Downing St", owner=self.person2) + self.house4 = House.objects.create(name='house 4', address="7 Regents St", owner=self.person2) self.room1_1 = Room.objects.create(name="Dining room", house=self.house1) self.room1_2 = Room.objects.create(name="Lounge", house=self.house1) self.room1_3 = Room.objects.create(name="Kitchen", house=self.house1) @@ -253,6 +259,14 @@ class CustomPrefetchTests(TestCase): self.room4_3 = Room.objects.create(name="Kitchen", house=self.house4) self.person1.houses.add(self.house1, self.house2) self.person2.houses.add(self.house3, self.house4) + self.house1.main_room = self.room1_1 + self.house1.save() + self.house2.main_room = self.room2_1 + self.house2.save() + self.house3.main_room = self.room3_1 + self.house3.save() + self.house4.main_room = self.room4_1 + self.house4.save() def test_traverse_qs(self): qs = Person.objects.prefetch_related('houses') @@ -262,16 +276,17 @@ class CustomPrefetchTests(TestCase): self.assertEqual(related_objs_normal, (related_objs_from_traverse,)) def test_ambiguous(self): - # Ambiguous. + # Ambiguous: Lookup was already seen with a different queryset. with self.assertRaises(ValueError): self.traverse_qs( Person.objects.prefetch_related('houses__rooms', Prefetch('houses', queryset=House.objects.all())), [['houses', 'rooms']] ) + # Ambiguous: Lookup houses_lst doesn't yet exist when performing houses_lst__rooms. with self.assertRaises(AttributeError): self.traverse_qs( - Person.objects.prefetch_related('houses_list__rooms', Prefetch('houses', queryset=House.objects.all(), to_attr='houses_lst')), + Person.objects.prefetch_related('houses_lst__rooms', Prefetch('houses', queryset=House.objects.all(), to_attr='houses_lst')), [['houses', 'rooms']] ) @@ -491,6 +506,50 @@ class CustomPrefetchTests(TestCase): self.assertEqual(lst2[0].houses_lst[0].rooms_lst[1], self.room1_2) self.assertEqual(len(lst2[1].houses_lst), 0) + # Test ReverseSingleRelatedObjectDescriptor. + houses = House.objects.select_related('owner') + with self.assertNumQueries(6): + rooms = Room.objects.all().prefetch_related('house') + lst1 = self.traverse_qs(rooms, [['house', 'owner']]) + with self.assertNumQueries(2): + rooms = Room.objects.all().prefetch_related(Prefetch('house', queryset=houses.all())) + lst2 = self.traverse_qs(rooms, [['house', 'owner']]) + self.assertEqual(lst1, lst2) + with self.assertNumQueries(2): + houses = House.objects.select_related('owner') + rooms = Room.objects.all().prefetch_related(Prefetch('house', queryset=houses.all(), to_attr='house_attr')) + lst2 = self.traverse_qs(rooms, [['house_attr', 'owner']]) + self.assertEqual(lst1, lst2) + room = Room.objects.all().prefetch_related(Prefetch('house', queryset=houses.filter(address='DoesNotExist'))).first() + with self.assertRaises(ObjectDoesNotExist): + getattr(room, 'house') + room = Room.objects.all().prefetch_related(Prefetch('house', queryset=houses.filter(address='DoesNotExist'), to_attr='house_attr')).first() + self.assertIsNone(room.house_attr) + rooms = Room.objects.all().prefetch_related(Prefetch('house', queryset=House.objects.only('name'))) + with self.assertNumQueries(2): + getattr(rooms.first().house, 'name') + with self.assertNumQueries(3): + getattr(rooms.first().house, 'address') + + # Test SingleRelatedObjectDescriptor. + houses = House.objects.select_related('owner') + with self.assertNumQueries(6): + rooms = Room.objects.all().prefetch_related('main_room_of') + lst1 = self.traverse_qs(rooms, [['main_room_of', 'owner']]) + with self.assertNumQueries(2): + rooms = Room.objects.all().prefetch_related(Prefetch('main_room_of', queryset=houses.all())) + lst2 = self.traverse_qs(rooms, [['main_room_of', 'owner']]) + self.assertEqual(lst1, lst2) + with self.assertNumQueries(2): + rooms = list(Room.objects.all().prefetch_related(Prefetch('main_room_of', queryset=houses.all(), to_attr='main_room_of_attr'))) + lst2 = self.traverse_qs(rooms, [['main_room_of_attr', 'owner']]) + self.assertEqual(lst1, lst2) + room = Room.objects.filter(main_room_of__isnull=False).prefetch_related(Prefetch('main_room_of', queryset=houses.filter(address='DoesNotExist'))).first() + with self.assertRaises(ObjectDoesNotExist): + getattr(room, 'main_room_of') + room = Room.objects.filter(main_room_of__isnull=False).prefetch_related(Prefetch('main_room_of', queryset=houses.filter(address='DoesNotExist'), to_attr='main_room_of_attr')).first() + self.assertIsNone(room.main_room_of_attr) + class DefaultManagerTests(TestCase):