Refs #25307 -- Replaced SQLQuery.rewrite_cols() by replace_expressions().

The latter offers a more generic interface that doesn't require
specialized expression types handling.
This commit is contained in:
Simon Charette 2022-11-12 10:06:00 -05:00 committed by Mariusz Felisiak
parent 1771998c09
commit b181cae2e3
1 changed files with 25 additions and 62 deletions

View File

@ -381,60 +381,6 @@ class Query(BaseExpression):
alias = None
return target.get_col(alias, field)
def rewrite_cols(self, annotation, col_cnt):
# We must make sure the inner query has the referred columns in it.
# If we are aggregating over an annotation, then Django uses Ref()
# instances to note this. However, if we are annotating over a column
# of a related model, then it might be that column isn't part of the
# SELECT clause of the inner query, and we must manually make sure
# the column is selected. An example case is:
# .aggregate(Sum('author__awards'))
# Resolving this expression results in a join to author, but there
# is no guarantee the awards column of author is in the select clause
# of the query. Thus we must manually add the column to the inner
# query.
orig_exprs = annotation.get_source_expressions()
new_exprs = []
for expr in orig_exprs:
# FIXME: These conditions are fairly arbitrary. Identify a better
# method of having expressions decide which code path they should
# take.
if isinstance(expr, Ref):
# Its already a Ref to subquery (see resolve_ref() for
# details)
new_exprs.append(expr)
elif isinstance(expr, (WhereNode, Lookup)):
# Decompose the subexpressions further. The code here is
# copied from the else clause, but this condition must appear
# before the contains_aggregate/is_summary condition below.
new_expr, col_cnt = self.rewrite_cols(expr, col_cnt)
new_exprs.append(new_expr)
else:
# Reuse aliases of expressions already selected in subquery.
for col_alias, selected_annotation in self.annotation_select.items():
if selected_annotation is expr:
new_expr = Ref(col_alias, expr)
break
else:
# An expression that is not selected the subquery.
if isinstance(expr, Col) or (
expr.contains_aggregate and not expr.is_summary
):
# Reference column or another aggregate. Select it
# under a non-conflicting alias.
col_cnt += 1
col_alias = "__col%d" % col_cnt
self.annotations[col_alias] = expr
self.append_annotation_mask([col_alias])
new_expr = Ref(col_alias, expr)
else:
# Some other expression not referencing database values
# directly. Its subexpression might contain Cols.
new_expr, col_cnt = self.rewrite_cols(expr, col_cnt)
new_exprs.append(new_expr)
annotation.set_source_expressions(new_exprs)
return annotation, col_cnt
def get_aggregation(self, using, added_aggregate_names):
"""
Return the dictionary with the values of the existing aggregations.
@ -508,17 +454,31 @@ class Query(BaseExpression):
annotation_mask |= inner_query.annotations[name].get_refs()
inner_query.set_annotation_mask(annotation_mask)
relabels = {t: "subquery" for t in inner_query.alias_map}
relabels[None] = "subquery"
# Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery.
col_cnt = 0
# Remove any aggregates marked for reduction from the subquery and
# move them to the outer AggregateQuery. This requires making sure
# all columns referenced by the aggregates are selected in the
# subquery. It is achieved by retrieving all column references from
# the aggregates, explicitly selecting them if they are not
# already, and making sure the aggregates are repointed to
# referenced to them.
col_refs = {}
for alias, expression in list(inner_query.annotation_select.items()):
if not expression.is_summary:
continue
annotation_select_mask = inner_query.annotation_select_mask
expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt)
outer_query.annotations[alias] = expression.relabeled_clone(relabels)
replacements = {}
for col in self._gen_cols([expression], resolve_refs=False):
if not (col_ref := col_refs.get(col)):
index = len(col_refs) + 1
col_alias = f"__col{index}"
col_ref = Ref(col_alias, col)
col_refs[col] = col_ref
inner_query.annotations[col_alias] = col
inner_query.append_annotation_mask([col_alias])
replacements[col] = col_ref
outer_query.annotations[alias] = expression.replace_expressions(
replacements
)
del inner_query.annotations[alias]
annotation_select_mask.remove(alias)
# Make sure the annotation_select wont use cached results.
@ -1923,7 +1883,7 @@ class Query(BaseExpression):
return targets, joins[-1], joins
@classmethod
def _gen_cols(cls, exprs, include_external=False):
def _gen_cols(cls, exprs, include_external=False, resolve_refs=True):
for expr in exprs:
if isinstance(expr, Col):
yield expr
@ -1932,9 +1892,12 @@ class Query(BaseExpression):
):
yield from expr.get_external_cols()
elif hasattr(expr, "get_source_expressions"):
if not resolve_refs and isinstance(expr, Ref):
continue
yield from cls._gen_cols(
expr.get_source_expressions(),
include_external=include_external,
resolve_refs=resolve_refs,
)
@classmethod