From 6dc9b04018032dccbb5ad8347f7ddf4341316166 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 14 Apr 2025 15:12:28 +0100 Subject: [PATCH] Refs #28586 -- Copied fetch modes to related objects. This change ensures that behavior and performance remain consistent when traversing relationships. --- django/contrib/contenttypes/fields.py | 14 ++++-- .../db/models/fields/related_descriptors.py | 7 ++- django/db/models/query.py | 22 ++++++--- docs/topics/db/fetch-modes.txt | 5 ++ tests/foreign_object/tests.py | 37 ++++++++++++++ tests/generic_relations/tests.py | 32 +++++++++++- tests/many_to_many/tests.py | 41 ++++++++++++++++ tests/many_to_one/tests.py | 49 +++++++++++++++++++ tests/model_inheritance_regress/tests.py | 17 +++++++ tests/one_to_one/tests.py | 38 ++++++++++++++ tests/prefetch_related/tests.py | 30 +++++++++++- tests/select_related/tests.py | 32 ++++++++++++ 12 files changed, 310 insertions(+), 14 deletions(-) diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index aa41eab370..62239dc715 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -201,11 +201,13 @@ class GenericForeignKeyDescriptor: for ct_id, fkeys in fk_dict.items(): if ct_id in custom_queryset_dict: # Return values from the custom queryset, if provided. - ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys)) + queryset = custom_queryset_dict[ct_id].filter(pk__in=fkeys) else: instance = instance_dict[ct_id] ct = self.field.get_content_type(id=ct_id, using=instance._state.db) - ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys)) + queryset = ct.get_all_objects_for_this_type(pk__in=fkeys) + + ret_val.extend(queryset.fetch_mode(instances[0]._state.fetch_mode)) # For doing the join in Python, we have to match both the FK val and # the content type, so we use a callable that returns a (fk, class) @@ -271,6 +273,8 @@ class GenericForeignKeyDescriptor: ) except ObjectDoesNotExist: pass + else: + rel_obj._state.fetch_mode = instance._state.fetch_mode self.field.set_cached_value(instance, rel_obj) def fetch_many(self, instances): @@ -636,7 +640,11 @@ def create_generic_related_manager(superclass, rel): Filter the queryset for the instance this manager is bound to. """ db = self._db or router.db_for_read(self.model, instance=self.instance) - return queryset.using(db).filter(**self.core_filters) + return ( + queryset.using(db) + .fetch_mode(self.instance._state.fetch_mode) + .filter(**self.core_filters) + ) def _remove_prefetched_objects(self): try: diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 7df96491a0..4728233a6a 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -169,7 +169,7 @@ class ForwardManyToOneDescriptor: def get_queryset(self, *, instance): return self.field.remote_field.model._base_manager.db_manager( hints={"instance": instance} - ).all() + ).fetch_mode(instance._state.fetch_mode) def get_prefetch_querysets(self, instances, querysets=None): if querysets and len(querysets) != 1: @@ -398,6 +398,7 @@ class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor): obj = rel_model(**kwargs) obj._state.adding = instance._state.adding obj._state.db = instance._state.db + obj._state.fetch_mode = instance._state.fetch_mode return obj return super().get_object(instance) @@ -462,7 +463,7 @@ class ReverseOneToOneDescriptor: def get_queryset(self, *, instance): return self.related.related_model._base_manager.db_manager( hints={"instance": instance} - ).all() + ).fetch_mode(instance._state.fetch_mode) def get_prefetch_querysets(self, instances, querysets=None): if querysets and len(querysets) != 1: @@ -740,6 +741,7 @@ def create_reverse_many_to_one_manager(superclass, rel): queryset._add_hints(instance=self.instance) if self._db: queryset = queryset.using(self._db) + queryset._fetch_mode = self.instance._state.fetch_mode queryset._defer_next_filter = True queryset = queryset.filter(**self.core_filters) for field in self.field.foreign_related_fields: @@ -1141,6 +1143,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): queryset._add_hints(instance=self.instance) if self._db: queryset = queryset.using(self._db) + queryset._fetch_mode = self.instance._state.fetch_mode queryset._defer_next_filter = True return queryset._next_is_sticky().filter(**self.core_filters) diff --git a/django/db/models/query.py b/django/db/models/query.py index 0811b90b5e..0a577f8c2d 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -90,6 +90,7 @@ class ModelIterable(BaseIterable): queryset = self.queryset db = queryset.db compiler = queryset.query.get_compiler(using=db) + fetch_mode = queryset._fetch_mode # Execute the query. This will also fill compiler.select, klass_info, # and annotations. results = compiler.execute_sql( @@ -106,7 +107,7 @@ class ModelIterable(BaseIterable): init_list = [ f[0].target.attname for f in select[model_fields_start:model_fields_end] ] - related_populators = get_related_populators(klass_info, select, db) + related_populators = get_related_populators(klass_info, select, db, fetch_mode) known_related_objects = [ ( field, @@ -124,7 +125,6 @@ class ModelIterable(BaseIterable): ) for field, related_objs in queryset._known_related_objects.items() ] - fetch_mode = queryset._fetch_mode peers = [] for row in compiler.results_iter(results): obj = model_cls.from_db( @@ -2787,8 +2787,9 @@ class RelatedPopulator: model instance. """ - def __init__(self, klass_info, select, db): + def __init__(self, klass_info, select, db, fetch_mode): self.db = db + self.fetch_mode = fetch_mode # Pre-compute needed attributes. The attributes are: # - model_cls: the possibly deferred model class to instantiate # - either: @@ -2841,7 +2842,9 @@ class RelatedPopulator: # relationship. Therefore checking for a single member of the primary # key is enough to determine if the referenced object exists or not. self.pk_idx = self.init_list.index(self.model_cls._meta.pk_fields[0].attname) - self.related_populators = get_related_populators(klass_info, select, self.db) + self.related_populators = get_related_populators( + klass_info, select, self.db, fetch_mode + ) self.local_setter = klass_info["local_setter"] self.remote_setter = klass_info["remote_setter"] @@ -2853,7 +2856,12 @@ class RelatedPopulator: if obj_data[self.pk_idx] is None: obj = None else: - obj = self.model_cls.from_db(self.db, self.init_list, obj_data) + obj = self.model_cls.from_db( + self.db, + self.init_list, + obj_data, + fetch_mode=self.fetch_mode, + ) for rel_iter in self.related_populators: rel_iter.populate(row, obj) self.local_setter(from_obj, obj) @@ -2861,10 +2869,10 @@ class RelatedPopulator: self.remote_setter(obj, from_obj) -def get_related_populators(klass_info, select, db): +def get_related_populators(klass_info, select, db, fetch_mode): iterators = [] related_klass_infos = klass_info.get("related_klass_infos", []) for rel_klass_info in related_klass_infos: - rel_cls = RelatedPopulator(rel_klass_info, select, db) + rel_cls = RelatedPopulator(rel_klass_info, select, db, fetch_mode) iterators.append(rel_cls) return iterators diff --git a/docs/topics/db/fetch-modes.txt b/docs/topics/db/fetch-modes.txt index e76bb28a59..da7a07a0d4 100644 --- a/docs/topics/db/fetch-modes.txt +++ b/docs/topics/db/fetch-modes.txt @@ -29,6 +29,11 @@ Fetch modes apply to: * Fields deferred with :meth:`.QuerySet.defer` or :meth:`.QuerySet.only` * :ref:`generic-relations` +Django copies the fetch mode of an instance to any related objects it fetches, +so the mode applies to a whole tree of relationships, not just the top-level +model in the initial ``QuerySet``. This copying is also done in related +managers, even though fetch modes don't affect such managers' queries. + Available modes =============== diff --git a/tests/foreign_object/tests.py b/tests/foreign_object/tests.py index 09fb47e771..233c596885 100644 --- a/tests/foreign_object/tests.py +++ b/tests/foreign_object/tests.py @@ -5,6 +5,7 @@ from operator import attrgetter from django.core.exceptions import FieldError, ValidationError from django.db import connection, models +from django.db.models import FETCH_PEERS from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext, isolate_apps from django.utils import translation @@ -603,6 +604,42 @@ class MultiColumnFKTests(TestCase): [m4], ) + def test_fetch_mode_copied_forward_fetching_one(self): + person = Person.objects.fetch_mode(FETCH_PEERS).get(pk=self.bob.pk) + self.assertEqual(person._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + person.person_country._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_forward_fetching_many(self): + people = list(Person.objects.fetch_mode(FETCH_PEERS)) + person = people[0] + self.assertEqual(person._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + person.person_country._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_one(self): + country = Country.objects.fetch_mode(FETCH_PEERS).get(pk=self.usa.pk) + self.assertEqual(country._state.fetch_mode, FETCH_PEERS) + person = country.person_set.get(pk=self.bob.pk) + self.assertEqual( + person._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_many(self): + countries = list(Country.objects.fetch_mode(FETCH_PEERS)) + country = countries[0] + self.assertEqual(country._state.fetch_mode, FETCH_PEERS) + person = country.person_set.earliest("pk") + self.assertEqual( + person._state.fetch_mode, + FETCH_PEERS, + ) + class TestModelCheckTests(SimpleTestCase): @isolate_apps("foreign_object") diff --git a/tests/generic_relations/tests.py b/tests/generic_relations/tests.py index 3de243d7b8..dceb8f4bae 100644 --- a/tests/generic_relations/tests.py +++ b/tests/generic_relations/tests.py @@ -813,7 +813,6 @@ class GenericRelationsTests(TestCase): self.assertEqual(quartz_tag.content_object, self.quartz) def test_fetch_mode_raise(self): - TaggedItem.objects.create(tag="lion", content_object=self.lion) tag = TaggedItem.objects.fetch_mode(RAISE).get(tag="yellow") msg = "Fetching of TaggedItem.content_object blocked." with self.assertRaisesMessage(FieldFetchBlocked, msg) as cm: @@ -821,6 +820,37 @@ class GenericRelationsTests(TestCase): self.assertIsNone(cm.exception.__cause__) self.assertTrue(cm.exception.__suppress_context__) + def test_fetch_mode_copied_forward_fetching_one(self): + tag = TaggedItem.objects.fetch_mode(FETCH_PEERS).get(tag="yellow") + self.assertEqual(tag.content_object, self.lion) + self.assertEqual( + tag.content_object._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_forward_fetching_many(self): + tags = list(TaggedItem.objects.fetch_mode(FETCH_PEERS).order_by("tag")) + tag = [t for t in tags if t.tag == "yellow"][0] + self.assertEqual(tag.content_object, self.lion) + self.assertEqual( + tag.content_object._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_one(self): + animal = Animal.objects.fetch_mode(FETCH_PEERS).get(pk=self.lion.pk) + self.assertEqual(animal._state.fetch_mode, FETCH_PEERS) + tag = animal.tags.get(tag="yellow") + self.assertEqual(tag._state.fetch_mode, FETCH_PEERS) + + def test_fetch_mode_copied_reverse_fetching_many(self): + animals = list(Animal.objects.fetch_mode(FETCH_PEERS)) + animal = animals[0] + self.assertEqual(animal._state.fetch_mode, FETCH_PEERS) + tags = list(animal.tags.all()) + tag = tags[0] + self.assertEqual(tag._state.fetch_mode, FETCH_PEERS) + class ProxyRelatedModelTest(TestCase): def test_default_behavior(self): diff --git a/tests/many_to_many/tests.py b/tests/many_to_many/tests.py index 34b7ffc67d..30fbde873e 100644 --- a/tests/many_to_many/tests.py +++ b/tests/many_to_many/tests.py @@ -1,6 +1,7 @@ from unittest import mock from django.db import connection, transaction +from django.db.models import FETCH_PEERS from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from .models import ( @@ -589,6 +590,46 @@ class ManyToManyTests(TestCase): querysets=[Publication.objects.all(), Publication.objects.all()], ) + def test_fetch_mode_copied_forward_fetching_one(self): + a = Article.objects.fetch_mode(FETCH_PEERS).get(pk=self.a1.pk) + self.assertEqual(a._state.fetch_mode, FETCH_PEERS) + p = a.publications.earliest("pk") + self.assertEqual( + p._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_forward_fetching_many(self): + articles = list(Article.objects.fetch_mode(FETCH_PEERS)) + a = articles[0] + self.assertEqual(a._state.fetch_mode, FETCH_PEERS) + publications = list(a.publications.all()) + p = publications[0] + self.assertEqual( + p._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_one(self): + p1 = Publication.objects.fetch_mode(FETCH_PEERS).get(pk=self.p1.pk) + self.assertEqual(p1._state.fetch_mode, FETCH_PEERS) + a = p1.article_set.earliest("pk") + self.assertEqual( + a._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_many(self): + publications = list(Publication.objects.fetch_mode(FETCH_PEERS)) + p = publications[0] + self.assertEqual(p._state.fetch_mode, FETCH_PEERS) + articles = list(p.article_set.all()) + a = articles[0] + self.assertEqual( + a._state.fetch_mode, + FETCH_PEERS, + ) + class ManyToManyQueryTests(TestCase): """ diff --git a/tests/many_to_one/tests.py b/tests/many_to_one/tests.py index c5fa458570..4d2343e304 100644 --- a/tests/many_to_one/tests.py +++ b/tests/many_to_one/tests.py @@ -941,3 +941,52 @@ class ManyToOneTests(TestCase): a.reporter self.assertIsNone(cm.exception.__cause__) self.assertTrue(cm.exception.__suppress_context__) + + def test_fetch_mode_copied_forward_fetching_one(self): + a1 = Article.objects.fetch_mode(FETCH_PEERS).get() + self.assertEqual(a1._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + a1.reporter._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_forward_fetching_many(self): + Article.objects.create( + headline="This is another test", + pub_date=datetime.date(2005, 7, 27), + reporter=self.r2, + ) + a1, a2 = Article.objects.fetch_mode(FETCH_PEERS) + self.assertEqual(a1._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + a1.reporter._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_one(self): + r1 = Reporter.objects.fetch_mode(FETCH_PEERS).get(pk=self.r.pk) + self.assertEqual(r1._state.fetch_mode, FETCH_PEERS) + article = r1.article_set.get() + self.assertEqual( + article._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_many(self): + Article.objects.create( + headline="This is another test", + pub_date=datetime.date(2005, 7, 27), + reporter=self.r2, + ) + r1, r2 = Reporter.objects.fetch_mode(FETCH_PEERS) + self.assertEqual(r1._state.fetch_mode, FETCH_PEERS) + a1 = r1.article_set.get() + self.assertEqual( + a1._state.fetch_mode, + FETCH_PEERS, + ) + a2 = r2.article_set.get() + self.assertEqual( + a2._state.fetch_mode, + FETCH_PEERS, + ) diff --git a/tests/model_inheritance_regress/tests.py b/tests/model_inheritance_regress/tests.py index 3310497de1..adc2a22fc4 100644 --- a/tests/model_inheritance_regress/tests.py +++ b/tests/model_inheritance_regress/tests.py @@ -7,6 +7,7 @@ from operator import attrgetter from unittest import expectedFailure from django import forms +from django.db.models import FETCH_PEERS from django.test import TestCase from .models import ( @@ -600,6 +601,22 @@ class ModelInheritanceTest(TestCase): self.assertEqual(restaurant.place_ptr.restaurant, restaurant) self.assertEqual(restaurant.italianrestaurant, italian_restaurant) + def test_parent_access_copies_fetch_mode(self): + italian_restaurant = ItalianRestaurant.objects.create( + name="Mom's Spaghetti", + address="2131 Woodward Ave", + serves_hot_dogs=False, + serves_pizza=False, + serves_gnocchi=True, + ) + + # No queries are made when accessing the parent objects. + italian_restaurant = ItalianRestaurant.objects.fetch_mode(FETCH_PEERS).get( + pk=italian_restaurant.pk + ) + restaurant = italian_restaurant.restaurant_ptr + self.assertEqual(restaurant._state.fetch_mode, FETCH_PEERS) + def test_id_field_update_on_ancestor_change(self): place1 = Place.objects.create(name="House of Pasta", address="944 Fullerton") place2 = Place.objects.create(name="House of Pizza", address="954 Fullerton") diff --git a/tests/one_to_one/tests.py b/tests/one_to_one/tests.py index da7bd992c0..39f24d6b10 100644 --- a/tests/one_to_one/tests.py +++ b/tests/one_to_one/tests.py @@ -657,3 +657,41 @@ class OneToOneTests(TestCase): p.restaurant self.assertIsNone(cm.exception.__cause__) self.assertTrue(cm.exception.__suppress_context__) + + def test_fetch_mode_copied_forward_fetching_one(self): + r1 = Restaurant.objects.fetch_mode(FETCH_PEERS).get(pk=self.r1.pk) + self.assertEqual(r1._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + r1.place._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_forward_fetching_many(self): + Restaurant.objects.create( + place=self.p2, serves_hot_dogs=True, serves_pizza=False + ) + r1, r2 = Restaurant.objects.fetch_mode(FETCH_PEERS) + self.assertEqual(r1._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + r1.place._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_one(self): + p1 = Place.objects.fetch_mode(FETCH_PEERS).get(pk=self.p1.pk) + self.assertEqual(p1._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + p1.restaurant._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_reverse_fetching_many(self): + Restaurant.objects.create( + place=self.p2, serves_hot_dogs=True, serves_pizza=False + ) + p1, p2 = Place.objects.fetch_mode(FETCH_PEERS) + self.assertEqual(p1._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + p1.restaurant._state.fetch_mode, + FETCH_PEERS, + ) diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index 6e4acdddf6..bb6417b8ae 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -3,7 +3,13 @@ from unittest import mock from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ObjectDoesNotExist from django.db import NotSupportedError, connection -from django.db.models import F, Prefetch, QuerySet, prefetch_related_objects +from django.db.models import ( + FETCH_PEERS, + F, + Prefetch, + QuerySet, + prefetch_related_objects, +) from django.db.models.fetch_modes import RAISE from django.db.models.query import get_prefetcher from django.db.models.sql import Query @@ -108,6 +114,28 @@ class PrefetchRelatedTests(TestDataMixin, TestCase): normal_books = [a.first_book for a in Author.objects.all()] self.assertEqual(books, normal_books) + def test_fetch_mode_copied_fetching_one(self): + author = ( + Author.objects.fetch_mode(FETCH_PEERS) + .prefetch_related("first_book") + .get(pk=self.author1.pk) + ) + self.assertEqual(author._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + author.first_book._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_fetching_many(self): + authors = list( + Author.objects.fetch_mode(FETCH_PEERS).prefetch_related("first_book") + ) + self.assertEqual(authors[0]._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + authors[0].first_book._state.fetch_mode, + FETCH_PEERS, + ) + def test_fetch_mode_raise(self): authors = list(Author.objects.fetch_mode(RAISE).prefetch_related("first_book")) authors[0].first_book # No exception, already loaded diff --git a/tests/select_related/tests.py b/tests/select_related/tests.py index 68fe7a906f..41ed350cf3 100644 --- a/tests/select_related/tests.py +++ b/tests/select_related/tests.py @@ -1,4 +1,5 @@ from django.core.exceptions import FieldError +from django.db.models import FETCH_PEERS from django.test import SimpleTestCase, TestCase from .models import ( @@ -210,6 +211,37 @@ class SelectRelatedTests(TestCase): with self.assertRaisesMessage(TypeError, message): list(Species.objects.values_list("name").select_related("genus")) + def test_fetch_mode_copied_fetching_one(self): + fly = ( + Species.objects.fetch_mode(FETCH_PEERS) + .select_related("genus__family") + .get(name="melanogaster") + ) + self.assertEqual(fly._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + fly.genus._state.fetch_mode, + FETCH_PEERS, + ) + self.assertEqual( + fly.genus.family._state.fetch_mode, + FETCH_PEERS, + ) + + def test_fetch_mode_copied_fetching_many(self): + specieses = list( + Species.objects.fetch_mode(FETCH_PEERS).select_related("genus__family") + ) + species = specieses[0] + self.assertEqual(species._state.fetch_mode, FETCH_PEERS) + self.assertEqual( + species.genus._state.fetch_mode, + FETCH_PEERS, + ) + self.assertEqual( + species.genus.family._state.fetch_mode, + FETCH_PEERS, + ) + class SelectRelatedValidationTests(SimpleTestCase): """