mirror of
https://github.com/django/django.git
synced 2024-12-23 01:25:58 +00:00
Refs #27149 -- Moved subquery expression resolving to Query.
This makes Subquery a thin wrapper over Query and makes sure it respects the Expression source expression API by accepting the same number of expressions as it returns. Refs #30188. It also makes OuterRef usable in Query without Subquery wrapping. This should allow Query's internals to more easily perform subquery push downs during split_exclude(). Refs #21703.
This commit is contained in:
parent
96b6ad94d9
commit
3543129822
@ -1002,70 +1002,28 @@ class Subquery(Expression):
|
|||||||
self.extra = extra
|
self.extra = extra
|
||||||
super().__init__(output_field)
|
super().__init__(output_field)
|
||||||
|
|
||||||
|
def get_source_expressions(self):
|
||||||
|
return [self.query]
|
||||||
|
|
||||||
|
def set_source_expressions(self, exprs):
|
||||||
|
self.query = exprs[0]
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
def _resolve_output_field(self):
|
||||||
if len(self.query.select) == 1:
|
return self.query.output_field
|
||||||
return self.query.select[0].field
|
|
||||||
return super()._resolve_output_field()
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
clone = super().copy()
|
clone = super().copy()
|
||||||
clone.query = clone.query.clone()
|
clone.query = clone.query.clone()
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
@property
|
||||||
clone = self.copy()
|
def external_aliases(self):
|
||||||
clone.is_summary = summarize
|
return self.query.external_aliases
|
||||||
clone.query.bump_prefix(query)
|
|
||||||
|
|
||||||
# Need to recursively resolve these.
|
|
||||||
def resolve_all(child):
|
|
||||||
if hasattr(child, 'children'):
|
|
||||||
[resolve_all(_child) for _child in child.children]
|
|
||||||
if hasattr(child, 'rhs'):
|
|
||||||
child.rhs = resolve(child.rhs)
|
|
||||||
|
|
||||||
def resolve(child):
|
|
||||||
if hasattr(child, 'resolve_expression'):
|
|
||||||
resolved = child.resolve_expression(
|
|
||||||
query=query, allow_joins=allow_joins, reuse=reuse,
|
|
||||||
summarize=summarize, for_save=for_save,
|
|
||||||
)
|
|
||||||
# Add table alias to the parent query's aliases to prevent
|
|
||||||
# quoting.
|
|
||||||
if hasattr(resolved, 'alias') and resolved.alias != resolved.target.model._meta.db_table:
|
|
||||||
clone.query.external_aliases.add(resolved.alias)
|
|
||||||
return resolved
|
|
||||||
return child
|
|
||||||
|
|
||||||
resolve_all(clone.query.where)
|
|
||||||
|
|
||||||
for key, value in clone.query.annotations.items():
|
|
||||||
if isinstance(value, Subquery):
|
|
||||||
clone.query.annotations[key] = resolve(value)
|
|
||||||
|
|
||||||
return clone
|
|
||||||
|
|
||||||
def get_source_expressions(self):
|
|
||||||
return [
|
|
||||||
x for x in [
|
|
||||||
getattr(expr, 'lhs', None)
|
|
||||||
for expr in self.query.where.children
|
|
||||||
] if x
|
|
||||||
]
|
|
||||||
|
|
||||||
def relabeled_clone(self, change_map):
|
|
||||||
clone = self.copy()
|
|
||||||
clone.query = clone.query.relabeled_clone(change_map)
|
|
||||||
clone.query.external_aliases.update(
|
|
||||||
alias for alias in change_map.values()
|
|
||||||
if alias not in clone.query.alias_map
|
|
||||||
)
|
|
||||||
return clone
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, template=None, **extra_context):
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
template_params = {**self.extra, **extra_context}
|
template_params = {**self.extra, **extra_context}
|
||||||
template_params['subquery'], sql_params = self.query.get_compiler(connection=connection).as_sql()
|
template_params['subquery'], sql_params = self.query.as_sql(compiler, connection)
|
||||||
|
|
||||||
template = template or template_params.get('template', self.template)
|
template = template or template_params.get('template', self.template)
|
||||||
sql = template % template_params
|
sql = template % template_params
|
||||||
|
@ -21,7 +21,7 @@ from django.core.exceptions import (
|
|||||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
|
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
|
||||||
from django.db.models.aggregates import Count
|
from django.db.models.aggregates import Count
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.expressions import Col, F, Ref, SimpleCol
|
from django.db.models.expressions import BaseExpression, Col, F, Ref, SimpleCol
|
||||||
from django.db.models.fields import Field
|
from django.db.models.fields import Field
|
||||||
from django.db.models.fields.related_lookups import MultiColSource
|
from django.db.models.fields.related_lookups import MultiColSource
|
||||||
from django.db.models.lookups import Lookup
|
from django.db.models.lookups import Lookup
|
||||||
@ -139,7 +139,7 @@ class RawQuery:
|
|||||||
self.cursor.execute(self.sql, params)
|
self.cursor.execute(self.sql, params)
|
||||||
|
|
||||||
|
|
||||||
class Query:
|
class Query(BaseExpression):
|
||||||
"""A single SQL query."""
|
"""A single SQL query."""
|
||||||
|
|
||||||
alias_prefix = 'T'
|
alias_prefix = 'T'
|
||||||
@ -231,6 +231,13 @@ class Query:
|
|||||||
self.explain_format = None
|
self.explain_format = None
|
||||||
self.explain_options = {}
|
self.explain_options = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_field(self):
|
||||||
|
if len(self.select) == 1:
|
||||||
|
return self.select[0].field
|
||||||
|
elif len(self.annotation_select) == 1:
|
||||||
|
return next(iter(self.annotation_select.values())).output_field
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_select_fields(self):
|
def has_select_fields(self):
|
||||||
return bool(self.select or self.annotation_select_mask or self.extra_select_mask)
|
return bool(self.select or self.annotation_select_mask or self.extra_select_mask)
|
||||||
@ -862,7 +869,7 @@ class Query:
|
|||||||
# No clashes between self and outer query should be possible.
|
# No clashes between self and outer query should be possible.
|
||||||
return
|
return
|
||||||
|
|
||||||
local_recursion_limit = 127 # explicitly avoid infinite loop
|
local_recursion_limit = 67 # explicitly avoid infinite loop
|
||||||
for pos, prefix in enumerate(prefix_gen()):
|
for pos, prefix in enumerate(prefix_gen()):
|
||||||
if prefix not in self.subq_aliases:
|
if prefix not in self.subq_aliases:
|
||||||
self.alias_prefix = prefix
|
self.alias_prefix = prefix
|
||||||
@ -997,6 +1004,21 @@ class Query:
|
|||||||
not self.distinct_fields and
|
not self.distinct_fields and
|
||||||
not self.select_for_update):
|
not self.select_for_update):
|
||||||
clone.clear_ordering(True)
|
clone.clear_ordering(True)
|
||||||
|
clone.where.resolve_expression(query, *args, **kwargs)
|
||||||
|
for key, value in clone.annotations.items():
|
||||||
|
resolved = value.resolve_expression(query, *args, **kwargs)
|
||||||
|
if hasattr(resolved, 'external_aliases'):
|
||||||
|
resolved.external_aliases.update(clone.alias_map)
|
||||||
|
clone.annotations[key] = resolved
|
||||||
|
# Outer query's aliases are considered external.
|
||||||
|
clone.external_aliases.update(
|
||||||
|
alias for alias, table in query.alias_map.items()
|
||||||
|
if (
|
||||||
|
isinstance(table, Join) and table.join_field.related_model._meta.db_table != alias
|
||||||
|
) or (
|
||||||
|
isinstance(table, BaseTable) and table.table_name != table.table_alias
|
||||||
|
)
|
||||||
|
)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
@ -183,8 +183,23 @@ class WhereNode(tree.Node):
|
|||||||
def is_summary(self):
|
def is_summary(self):
|
||||||
return any(child.is_summary for child in self.children)
|
return any(child.is_summary for child in self.children)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_rhs(rhs, query, *args, **kwargs):
|
||||||
|
if hasattr(rhs, 'resolve_expression'):
|
||||||
|
rhs = rhs.resolve_expression(query, *args, **kwargs)
|
||||||
|
return rhs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_node(cls, node, query, *args, **kwargs):
|
||||||
|
if hasattr(node, 'children'):
|
||||||
|
for child in node.children:
|
||||||
|
cls._resolve_node(child, query, *args, **kwargs)
|
||||||
|
if hasattr(node, 'rhs'):
|
||||||
|
node.rhs = cls._resolve_rhs(node.rhs, query, *args, **kwargs)
|
||||||
|
|
||||||
def resolve_expression(self, *args, **kwargs):
|
def resolve_expression(self, *args, **kwargs):
|
||||||
clone = self.clone()
|
clone = self.clone()
|
||||||
|
clone._resolve_node(clone, *args, **kwargs)
|
||||||
clone.resolved = True
|
clone.resolved = True
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ class Queries1Tests(TestCase):
|
|||||||
|
|
||||||
def test_avoid_infinite_loop_on_too_many_subqueries(self):
|
def test_avoid_infinite_loop_on_too_many_subqueries(self):
|
||||||
x = Tag.objects.filter(pk=1)
|
x = Tag.objects.filter(pk=1)
|
||||||
local_recursion_limit = 127
|
local_recursion_limit = 67
|
||||||
msg = 'Maximum recursion depth exceeded: too many subqueries.'
|
msg = 'Maximum recursion depth exceeded: too many subqueries.'
|
||||||
with self.assertRaisesMessage(RuntimeError, msg):
|
with self.assertRaisesMessage(RuntimeError, msg):
|
||||||
for i in range(local_recursion_limit * 2):
|
for i in range(local_recursion_limit * 2):
|
||||||
|
Loading…
Reference in New Issue
Block a user