From 35431298226165986ad07e91f9d3aca721ff38ec Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Wed, 6 Mar 2019 01:05:55 -0500 Subject: [PATCH] Refs #27149 -- Moved subquery expression resolving to Query. This makes Subquery a thin wrapper over Query and makes sure it respects the Expression source expression API by accepting the same number of expressions as it returns. Refs #30188. It also makes OuterRef usable in Query without Subquery wrapping. This should allow Query's internals to more easily perform subquery push downs during split_exclude(). Refs #21703. --- django/db/models/expressions.py | 64 ++++++--------------------------- django/db/models/sql/query.py | 28 +++++++++++++-- django/db/models/sql/where.py | 15 ++++++++ tests/queries/tests.py | 2 +- 4 files changed, 52 insertions(+), 57 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index bb060fd893..7b703c8f1c 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1002,70 +1002,28 @@ class Subquery(Expression): self.extra = extra super().__init__(output_field) + def get_source_expressions(self): + return [self.query] + + def set_source_expressions(self, exprs): + self.query = exprs[0] + def _resolve_output_field(self): - if len(self.query.select) == 1: - return self.query.select[0].field - return super()._resolve_output_field() + return self.query.output_field def copy(self): clone = super().copy() clone.query = clone.query.clone() return clone - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - clone = self.copy() - clone.is_summary = summarize - clone.query.bump_prefix(query) - - # Need to recursively resolve these. - def resolve_all(child): - if hasattr(child, 'children'): - [resolve_all(_child) for _child in child.children] - if hasattr(child, 'rhs'): - child.rhs = resolve(child.rhs) - - def resolve(child): - if hasattr(child, 'resolve_expression'): - resolved = child.resolve_expression( - query=query, allow_joins=allow_joins, reuse=reuse, - summarize=summarize, for_save=for_save, - ) - # Add table alias to the parent query's aliases to prevent - # quoting. - if hasattr(resolved, 'alias') and resolved.alias != resolved.target.model._meta.db_table: - clone.query.external_aliases.add(resolved.alias) - return resolved - return child - - resolve_all(clone.query.where) - - for key, value in clone.query.annotations.items(): - if isinstance(value, Subquery): - clone.query.annotations[key] = resolve(value) - - return clone - - def get_source_expressions(self): - return [ - x for x in [ - getattr(expr, 'lhs', None) - for expr in self.query.where.children - ] if x - ] - - def relabeled_clone(self, change_map): - clone = self.copy() - clone.query = clone.query.relabeled_clone(change_map) - clone.query.external_aliases.update( - alias for alias in change_map.values() - if alias not in clone.query.alias_map - ) - return clone + @property + def external_aliases(self): + return self.query.external_aliases def as_sql(self, compiler, connection, template=None, **extra_context): connection.ops.check_expression_support(self) template_params = {**self.extra, **extra_context} - template_params['subquery'], sql_params = self.query.get_compiler(connection=connection).as_sql() + template_params['subquery'], sql_params = self.query.as_sql(compiler, connection) template = template or template_params.get('template', self.template) sql = template % template_params diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ba4baca2b8..c192530573 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -21,7 +21,7 @@ from django.core.exceptions import ( from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import Col, F, Ref, SimpleCol +from django.db.models.expressions import BaseExpression, Col, F, Ref, SimpleCol from django.db.models.fields import Field from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup @@ -139,7 +139,7 @@ class RawQuery: self.cursor.execute(self.sql, params) -class Query: +class Query(BaseExpression): """A single SQL query.""" alias_prefix = 'T' @@ -231,6 +231,13 @@ class Query: self.explain_format = None self.explain_options = {} + @property + def output_field(self): + if len(self.select) == 1: + return self.select[0].field + elif len(self.annotation_select) == 1: + return next(iter(self.annotation_select.values())).output_field + @property def has_select_fields(self): return bool(self.select or self.annotation_select_mask or self.extra_select_mask) @@ -862,7 +869,7 @@ class Query: # No clashes between self and outer query should be possible. return - local_recursion_limit = 127 # explicitly avoid infinite loop + local_recursion_limit = 67 # explicitly avoid infinite loop for pos, prefix in enumerate(prefix_gen()): if prefix not in self.subq_aliases: self.alias_prefix = prefix @@ -997,6 +1004,21 @@ class Query: not self.distinct_fields and not self.select_for_update): clone.clear_ordering(True) + clone.where.resolve_expression(query, *args, **kwargs) + for key, value in clone.annotations.items(): + resolved = value.resolve_expression(query, *args, **kwargs) + if hasattr(resolved, 'external_aliases'): + resolved.external_aliases.update(clone.alias_map) + clone.annotations[key] = resolved + # Outer query's aliases are considered external. + clone.external_aliases.update( + alias for alias, table in query.alias_map.items() + if ( + isinstance(table, Join) and table.join_field.related_model._meta.db_table != alias + ) or ( + isinstance(table, BaseTable) and table.table_name != table.table_alias + ) + ) return clone def as_sql(self, compiler, connection): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 9d3d6a9366..496822c58b 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -183,8 +183,23 @@ class WhereNode(tree.Node): def is_summary(self): return any(child.is_summary for child in self.children) + @staticmethod + def _resolve_rhs(rhs, query, *args, **kwargs): + if hasattr(rhs, 'resolve_expression'): + rhs = rhs.resolve_expression(query, *args, **kwargs) + return rhs + + @classmethod + def _resolve_node(cls, node, query, *args, **kwargs): + if hasattr(node, 'children'): + for child in node.children: + cls._resolve_node(child, query, *args, **kwargs) + if hasattr(node, 'rhs'): + node.rhs = cls._resolve_rhs(node.rhs, query, *args, **kwargs) + def resolve_expression(self, *args, **kwargs): clone = self.clone() + clone._resolve_node(clone, *args, **kwargs) clone.resolved = True return clone diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 8128deca8b..76f42fee73 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -402,7 +402,7 @@ class Queries1Tests(TestCase): def test_avoid_infinite_loop_on_too_many_subqueries(self): x = Tag.objects.filter(pk=1) - local_recursion_limit = 127 + local_recursion_limit = 67 msg = 'Maximum recursion depth exceeded: too many subqueries.' with self.assertRaisesMessage(RuntimeError, msg): for i in range(local_recursion_limit * 2):