1
0
mirror of https://github.com/django/django.git synced 2025-01-04 23:46:09 +00:00

Fixed #35356 -- Deferred self-referential foreign key fields adequately.

While refs #34612 surfaced issues with reverse one-to-one fields
deferrals, it missed that switching to storing remote fields would break
self-referential relationships.

This change switches to storing related objects in the select mask
instead of remote fields to prevent collisions when dealing with
self-referential relationships that might have a different directional
mask.

Despite fixing #21204 introduced a crash under some self-referential
deferral conditions, it was simply not working even before that as it
aggregated the sets of deferred fields by model.

Thanks Joshua van Besouw for the report and Mariusz Felisiak for the
review.
This commit is contained in:
Simon Charette 2024-04-05 23:08:49 -04:00 committed by nessita
parent bcad5ad92b
commit 83f5478225
4 changed files with 36 additions and 17 deletions

View File

@ -1253,21 +1253,21 @@ class SQLCompiler:
if restricted: if restricted:
related_fields = [ related_fields = [
(o.field, o.related_model) (o, o.field, o.related_model)
for o in opts.related_objects for o in opts.related_objects
if o.field.unique and not o.many_to_many if o.field.unique and not o.many_to_many
] ]
for related_field, model in related_fields: for related_object, related_field, model in related_fields:
related_select_mask = select_mask.get(related_field) or {}
if not select_related_descend( if not select_related_descend(
related_field, related_field,
restricted, restricted,
requested, requested,
related_select_mask, select_mask,
reverse=True, reverse=True,
): ):
continue continue
related_select_mask = select_mask.get(related_object) or {}
related_field_name = related_field.related_query_name() related_field_name = related_field.related_query_name()
fields_found.add(related_field_name) fields_found.add(related_field_name)

View File

@ -815,19 +815,17 @@ class Query(BaseExpression):
if filtered_relation := self._filtered_relations.get(field_name): if filtered_relation := self._filtered_relations.get(field_name):
relation = opts.get_field(filtered_relation.relation_name) relation = opts.get_field(filtered_relation.relation_name)
field_select_mask = select_mask.setdefault((field_name, relation), {}) field_select_mask = select_mask.setdefault((field_name, relation), {})
field = relation.field
else: else:
reverse_rel = opts.get_field(field_name) relation = opts.get_field(field_name)
# While virtual fields such as many-to-many and generic foreign # While virtual fields such as many-to-many and generic foreign
# keys cannot be effectively deferred we've historically # keys cannot be effectively deferred we've historically
# allowed them to be passed to QuerySet.defer(). Ignore such # allowed them to be passed to QuerySet.defer(). Ignore such
# field references until a layer of validation at mask # field references until a layer of validation at mask
# alteration time will be implemented eventually. # alteration time will be implemented eventually.
if not hasattr(reverse_rel, "field"): if not hasattr(relation, "field"):
continue continue
field = reverse_rel.field field_select_mask = select_mask.setdefault(relation, {})
field_select_mask = select_mask.setdefault(field, {}) related_model = relation.related_model._meta.concrete_model
related_model = field.model._meta.concrete_model
self._get_defer_select_mask( self._get_defer_select_mask(
related_model._meta, field_mask, field_select_mask related_model._meta, field_mask, field_select_mask
) )
@ -840,13 +838,7 @@ class Query(BaseExpression):
# Only include fields mentioned in the mask. # Only include fields mentioned in the mask.
for field_name, field_mask in mask.items(): for field_name, field_mask in mask.items():
field = opts.get_field(field_name) field = opts.get_field(field_name)
# Retrieve the actual field associated with reverse relationships field_select_mask = select_mask.setdefault(field, {})
# as that's what is expected in the select mask.
if field in opts.related_objects:
field_key = field.field
else:
field_key = field
field_select_mask = select_mask.setdefault(field_key, {})
if field_mask: if field_mask:
if not field.is_relation: if not field.is_relation:
raise FieldError(next(iter(field_mask))) raise FieldError(next(iter(field_mask)))

View File

@ -10,6 +10,12 @@ class Item(models.Model):
text = models.TextField(default="xyzzy") text = models.TextField(default="xyzzy")
value = models.IntegerField() value = models.IntegerField()
other_value = models.IntegerField(default=0) other_value = models.IntegerField(default=0)
source = models.OneToOneField(
"self",
related_name="destination",
on_delete=models.CASCADE,
null=True,
)
class RelatedItem(models.Model): class RelatedItem(models.Model):

View File

@ -309,6 +309,27 @@ class DeferRegressionTest(TestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
self.assertEqual(Item.objects.only("request").get(), item) self.assertEqual(Item.objects.only("request").get(), item)
def test_self_referential_one_to_one(self):
first = Item.objects.create(name="first", value=1)
second = Item.objects.create(name="second", value=2, source=first)
with self.assertNumQueries(1):
deferred_first, deferred_second = (
Item.objects.select_related("source", "destination")
.only("name", "source__name", "destination__value")
.order_by("pk")
)
with self.assertNumQueries(0):
self.assertEqual(deferred_first.name, first.name)
self.assertEqual(deferred_second.name, second.name)
self.assertEqual(deferred_second.source.name, first.name)
self.assertEqual(deferred_first.destination.value, second.value)
with self.assertNumQueries(1):
self.assertEqual(deferred_first.value, first.value)
with self.assertNumQueries(1):
self.assertEqual(deferred_second.source.value, first.value)
with self.assertNumQueries(1):
self.assertEqual(deferred_first.destination.name, second.name)
class DeferDeletionSignalsTests(TestCase): class DeferDeletionSignalsTests(TestCase):
senders = [Item, Proxy] senders = [Item, Proxy]