diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index 62284e980e..4015f16c6b 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -1,4 +1,6 @@ import functools +import itertools +import operator from collections import defaultdict from django.contrib.contenttypes.models import ContentType @@ -566,19 +568,28 @@ def create_generic_related_manager(superclass, rel): queryset._add_hints(instance=instances[0]) queryset = queryset.using(queryset._db or self._db) - - query = { - '%s__pk' % self.content_type_field_name: self.content_type.id, - '%s__in' % self.object_id_field_name: {obj.pk for obj in instances} - } - + # Group instances by content types. + content_type_queries = ( + models.Q(**{ + '%s__pk' % self.content_type_field_name: content_type_id, + '%s__in' % self.object_id_field_name: {obj.pk for obj in objs} + }) + for content_type_id, objs in itertools.groupby( + sorted(instances, key=lambda obj: self.get_content_type(obj).pk), + lambda obj: self.get_content_type(obj).pk, + ) + ) + query = functools.reduce(operator.or_, content_type_queries) # We (possibly) need to convert object IDs to the type of the # instances' PK in order to match up instances: object_id_converter = instances[0]._meta.pk.to_python return ( - queryset.filter(**query), - lambda relobj: object_id_converter(getattr(relobj, self.object_id_field_name)), - lambda obj: obj.pk, + queryset.filter(query), + lambda relobj: ( + object_id_converter(getattr(relobj, self.object_id_field_name)), + relobj.content_type_id + ), + lambda obj: (obj.pk, self.get_content_type(obj).pk), False, self.prefetch_cache_name, False, diff --git a/tests/generic_relations/tests.py b/tests/generic_relations/tests.py index 2e99c5b5cf..7c0db95908 100644 --- a/tests/generic_relations/tests.py +++ b/tests/generic_relations/tests.py @@ -546,6 +546,24 @@ class GenericRelationsTests(TestCase): platypus.tags.remove(weird_tag) self.assertSequenceEqual(platypus.tags.all(), [furry_tag]) + def test_prefetch_related_different_content_types(self): + TaggedItem.objects.create(content_object=self.platypus, tag='prefetch_tag_1') + TaggedItem.objects.create( + content_object=Vegetable.objects.create(name='Broccoli'), + tag='prefetch_tag_2', + ) + TaggedItem.objects.create( + content_object=Animal.objects.create(common_name='Bear'), + tag='prefetch_tag_3', + ) + qs = TaggedItem.objects.filter( + tag__startswith='prefetch_tag_', + ).prefetch_related('content_object', 'content_object__tags') + with self.assertNumQueries(4): + tags = list(qs) + for tag in tags: + self.assertSequenceEqual(tag.content_object.tags.all(), [tag]) + class ProxyRelatedModelTest(TestCase): def test_default_behavior(self):