mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fixed #31679 -- Delayed annotating aggregations.
By avoiding to annotate aggregations meant to be possibly pushed to an outer query until their references are resolved it is possible to aggregate over a query with the same alias. Even if #34176 is a convoluted case to support, this refactor seems worth it given the reduction in complexity it brings with regards to annotation removal when performing a subquery pushdown.
This commit is contained in:
parent
d526d1569c
commit
1297c0d0d7
@ -23,7 +23,7 @@ from django.db import (
|
||||
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
|
||||
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
||||
from django.db.models.deletion import Collector
|
||||
from django.db.models.expressions import Case, F, Ref, Value, When
|
||||
from django.db.models.expressions import Case, F, Value, When
|
||||
from django.db.models.functions import Cast, Trunc
|
||||
from django.db.models.query_utils import FilteredRelation, Q
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
||||
@ -589,24 +589,7 @@ class QuerySet(AltersData):
|
||||
raise TypeError("Complex aggregates require an alias")
|
||||
kwargs[arg.default_alias] = arg
|
||||
|
||||
query = self.query.chain()
|
||||
for (alias, aggregate_expr) in kwargs.items():
|
||||
query.add_annotation(aggregate_expr, alias, is_summary=True)
|
||||
annotation = query.annotations[alias]
|
||||
if not annotation.contains_aggregate:
|
||||
raise TypeError("%s is not an aggregate expression" % alias)
|
||||
for expr in annotation.get_source_expressions():
|
||||
if (
|
||||
expr.contains_aggregate
|
||||
and isinstance(expr, Ref)
|
||||
and expr.refs in kwargs
|
||||
):
|
||||
name = expr.refs
|
||||
raise exceptions.FieldError(
|
||||
"Cannot compute %s('%s'): '%s' is an aggregate"
|
||||
% (annotation.name, name, name)
|
||||
)
|
||||
return query.get_aggregation(self.db, kwargs)
|
||||
return self.query.chain().get_aggregation(self.db, kwargs)
|
||||
|
||||
async def aaggregate(self, *args, **kwargs):
|
||||
return await sync_to_async(self.aggregate)(*args, **kwargs)
|
||||
@ -1655,7 +1638,6 @@ class QuerySet(AltersData):
|
||||
clone.query.add_annotation(
|
||||
annotation,
|
||||
alias,
|
||||
is_summary=False,
|
||||
select=select,
|
||||
)
|
||||
for alias, annotation in clone.query.annotations.items():
|
||||
|
@ -381,24 +381,28 @@ class Query(BaseExpression):
|
||||
alias = None
|
||||
return target.get_col(alias, field)
|
||||
|
||||
def get_aggregation(self, using, added_aggregate_names):
|
||||
def get_aggregation(self, using, aggregate_exprs):
|
||||
"""
|
||||
Return the dictionary with the values of the existing aggregations.
|
||||
"""
|
||||
if not self.annotation_select:
|
||||
if not aggregate_exprs:
|
||||
return {}
|
||||
existing_annotations = {
|
||||
alias: annotation
|
||||
for alias, annotation in self.annotations.items()
|
||||
if alias not in added_aggregate_names
|
||||
}
|
||||
aggregates = {}
|
||||
for alias, aggregate_expr in aggregate_exprs.items():
|
||||
self.check_alias(alias)
|
||||
aggregate = aggregate_expr.resolve_expression(
|
||||
self, allow_joins=True, reuse=None, summarize=True
|
||||
)
|
||||
if not aggregate.contains_aggregate:
|
||||
raise TypeError("%s is not an aggregate expression" % alias)
|
||||
aggregates[alias] = aggregate
|
||||
# Existing usage of aggregation can be determined by the presence of
|
||||
# selected aggregates but also by filters against aliased aggregates.
|
||||
_, having, qualify = self.where.split_having_qualify()
|
||||
has_existing_aggregation = (
|
||||
any(
|
||||
getattr(annotation, "contains_aggregate", True)
|
||||
for annotation in existing_annotations.values()
|
||||
for annotation in self.annotations.values()
|
||||
)
|
||||
or having
|
||||
)
|
||||
@ -449,25 +453,19 @@ class Query(BaseExpression):
|
||||
# filtering against window functions is involved as it
|
||||
# requires complex realising.
|
||||
annotation_mask = set()
|
||||
for name in added_aggregate_names:
|
||||
annotation_mask.add(name)
|
||||
annotation_mask |= inner_query.annotations[name].get_refs()
|
||||
for aggregate in aggregates.values():
|
||||
annotation_mask |= aggregate.get_refs()
|
||||
inner_query.set_annotation_mask(annotation_mask)
|
||||
|
||||
# Remove any aggregates marked for reduction from the subquery and
|
||||
# move them to the outer AggregateQuery. This requires making sure
|
||||
# all columns referenced by the aggregates are selected in the
|
||||
# subquery. It is achieved by retrieving all column references from
|
||||
# the aggregates, explicitly selecting them if they are not
|
||||
# already, and making sure the aggregates are repointed to
|
||||
# referenced to them.
|
||||
# Add aggregates to the outer AggregateQuery. This requires making
|
||||
# sure all columns referenced by the aggregates are selected in the
|
||||
# inner query. It is achieved by retrieving all column references
|
||||
# by the aggregates, explicitly selecting them in the inner query,
|
||||
# and making sure the aggregates are repointed to them.
|
||||
col_refs = {}
|
||||
for alias, expression in list(inner_query.annotation_select.items()):
|
||||
if not expression.is_summary:
|
||||
continue
|
||||
annotation_select_mask = inner_query.annotation_select_mask
|
||||
for alias, aggregate in aggregates.items():
|
||||
replacements = {}
|
||||
for col in self._gen_cols([expression], resolve_refs=False):
|
||||
for col in self._gen_cols([aggregate], resolve_refs=False):
|
||||
if not (col_ref := col_refs.get(col)):
|
||||
index = len(col_refs) + 1
|
||||
col_alias = f"__col{index}"
|
||||
@ -476,13 +474,9 @@ class Query(BaseExpression):
|
||||
inner_query.annotations[col_alias] = col
|
||||
inner_query.append_annotation_mask([col_alias])
|
||||
replacements[col] = col_ref
|
||||
outer_query.annotations[alias] = expression.replace_expressions(
|
||||
outer_query.annotations[alias] = aggregate.replace_expressions(
|
||||
replacements
|
||||
)
|
||||
del inner_query.annotations[alias]
|
||||
annotation_select_mask.remove(alias)
|
||||
# Make sure the annotation_select wont use cached results.
|
||||
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
|
||||
if (
|
||||
inner_query.select == ()
|
||||
and not inner_query.default_cols
|
||||
@ -499,19 +493,21 @@ class Query(BaseExpression):
|
||||
self.select = ()
|
||||
self.default_cols = False
|
||||
self.extra = {}
|
||||
if existing_annotations:
|
||||
if self.annotations:
|
||||
# Inline reference to existing annotations and mask them as
|
||||
# they are unnecessary given only the summarized aggregations
|
||||
# are requested.
|
||||
replacements = {
|
||||
Ref(alias, annotation): annotation
|
||||
for alias, annotation in existing_annotations.items()
|
||||
for alias, annotation in self.annotations.items()
|
||||
}
|
||||
for name in added_aggregate_names:
|
||||
self.annotations[name] = self.annotations[name].replace_expressions(
|
||||
replacements
|
||||
)
|
||||
self.set_annotation_mask(added_aggregate_names)
|
||||
self.annotations = {
|
||||
alias: aggregate.replace_expressions(replacements)
|
||||
for alias, aggregate in aggregates.items()
|
||||
}
|
||||
else:
|
||||
self.annotations = aggregates
|
||||
self.set_annotation_mask(aggregates)
|
||||
|
||||
empty_set_result = [
|
||||
expression.empty_result_set_value
|
||||
@ -537,8 +533,7 @@ class Query(BaseExpression):
|
||||
Perform a COUNT() query using the current filter constraints.
|
||||
"""
|
||||
obj = self.clone()
|
||||
obj.add_annotation(Count("*"), alias="__count", is_summary=True)
|
||||
return obj.get_aggregation(using, ["__count"])["__count"]
|
||||
return obj.get_aggregation(using, {"__count": Count("*")})["__count"]
|
||||
|
||||
def has_filters(self):
|
||||
return self.where
|
||||
@ -1085,12 +1080,10 @@ class Query(BaseExpression):
|
||||
"semicolons, or SQL comments."
|
||||
)
|
||||
|
||||
def add_annotation(self, annotation, alias, is_summary=False, select=True):
|
||||
def add_annotation(self, annotation, alias, select=True):
|
||||
"""Add a single annotation expression to the Query."""
|
||||
self.check_alias(alias)
|
||||
annotation = annotation.resolve_expression(
|
||||
self, allow_joins=True, reuse=None, summarize=is_summary
|
||||
)
|
||||
annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
|
||||
if select:
|
||||
self.append_annotation_mask([alias])
|
||||
else:
|
||||
|
@ -395,6 +395,9 @@ Miscellaneous
|
||||
* The undocumented ``negated`` parameter of the
|
||||
:class:`~django.db.models.Exists` expression is removed.
|
||||
|
||||
* The ``is_summary`` argument of the undocumented ``Query.add_annotation()``
|
||||
method is removed.
|
||||
|
||||
.. _deprecated-features-4.2:
|
||||
|
||||
Features deprecated in 4.2
|
||||
|
@ -1258,11 +1258,11 @@ class AggregateTestCase(TestCase):
|
||||
self.assertEqual(author.sum_age, other_author.sum_age)
|
||||
|
||||
def test_aggregate_over_aggregate(self):
|
||||
msg = "Cannot compute Avg('age'): 'age' is an aggregate"
|
||||
msg = "Cannot resolve keyword 'age_agg' into field."
|
||||
with self.assertRaisesMessage(FieldError, msg):
|
||||
Author.objects.annotate(age_alias=F("age"),).aggregate(
|
||||
age=Sum(F("age")),
|
||||
avg_age=Avg(F("age")),
|
||||
Author.objects.aggregate(
|
||||
age_agg=Sum(F("age")),
|
||||
avg_age=Avg(F("age_agg")),
|
||||
)
|
||||
|
||||
def test_annotated_aggregate_over_annotated_aggregate(self):
|
||||
@ -2086,6 +2086,14 @@ class AggregateTestCase(TestCase):
|
||||
)
|
||||
self.assertEqual(len(qs), 6)
|
||||
|
||||
def test_aggregation_over_annotation_shared_alias(self):
|
||||
self.assertEqual(
|
||||
Publisher.objects.annotate(agg=Count("book__authors"),).aggregate(
|
||||
agg=Count("agg"),
|
||||
),
|
||||
{"agg": 5},
|
||||
)
|
||||
|
||||
|
||||
class AggregateAnnotationPruningTests(TestCase):
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user