1
0
mirror of https://github.com/django/django.git synced 2025-01-03 06:55:47 +00:00

Fixed #21204 -- Tracked field deferrals by field instead of models.

This ensures field deferral works properly when a model is involved
more than once in the same query with a distinct deferral mask.
This commit is contained in:
Simon Charette 2022-08-18 12:30:20 -04:00 committed by Mariusz Felisiak
parent 5d12650ed9
commit b3db6c8dcb
6 changed files with 121 additions and 126 deletions

View File

@ -259,7 +259,7 @@ class RegisterLookupMixin:
cls._clear_cached_lookups() cls._clear_cached_lookups()
def select_related_descend(field, restricted, requested, load_fields, reverse=False): def select_related_descend(field, restricted, requested, select_mask, reverse=False):
""" """
Return True if this field should be used to descend deeper for Return True if this field should be used to descend deeper for
select_related() purposes. Used by both the query construction code select_related() purposes. Used by both the query construction code
@ -271,7 +271,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
* restricted - a boolean field, indicating if the field list has been * restricted - a boolean field, indicating if the field list has been
manually restricted using a requested clause) manually restricted using a requested clause)
* requested - The select_related() dictionary. * requested - The select_related() dictionary.
* load_fields - the set of fields to be loaded on this model * select_mask - the dictionary of selected fields.
* reverse - boolean, True if we are checking a reverse select related * reverse - boolean, True if we are checking a reverse select related
""" """
if not field.remote_field: if not field.remote_field:
@ -287,9 +287,9 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
return False return False
if ( if (
restricted restricted
and load_fields and select_mask
and field.name in requested and field.name in requested
and field.attname not in load_fields and field not in select_mask
): ):
raise FieldError( raise FieldError(
f"Field {field.model._meta.object_name}.{field.name} cannot be both " f"Field {field.model._meta.object_name}.{field.name} cannot be both "

View File

@ -256,8 +256,9 @@ class SQLCompiler:
select.append((RawSQL(sql, params), alias)) select.append((RawSQL(sql, params), alias))
select_idx += 1 select_idx += 1
assert not (self.query.select and self.query.default_cols) assert not (self.query.select and self.query.default_cols)
select_mask = self.query.get_select_mask()
if self.query.default_cols: if self.query.default_cols:
cols = self.get_default_columns() cols = self.get_default_columns(select_mask)
else: else:
# self.query.select is a special case. These columns never go to # self.query.select is a special case. These columns never go to
# any model. # any model.
@ -278,7 +279,7 @@ class SQLCompiler:
select_idx += 1 select_idx += 1
if self.query.select_related: if self.query.select_related:
related_klass_infos = self.get_related_selections(select) related_klass_infos = self.get_related_selections(select, select_mask)
klass_info["related_klass_infos"] = related_klass_infos klass_info["related_klass_infos"] = related_klass_infos
def get_select_from_parent(klass_info): def get_select_from_parent(klass_info):
@ -870,7 +871,9 @@ class SQLCompiler:
# Finally do cleanup - get rid of the joins we created above. # Finally do cleanup - get rid of the joins we created above.
self.query.reset_refcounts(refcounts_before) self.query.reset_refcounts(refcounts_before)
def get_default_columns(self, start_alias=None, opts=None, from_parent=None): def get_default_columns(
self, select_mask, start_alias=None, opts=None, from_parent=None
):
""" """
Compute the default columns for selecting every field in the base Compute the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via model. Will sometimes be called to pull in related models (e.g. via
@ -886,7 +889,6 @@ class SQLCompiler:
if opts is None: if opts is None:
if (opts := self.query.get_meta()) is None: if (opts := self.query.get_meta()) is None:
return result return result
only_load = self.deferred_to_columns()
start_alias = start_alias or self.query.get_initial_alias() start_alias = start_alias or self.query.get_initial_alias()
# The 'seen_models' is used to optimize checking the needed parent # The 'seen_models' is used to optimize checking the needed parent
# alias for a given field. This also includes None -> start_alias to # alias for a given field. This also includes None -> start_alias to
@ -912,7 +914,7 @@ class SQLCompiler:
# parent model data is already present in the SELECT clause, # parent model data is already present in the SELECT clause,
# and we want to avoid reloading the same data again. # and we want to avoid reloading the same data again.
continue continue
if field.model in only_load and field.attname not in only_load[field.model]: if select_mask and field not in select_mask:
continue continue
alias = self.query.join_parent_model(opts, model, start_alias, seen_models) alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
column = field.get_col(alias) column = field.get_col(alias)
@ -1063,6 +1065,7 @@ class SQLCompiler:
def get_related_selections( def get_related_selections(
self, self,
select, select,
select_mask,
opts=None, opts=None,
root_alias=None, root_alias=None,
cur_depth=1, cur_depth=1,
@ -1095,7 +1098,6 @@ class SQLCompiler:
if not opts: if not opts:
opts = self.query.get_meta() opts = self.query.get_meta()
root_alias = self.query.get_initial_alias() root_alias = self.query.get_initial_alias()
only_load = self.deferred_to_columns()
# Setup for the case when only particular related fields should be # Setup for the case when only particular related fields should be
# included in the related selection. # included in the related selection.
@ -1109,7 +1111,6 @@ class SQLCompiler:
klass_info["related_klass_infos"] = related_klass_infos klass_info["related_klass_infos"] = related_klass_infos
for f in opts.fields: for f in opts.fields:
field_model = f.model._meta.concrete_model
fields_found.add(f.name) fields_found.add(f.name)
if restricted: if restricted:
@ -1129,10 +1130,9 @@ class SQLCompiler:
else: else:
next = False next = False
if not select_related_descend( if not select_related_descend(f, restricted, requested, select_mask):
f, restricted, requested, only_load.get(field_model)
):
continue continue
related_select_mask = select_mask.get(f) or {}
klass_info = { klass_info = {
"model": f.remote_field.model, "model": f.remote_field.model,
"field": f, "field": f,
@ -1148,7 +1148,7 @@ class SQLCompiler:
_, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias) _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
alias = joins[-1] alias = joins[-1]
columns = self.get_default_columns( columns = self.get_default_columns(
start_alias=alias, opts=f.remote_field.model._meta related_select_mask, start_alias=alias, opts=f.remote_field.model._meta
) )
for col in columns: for col in columns:
select_fields.append(len(select)) select_fields.append(len(select))
@ -1156,6 +1156,7 @@ class SQLCompiler:
klass_info["select_fields"] = select_fields klass_info["select_fields"] = select_fields
next_klass_infos = self.get_related_selections( next_klass_infos = self.get_related_selections(
select, select,
related_select_mask,
f.remote_field.model._meta, f.remote_field.model._meta,
alias, alias,
cur_depth + 1, cur_depth + 1,
@ -1171,8 +1172,9 @@ class SQLCompiler:
if o.field.unique and not o.many_to_many if o.field.unique and not o.many_to_many
] ]
for f, model in related_fields: for f, model in related_fields:
related_select_mask = select_mask.get(f) or {}
if not select_related_descend( if not select_related_descend(
f, restricted, requested, only_load.get(model), reverse=True f, restricted, requested, related_select_mask, reverse=True
): ):
continue continue
@ -1195,7 +1197,10 @@ class SQLCompiler:
related_klass_infos.append(klass_info) related_klass_infos.append(klass_info)
select_fields = [] select_fields = []
columns = self.get_default_columns( columns = self.get_default_columns(
start_alias=alias, opts=model._meta, from_parent=opts.model related_select_mask,
start_alias=alias,
opts=model._meta,
from_parent=opts.model,
) )
for col in columns: for col in columns:
select_fields.append(len(select)) select_fields.append(len(select))
@ -1203,7 +1208,13 @@ class SQLCompiler:
klass_info["select_fields"] = select_fields klass_info["select_fields"] = select_fields
next = requested.get(f.related_query_name(), {}) next = requested.get(f.related_query_name(), {})
next_klass_infos = self.get_related_selections( next_klass_infos = self.get_related_selections(
select, model._meta, alias, cur_depth + 1, next, restricted select,
related_select_mask,
model._meta,
alias,
cur_depth + 1,
next,
restricted,
) )
get_related_klass_infos(klass_info, next_klass_infos) get_related_klass_infos(klass_info, next_klass_infos)
@ -1239,7 +1250,9 @@ class SQLCompiler:
} }
related_klass_infos.append(klass_info) related_klass_infos.append(klass_info)
select_fields = [] select_fields = []
field_select_mask = select_mask.get((name, f)) or {}
columns = self.get_default_columns( columns = self.get_default_columns(
field_select_mask,
start_alias=alias, start_alias=alias,
opts=model._meta, opts=model._meta,
from_parent=opts.model, from_parent=opts.model,
@ -1251,6 +1264,7 @@ class SQLCompiler:
next_requested = requested.get(name, {}) next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections( next_klass_infos = self.get_related_selections(
select, select,
field_select_mask,
opts=model._meta, opts=model._meta,
root_alias=alias, root_alias=alias,
cur_depth=cur_depth + 1, cur_depth=cur_depth + 1,
@ -1377,16 +1391,6 @@ class SQLCompiler:
) )
return result return result
def deferred_to_columns(self):
"""
Convert the self.deferred_loading data structure to mapping of table
names to sets of column names which are to be loaded. Return the
dictionary.
"""
columns = {}
self.query.deferred_to_data(columns)
return columns
def get_converters(self, expressions): def get_converters(self, expressions):
converters = {} converters = {}
for i, expression in enumerate(expressions): for i, expression in enumerate(expressions):

View File

@ -718,7 +718,61 @@ class Query(BaseExpression):
self.order_by = rhs.order_by or self.order_by self.order_by = rhs.order_by or self.order_by
self.extra_order_by = rhs.extra_order_by or self.extra_order_by self.extra_order_by = rhs.extra_order_by or self.extra_order_by
def deferred_to_data(self, target): def _get_defer_select_mask(self, opts, mask, select_mask=None):
if select_mask is None:
select_mask = {}
select_mask[opts.pk] = {}
# All concrete fields that are not part of the defer mask must be
# loaded. If a relational field is encountered it gets added to the
# mask for it be considered if `select_related` and the cycle continues
# by recursively caling this function.
for field in opts.concrete_fields:
field_mask = mask.pop(field.name, None)
if field_mask is None:
select_mask.setdefault(field, {})
elif field_mask:
if not field.is_relation:
raise FieldError(next(iter(field_mask)))
field_select_mask = select_mask.setdefault(field, {})
related_model = field.remote_field.model._meta.concrete_model
self._get_defer_select_mask(
related_model._meta, field_mask, field_select_mask
)
# Remaining defer entries must be references to reverse relationships.
# The following code is expected to raise FieldError if it encounters
# a malformed defer entry.
for field_name, field_mask in mask.items():
if filtered_relation := self._filtered_relations.get(field_name):
relation = opts.get_field(filtered_relation.relation_name)
field_select_mask = select_mask.setdefault((field_name, relation), {})
field = relation.field
else:
field = opts.get_field(field_name).field
field_select_mask = select_mask.setdefault(field, {})
related_model = field.model._meta.concrete_model
self._get_defer_select_mask(
related_model._meta, field_mask, field_select_mask
)
return select_mask
def _get_only_select_mask(self, opts, mask, select_mask=None):
if select_mask is None:
select_mask = {}
select_mask[opts.pk] = {}
# Only include fields mentioned in the mask.
for field_name, field_mask in mask.items():
field = opts.get_field(field_name)
field_select_mask = select_mask.setdefault(field, {})
if field_mask:
if not field.is_relation:
raise FieldError(next(iter(field_mask)))
related_model = field.remote_field.model._meta.concrete_model
self._get_only_select_mask(
related_model._meta, field_mask, field_select_mask
)
return select_mask
def get_select_mask(self):
""" """
Convert the self.deferred_loading data structure to an alternate data Convert the self.deferred_loading data structure to an alternate data
structure, describing the field that *will* be loaded. This is used to structure, describing the field that *will* be loaded. This is used to
@ -726,81 +780,19 @@ class Query(BaseExpression):
QuerySet class to work out which fields are being initialized on each QuerySet class to work out which fields are being initialized on each
model. Models that have all their fields included aren't mentioned in model. Models that have all their fields included aren't mentioned in
the result, only those that have field restrictions in place. the result, only those that have field restrictions in place.
The "target" parameter is the instance that is populated (in place).
""" """
field_names, defer = self.deferred_loading field_names, defer = self.deferred_loading
if not field_names: if not field_names:
return return {}
orig_opts = self.get_meta() mask = {}
seen = {}
must_include = {orig_opts.concrete_model: {orig_opts.pk}}
for field_name in field_names: for field_name in field_names:
parts = field_name.split(LOOKUP_SEP) part_mask = mask
cur_model = self.model._meta.concrete_model for part in field_name.split(LOOKUP_SEP):
opts = orig_opts part_mask = part_mask.setdefault(part, {})
for name in parts[:-1]: opts = self.get_meta()
old_model = cur_model
if name in self._filtered_relations:
name = self._filtered_relations[name].relation_name
source = opts.get_field(name)
if is_reverse_o2o(source):
cur_model = source.related_model
else:
cur_model = source.remote_field.model
cur_model = cur_model._meta.concrete_model
opts = cur_model._meta
# Even if we're "just passing through" this model, we must add
# both the current model's pk and the related reference field
# (if it's not a reverse relation) to the things we select.
if not is_reverse_o2o(source):
must_include[old_model].add(source)
add_to_dict(must_include, cur_model, opts.pk)
field = opts.get_field(parts[-1])
is_reverse_object = field.auto_created and not field.concrete
model = field.related_model if is_reverse_object else field.model
model = model._meta.concrete_model
if model == opts.model:
model = cur_model
if not is_reverse_o2o(field):
add_to_dict(seen, model, field)
if defer: if defer:
# We need to load all fields for each model, except those that return self._get_defer_select_mask(opts, mask)
# appear in "seen" (for all models that appear in "seen"). The only return self._get_only_select_mask(opts, mask)
# slight complexity here is handling fields that exist on parent
# models.
workset = {}
for model, values in seen.items():
for field in model._meta.local_fields:
if field not in values:
m = field.model._meta.concrete_model
add_to_dict(workset, m, field)
for model, values in must_include.items():
# If we haven't included a model in workset, we don't add the
# corresponding must_include fields for that model, since an
# empty set means "include all fields". That's why there's no
# "else" branch here.
if model in workset:
workset[model].update(values)
for model, fields in workset.items():
target[model] = {f.attname for f in fields}
else:
for model, values in must_include.items():
if model in seen:
seen[model].update(values)
else:
# As we've passed through this model, but not explicitly
# included any fields, we have to make sure it's mentioned
# so that only the "must include" fields are pulled in.
seen[model] = values
# Now ensure that every model in the inheritance chain is mentioned
# in the parent list. Again, it must be mentioned to ensure that
# only "must include" fields are pulled in.
for model in orig_opts.get_parent_list():
seen.setdefault(model, set())
for model, fields in seen.items():
target[model] = {f.attname for f in fields}
def table_alias(self, table_name, create=False, filtered_relation=None): def table_alias(self, table_name, create=False, filtered_relation=None):
""" """
@ -2583,25 +2575,6 @@ def get_order_dir(field, default="ASC"):
return field, dirn[0] return field, dirn[0]
def add_to_dict(data, key, value):
"""
Add "value" to the set of values for "key", whether or not "key" already
exists.
"""
if key in data:
data[key].add(value)
else:
data[key] = {value}
def is_reverse_o2o(field):
"""
Check if the given field is reverse-o2o. The field is expected to be some
sort of relation field or related object.
"""
return field.is_relation and field.one_to_one and not field.concrete
class JoinPromoter: class JoinPromoter:
""" """
A class to abstract away join promotion problems for complex filter A class to abstract away join promotion problems for complex filter

View File

@ -290,6 +290,8 @@ class InvalidDeferTests(SimpleTestCase):
msg = "Primary has no field named 'missing'" msg = "Primary has no field named 'missing'"
with self.assertRaisesMessage(FieldDoesNotExist, msg): with self.assertRaisesMessage(FieldDoesNotExist, msg):
list(Primary.objects.defer("missing")) list(Primary.objects.defer("missing"))
with self.assertRaisesMessage(FieldError, "missing"):
list(Primary.objects.defer("value__missing"))
msg = "Secondary has no field named 'missing'" msg = "Secondary has no field named 'missing'"
with self.assertRaisesMessage(FieldDoesNotExist, msg): with self.assertRaisesMessage(FieldDoesNotExist, msg):
list(Primary.objects.defer("related__missing")) list(Primary.objects.defer("related__missing"))
@ -298,6 +300,8 @@ class InvalidDeferTests(SimpleTestCase):
msg = "Primary has no field named 'missing'" msg = "Primary has no field named 'missing'"
with self.assertRaisesMessage(FieldDoesNotExist, msg): with self.assertRaisesMessage(FieldDoesNotExist, msg):
list(Primary.objects.only("missing")) list(Primary.objects.only("missing"))
with self.assertRaisesMessage(FieldError, "missing"):
list(Primary.objects.only("value__missing"))
msg = "Secondary has no field named 'missing'" msg = "Secondary has no field named 'missing'"
with self.assertRaisesMessage(FieldDoesNotExist, msg): with self.assertRaisesMessage(FieldDoesNotExist, msg):
list(Primary.objects.only("related__missing")) list(Primary.objects.only("related__missing"))

View File

@ -246,8 +246,6 @@ class DeferRegressionTest(TestCase):
) )
self.assertEqual(len(qs), 1) self.assertEqual(len(qs), 1)
class DeferAnnotateSelectRelatedTest(TestCase):
def test_defer_annotate_select_related(self): def test_defer_annotate_select_related(self):
location = Location.objects.create() location = Location.objects.create()
Request.objects.create(location=location) Request.objects.create(location=location)
@ -276,6 +274,28 @@ class DeferAnnotateSelectRelatedTest(TestCase):
list, list,
) )
def test_common_model_different_mask(self):
child = Child.objects.create(name="Child", value=42)
second_child = Child.objects.create(name="Second", value=64)
Leaf.objects.create(child=child, second_child=second_child)
with self.assertNumQueries(1):
leaf = (
Leaf.objects.select_related("child", "second_child")
.defer("child__name", "second_child__value")
.get()
)
self.assertEqual(leaf.child, child)
self.assertEqual(leaf.second_child, second_child)
self.assertEqual(leaf.child.get_deferred_fields(), {"name"})
self.assertEqual(leaf.second_child.get_deferred_fields(), {"value"})
with self.assertNumQueries(0):
self.assertEqual(leaf.child.value, 42)
self.assertEqual(leaf.second_child.name, "Second")
with self.assertNumQueries(1):
self.assertEqual(leaf.child.name, "Child")
with self.assertNumQueries(1):
self.assertEqual(leaf.second_child.value, 64)
class DeferDeletionSignalsTests(TestCase): class DeferDeletionSignalsTests(TestCase):
senders = [Item, Proxy] senders = [Item, Proxy]

View File

@ -3594,12 +3594,6 @@ class WhereNodeTest(SimpleTestCase):
class QuerySetExceptionTests(SimpleTestCase): class QuerySetExceptionTests(SimpleTestCase):
def test_iter_exceptions(self):
qs = ExtraInfo.objects.only("author")
msg = "'ManyToOneRel' object has no attribute 'attname'"
with self.assertRaisesMessage(AttributeError, msg):
list(qs)
def test_invalid_order_by(self): def test_invalid_order_by(self):
msg = "Cannot resolve keyword '*' into field. Choices are: created, id, name" msg = "Cannot resolve keyword '*' into field. Choices are: created, id, name"
with self.assertRaisesMessage(FieldError, msg): with self.assertRaisesMessage(FieldError, msg):