diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 6e82edb6e7..f4215ed48e 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -259,7 +259,7 @@ class RegisterLookupMixin: cls._clear_cached_lookups() -def select_related_descend(field, restricted, requested, load_fields, reverse=False): +def select_related_descend(field, restricted, requested, select_mask, reverse=False): """ Return True if this field should be used to descend deeper for select_related() purposes. Used by both the query construction code @@ -271,7 +271,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa * restricted - a boolean field, indicating if the field list has been manually restricted using a requested clause) * requested - The select_related() dictionary. - * load_fields - the set of fields to be loaded on this model + * select_mask - the dictionary of selected fields. * reverse - boolean, True if we are checking a reverse select related """ if not field.remote_field: @@ -287,9 +287,9 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa return False if ( restricted - and load_fields + and select_mask and field.name in requested - and field.attname not in load_fields + and field not in select_mask ): raise FieldError( f"Field {field.model._meta.object_name}.{field.name} cannot be both " diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 858142913b..96d10b9eda 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -256,8 +256,9 @@ class SQLCompiler: select.append((RawSQL(sql, params), alias)) select_idx += 1 assert not (self.query.select and self.query.default_cols) + select_mask = self.query.get_select_mask() if self.query.default_cols: - cols = self.get_default_columns() + cols = self.get_default_columns(select_mask) else: # self.query.select is a special case. These columns never go to # any model. @@ -278,7 +279,7 @@ class SQLCompiler: select_idx += 1 if self.query.select_related: - related_klass_infos = self.get_related_selections(select) + related_klass_infos = self.get_related_selections(select, select_mask) klass_info["related_klass_infos"] = related_klass_infos def get_select_from_parent(klass_info): @@ -870,7 +871,9 @@ class SQLCompiler: # Finally do cleanup - get rid of the joins we created above. self.query.reset_refcounts(refcounts_before) - def get_default_columns(self, start_alias=None, opts=None, from_parent=None): + def get_default_columns( + self, select_mask, start_alias=None, opts=None, from_parent=None + ): """ Compute the default columns for selecting every field in the base model. Will sometimes be called to pull in related models (e.g. via @@ -886,7 +889,6 @@ class SQLCompiler: if opts is None: if (opts := self.query.get_meta()) is None: return result - only_load = self.deferred_to_columns() start_alias = start_alias or self.query.get_initial_alias() # The 'seen_models' is used to optimize checking the needed parent # alias for a given field. This also includes None -> start_alias to @@ -912,7 +914,7 @@ class SQLCompiler: # parent model data is already present in the SELECT clause, # and we want to avoid reloading the same data again. continue - if field.model in only_load and field.attname not in only_load[field.model]: + if select_mask and field not in select_mask: continue alias = self.query.join_parent_model(opts, model, start_alias, seen_models) column = field.get_col(alias) @@ -1063,6 +1065,7 @@ class SQLCompiler: def get_related_selections( self, select, + select_mask, opts=None, root_alias=None, cur_depth=1, @@ -1095,7 +1098,6 @@ class SQLCompiler: if not opts: opts = self.query.get_meta() root_alias = self.query.get_initial_alias() - only_load = self.deferred_to_columns() # Setup for the case when only particular related fields should be # included in the related selection. @@ -1109,7 +1111,6 @@ class SQLCompiler: klass_info["related_klass_infos"] = related_klass_infos for f in opts.fields: - field_model = f.model._meta.concrete_model fields_found.add(f.name) if restricted: @@ -1129,10 +1130,9 @@ class SQLCompiler: else: next = False - if not select_related_descend( - f, restricted, requested, only_load.get(field_model) - ): + if not select_related_descend(f, restricted, requested, select_mask): continue + related_select_mask = select_mask.get(f) or {} klass_info = { "model": f.remote_field.model, "field": f, @@ -1148,7 +1148,7 @@ class SQLCompiler: _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias) alias = joins[-1] columns = self.get_default_columns( - start_alias=alias, opts=f.remote_field.model._meta + related_select_mask, start_alias=alias, opts=f.remote_field.model._meta ) for col in columns: select_fields.append(len(select)) @@ -1156,6 +1156,7 @@ class SQLCompiler: klass_info["select_fields"] = select_fields next_klass_infos = self.get_related_selections( select, + related_select_mask, f.remote_field.model._meta, alias, cur_depth + 1, @@ -1171,8 +1172,9 @@ class SQLCompiler: if o.field.unique and not o.many_to_many ] for f, model in related_fields: + related_select_mask = select_mask.get(f) or {} if not select_related_descend( - f, restricted, requested, only_load.get(model), reverse=True + f, restricted, requested, related_select_mask, reverse=True ): continue @@ -1195,7 +1197,10 @@ class SQLCompiler: related_klass_infos.append(klass_info) select_fields = [] columns = self.get_default_columns( - start_alias=alias, opts=model._meta, from_parent=opts.model + related_select_mask, + start_alias=alias, + opts=model._meta, + from_parent=opts.model, ) for col in columns: select_fields.append(len(select)) @@ -1203,7 +1208,13 @@ class SQLCompiler: klass_info["select_fields"] = select_fields next = requested.get(f.related_query_name(), {}) next_klass_infos = self.get_related_selections( - select, model._meta, alias, cur_depth + 1, next, restricted + select, + related_select_mask, + model._meta, + alias, + cur_depth + 1, + next, + restricted, ) get_related_klass_infos(klass_info, next_klass_infos) @@ -1239,7 +1250,9 @@ class SQLCompiler: } related_klass_infos.append(klass_info) select_fields = [] + field_select_mask = select_mask.get((name, f)) or {} columns = self.get_default_columns( + field_select_mask, start_alias=alias, opts=model._meta, from_parent=opts.model, @@ -1251,6 +1264,7 @@ class SQLCompiler: next_requested = requested.get(name, {}) next_klass_infos = self.get_related_selections( select, + field_select_mask, opts=model._meta, root_alias=alias, cur_depth=cur_depth + 1, @@ -1377,16 +1391,6 @@ class SQLCompiler: ) return result - def deferred_to_columns(self): - """ - Convert the self.deferred_loading data structure to mapping of table - names to sets of column names which are to be loaded. Return the - dictionary. - """ - columns = {} - self.query.deferred_to_data(columns) - return columns - def get_converters(self, expressions): converters = {} for i, expression in enumerate(expressions): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 14ed0c0a63..8419dc0d54 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -718,7 +718,61 @@ class Query(BaseExpression): self.order_by = rhs.order_by or self.order_by self.extra_order_by = rhs.extra_order_by or self.extra_order_by - def deferred_to_data(self, target): + def _get_defer_select_mask(self, opts, mask, select_mask=None): + if select_mask is None: + select_mask = {} + select_mask[opts.pk] = {} + # All concrete fields that are not part of the defer mask must be + # loaded. If a relational field is encountered it gets added to the + # mask for it be considered if `select_related` and the cycle continues + # by recursively caling this function. + for field in opts.concrete_fields: + field_mask = mask.pop(field.name, None) + if field_mask is None: + select_mask.setdefault(field, {}) + elif field_mask: + if not field.is_relation: + raise FieldError(next(iter(field_mask))) + field_select_mask = select_mask.setdefault(field, {}) + related_model = field.remote_field.model._meta.concrete_model + self._get_defer_select_mask( + related_model._meta, field_mask, field_select_mask + ) + # Remaining defer entries must be references to reverse relationships. + # The following code is expected to raise FieldError if it encounters + # a malformed defer entry. + for field_name, field_mask in mask.items(): + if filtered_relation := self._filtered_relations.get(field_name): + relation = opts.get_field(filtered_relation.relation_name) + field_select_mask = select_mask.setdefault((field_name, relation), {}) + field = relation.field + else: + field = opts.get_field(field_name).field + field_select_mask = select_mask.setdefault(field, {}) + related_model = field.model._meta.concrete_model + self._get_defer_select_mask( + related_model._meta, field_mask, field_select_mask + ) + return select_mask + + def _get_only_select_mask(self, opts, mask, select_mask=None): + if select_mask is None: + select_mask = {} + select_mask[opts.pk] = {} + # Only include fields mentioned in the mask. + for field_name, field_mask in mask.items(): + field = opts.get_field(field_name) + field_select_mask = select_mask.setdefault(field, {}) + if field_mask: + if not field.is_relation: + raise FieldError(next(iter(field_mask))) + related_model = field.remote_field.model._meta.concrete_model + self._get_only_select_mask( + related_model._meta, field_mask, field_select_mask + ) + return select_mask + + def get_select_mask(self): """ Convert the self.deferred_loading data structure to an alternate data structure, describing the field that *will* be loaded. This is used to @@ -726,81 +780,19 @@ class Query(BaseExpression): QuerySet class to work out which fields are being initialized on each model. Models that have all their fields included aren't mentioned in the result, only those that have field restrictions in place. - - The "target" parameter is the instance that is populated (in place). """ field_names, defer = self.deferred_loading if not field_names: - return - orig_opts = self.get_meta() - seen = {} - must_include = {orig_opts.concrete_model: {orig_opts.pk}} + return {} + mask = {} for field_name in field_names: - parts = field_name.split(LOOKUP_SEP) - cur_model = self.model._meta.concrete_model - opts = orig_opts - for name in parts[:-1]: - old_model = cur_model - if name in self._filtered_relations: - name = self._filtered_relations[name].relation_name - source = opts.get_field(name) - if is_reverse_o2o(source): - cur_model = source.related_model - else: - cur_model = source.remote_field.model - cur_model = cur_model._meta.concrete_model - opts = cur_model._meta - # Even if we're "just passing through" this model, we must add - # both the current model's pk and the related reference field - # (if it's not a reverse relation) to the things we select. - if not is_reverse_o2o(source): - must_include[old_model].add(source) - add_to_dict(must_include, cur_model, opts.pk) - field = opts.get_field(parts[-1]) - is_reverse_object = field.auto_created and not field.concrete - model = field.related_model if is_reverse_object else field.model - model = model._meta.concrete_model - if model == opts.model: - model = cur_model - if not is_reverse_o2o(field): - add_to_dict(seen, model, field) - + part_mask = mask + for part in field_name.split(LOOKUP_SEP): + part_mask = part_mask.setdefault(part, {}) + opts = self.get_meta() if defer: - # We need to load all fields for each model, except those that - # appear in "seen" (for all models that appear in "seen"). The only - # slight complexity here is handling fields that exist on parent - # models. - workset = {} - for model, values in seen.items(): - for field in model._meta.local_fields: - if field not in values: - m = field.model._meta.concrete_model - add_to_dict(workset, m, field) - for model, values in must_include.items(): - # If we haven't included a model in workset, we don't add the - # corresponding must_include fields for that model, since an - # empty set means "include all fields". That's why there's no - # "else" branch here. - if model in workset: - workset[model].update(values) - for model, fields in workset.items(): - target[model] = {f.attname for f in fields} - else: - for model, values in must_include.items(): - if model in seen: - seen[model].update(values) - else: - # As we've passed through this model, but not explicitly - # included any fields, we have to make sure it's mentioned - # so that only the "must include" fields are pulled in. - seen[model] = values - # Now ensure that every model in the inheritance chain is mentioned - # in the parent list. Again, it must be mentioned to ensure that - # only "must include" fields are pulled in. - for model in orig_opts.get_parent_list(): - seen.setdefault(model, set()) - for model, fields in seen.items(): - target[model] = {f.attname for f in fields} + return self._get_defer_select_mask(opts, mask) + return self._get_only_select_mask(opts, mask) def table_alias(self, table_name, create=False, filtered_relation=None): """ @@ -2583,25 +2575,6 @@ def get_order_dir(field, default="ASC"): return field, dirn[0] -def add_to_dict(data, key, value): - """ - Add "value" to the set of values for "key", whether or not "key" already - exists. - """ - if key in data: - data[key].add(value) - else: - data[key] = {value} - - -def is_reverse_o2o(field): - """ - Check if the given field is reverse-o2o. The field is expected to be some - sort of relation field or related object. - """ - return field.is_relation and field.one_to_one and not field.concrete - - class JoinPromoter: """ A class to abstract away join promotion problems for complex filter diff --git a/tests/defer/tests.py b/tests/defer/tests.py index fe9637c7f1..c2319b54ec 100644 --- a/tests/defer/tests.py +++ b/tests/defer/tests.py @@ -290,6 +290,8 @@ class InvalidDeferTests(SimpleTestCase): msg = "Primary has no field named 'missing'" with self.assertRaisesMessage(FieldDoesNotExist, msg): list(Primary.objects.defer("missing")) + with self.assertRaisesMessage(FieldError, "missing"): + list(Primary.objects.defer("value__missing")) msg = "Secondary has no field named 'missing'" with self.assertRaisesMessage(FieldDoesNotExist, msg): list(Primary.objects.defer("related__missing")) @@ -298,6 +300,8 @@ class InvalidDeferTests(SimpleTestCase): msg = "Primary has no field named 'missing'" with self.assertRaisesMessage(FieldDoesNotExist, msg): list(Primary.objects.only("missing")) + with self.assertRaisesMessage(FieldError, "missing"): + list(Primary.objects.only("value__missing")) msg = "Secondary has no field named 'missing'" with self.assertRaisesMessage(FieldDoesNotExist, msg): list(Primary.objects.only("related__missing")) diff --git a/tests/defer_regress/tests.py b/tests/defer_regress/tests.py index c7a61f53a3..fb9dcdb297 100644 --- a/tests/defer_regress/tests.py +++ b/tests/defer_regress/tests.py @@ -246,8 +246,6 @@ class DeferRegressionTest(TestCase): ) self.assertEqual(len(qs), 1) - -class DeferAnnotateSelectRelatedTest(TestCase): def test_defer_annotate_select_related(self): location = Location.objects.create() Request.objects.create(location=location) @@ -276,6 +274,28 @@ class DeferAnnotateSelectRelatedTest(TestCase): list, ) + def test_common_model_different_mask(self): + child = Child.objects.create(name="Child", value=42) + second_child = Child.objects.create(name="Second", value=64) + Leaf.objects.create(child=child, second_child=second_child) + with self.assertNumQueries(1): + leaf = ( + Leaf.objects.select_related("child", "second_child") + .defer("child__name", "second_child__value") + .get() + ) + self.assertEqual(leaf.child, child) + self.assertEqual(leaf.second_child, second_child) + self.assertEqual(leaf.child.get_deferred_fields(), {"name"}) + self.assertEqual(leaf.second_child.get_deferred_fields(), {"value"}) + with self.assertNumQueries(0): + self.assertEqual(leaf.child.value, 42) + self.assertEqual(leaf.second_child.name, "Second") + with self.assertNumQueries(1): + self.assertEqual(leaf.child.name, "Child") + with self.assertNumQueries(1): + self.assertEqual(leaf.second_child.value, 64) + class DeferDeletionSignalsTests(TestCase): senders = [Item, Proxy] diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 1238f021be..1bd72dd8b8 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -3594,12 +3594,6 @@ class WhereNodeTest(SimpleTestCase): class QuerySetExceptionTests(SimpleTestCase): - def test_iter_exceptions(self): - qs = ExtraInfo.objects.only("author") - msg = "'ManyToOneRel' object has no attribute 'attname'" - with self.assertRaisesMessage(AttributeError, msg): - list(qs) - def test_invalid_order_by(self): msg = "Cannot resolve keyword '*' into field. Choices are: created, id, name" with self.assertRaisesMessage(FieldError, msg):