mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fixed #31679 -- Delayed annotating aggregations.
By avoiding to annotate aggregations meant to be possibly pushed to an outer query until their references are resolved it is possible to aggregate over a query with the same alias. Even if #34176 is a convoluted case to support, this refactor seems worth it given the reduction in complexity it brings with regards to annotation removal when performing a subquery pushdown.
This commit is contained in:
parent
d526d1569c
commit
1297c0d0d7
@ -23,7 +23,7 @@ from django.db import (
|
|||||||
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
|
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
|
||||||
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
||||||
from django.db.models.deletion import Collector
|
from django.db.models.deletion import Collector
|
||||||
from django.db.models.expressions import Case, F, Ref, Value, When
|
from django.db.models.expressions import Case, F, Value, When
|
||||||
from django.db.models.functions import Cast, Trunc
|
from django.db.models.functions import Cast, Trunc
|
||||||
from django.db.models.query_utils import FilteredRelation, Q
|
from django.db.models.query_utils import FilteredRelation, Q
|
||||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
||||||
@ -589,24 +589,7 @@ class QuerySet(AltersData):
|
|||||||
raise TypeError("Complex aggregates require an alias")
|
raise TypeError("Complex aggregates require an alias")
|
||||||
kwargs[arg.default_alias] = arg
|
kwargs[arg.default_alias] = arg
|
||||||
|
|
||||||
query = self.query.chain()
|
return self.query.chain().get_aggregation(self.db, kwargs)
|
||||||
for (alias, aggregate_expr) in kwargs.items():
|
|
||||||
query.add_annotation(aggregate_expr, alias, is_summary=True)
|
|
||||||
annotation = query.annotations[alias]
|
|
||||||
if not annotation.contains_aggregate:
|
|
||||||
raise TypeError("%s is not an aggregate expression" % alias)
|
|
||||||
for expr in annotation.get_source_expressions():
|
|
||||||
if (
|
|
||||||
expr.contains_aggregate
|
|
||||||
and isinstance(expr, Ref)
|
|
||||||
and expr.refs in kwargs
|
|
||||||
):
|
|
||||||
name = expr.refs
|
|
||||||
raise exceptions.FieldError(
|
|
||||||
"Cannot compute %s('%s'): '%s' is an aggregate"
|
|
||||||
% (annotation.name, name, name)
|
|
||||||
)
|
|
||||||
return query.get_aggregation(self.db, kwargs)
|
|
||||||
|
|
||||||
async def aaggregate(self, *args, **kwargs):
|
async def aaggregate(self, *args, **kwargs):
|
||||||
return await sync_to_async(self.aggregate)(*args, **kwargs)
|
return await sync_to_async(self.aggregate)(*args, **kwargs)
|
||||||
@ -1655,7 +1638,6 @@ class QuerySet(AltersData):
|
|||||||
clone.query.add_annotation(
|
clone.query.add_annotation(
|
||||||
annotation,
|
annotation,
|
||||||
alias,
|
alias,
|
||||||
is_summary=False,
|
|
||||||
select=select,
|
select=select,
|
||||||
)
|
)
|
||||||
for alias, annotation in clone.query.annotations.items():
|
for alias, annotation in clone.query.annotations.items():
|
||||||
|
@ -381,24 +381,28 @@ class Query(BaseExpression):
|
|||||||
alias = None
|
alias = None
|
||||||
return target.get_col(alias, field)
|
return target.get_col(alias, field)
|
||||||
|
|
||||||
def get_aggregation(self, using, added_aggregate_names):
|
def get_aggregation(self, using, aggregate_exprs):
|
||||||
"""
|
"""
|
||||||
Return the dictionary with the values of the existing aggregations.
|
Return the dictionary with the values of the existing aggregations.
|
||||||
"""
|
"""
|
||||||
if not self.annotation_select:
|
if not aggregate_exprs:
|
||||||
return {}
|
return {}
|
||||||
existing_annotations = {
|
aggregates = {}
|
||||||
alias: annotation
|
for alias, aggregate_expr in aggregate_exprs.items():
|
||||||
for alias, annotation in self.annotations.items()
|
self.check_alias(alias)
|
||||||
if alias not in added_aggregate_names
|
aggregate = aggregate_expr.resolve_expression(
|
||||||
}
|
self, allow_joins=True, reuse=None, summarize=True
|
||||||
|
)
|
||||||
|
if not aggregate.contains_aggregate:
|
||||||
|
raise TypeError("%s is not an aggregate expression" % alias)
|
||||||
|
aggregates[alias] = aggregate
|
||||||
# Existing usage of aggregation can be determined by the presence of
|
# Existing usage of aggregation can be determined by the presence of
|
||||||
# selected aggregates but also by filters against aliased aggregates.
|
# selected aggregates but also by filters against aliased aggregates.
|
||||||
_, having, qualify = self.where.split_having_qualify()
|
_, having, qualify = self.where.split_having_qualify()
|
||||||
has_existing_aggregation = (
|
has_existing_aggregation = (
|
||||||
any(
|
any(
|
||||||
getattr(annotation, "contains_aggregate", True)
|
getattr(annotation, "contains_aggregate", True)
|
||||||
for annotation in existing_annotations.values()
|
for annotation in self.annotations.values()
|
||||||
)
|
)
|
||||||
or having
|
or having
|
||||||
)
|
)
|
||||||
@ -449,25 +453,19 @@ class Query(BaseExpression):
|
|||||||
# filtering against window functions is involved as it
|
# filtering against window functions is involved as it
|
||||||
# requires complex realising.
|
# requires complex realising.
|
||||||
annotation_mask = set()
|
annotation_mask = set()
|
||||||
for name in added_aggregate_names:
|
for aggregate in aggregates.values():
|
||||||
annotation_mask.add(name)
|
annotation_mask |= aggregate.get_refs()
|
||||||
annotation_mask |= inner_query.annotations[name].get_refs()
|
|
||||||
inner_query.set_annotation_mask(annotation_mask)
|
inner_query.set_annotation_mask(annotation_mask)
|
||||||
|
|
||||||
# Remove any aggregates marked for reduction from the subquery and
|
# Add aggregates to the outer AggregateQuery. This requires making
|
||||||
# move them to the outer AggregateQuery. This requires making sure
|
# sure all columns referenced by the aggregates are selected in the
|
||||||
# all columns referenced by the aggregates are selected in the
|
# inner query. It is achieved by retrieving all column references
|
||||||
# subquery. It is achieved by retrieving all column references from
|
# by the aggregates, explicitly selecting them in the inner query,
|
||||||
# the aggregates, explicitly selecting them if they are not
|
# and making sure the aggregates are repointed to them.
|
||||||
# already, and making sure the aggregates are repointed to
|
|
||||||
# referenced to them.
|
|
||||||
col_refs = {}
|
col_refs = {}
|
||||||
for alias, expression in list(inner_query.annotation_select.items()):
|
for alias, aggregate in aggregates.items():
|
||||||
if not expression.is_summary:
|
|
||||||
continue
|
|
||||||
annotation_select_mask = inner_query.annotation_select_mask
|
|
||||||
replacements = {}
|
replacements = {}
|
||||||
for col in self._gen_cols([expression], resolve_refs=False):
|
for col in self._gen_cols([aggregate], resolve_refs=False):
|
||||||
if not (col_ref := col_refs.get(col)):
|
if not (col_ref := col_refs.get(col)):
|
||||||
index = len(col_refs) + 1
|
index = len(col_refs) + 1
|
||||||
col_alias = f"__col{index}"
|
col_alias = f"__col{index}"
|
||||||
@ -476,13 +474,9 @@ class Query(BaseExpression):
|
|||||||
inner_query.annotations[col_alias] = col
|
inner_query.annotations[col_alias] = col
|
||||||
inner_query.append_annotation_mask([col_alias])
|
inner_query.append_annotation_mask([col_alias])
|
||||||
replacements[col] = col_ref
|
replacements[col] = col_ref
|
||||||
outer_query.annotations[alias] = expression.replace_expressions(
|
outer_query.annotations[alias] = aggregate.replace_expressions(
|
||||||
replacements
|
replacements
|
||||||
)
|
)
|
||||||
del inner_query.annotations[alias]
|
|
||||||
annotation_select_mask.remove(alias)
|
|
||||||
# Make sure the annotation_select wont use cached results.
|
|
||||||
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
|
|
||||||
if (
|
if (
|
||||||
inner_query.select == ()
|
inner_query.select == ()
|
||||||
and not inner_query.default_cols
|
and not inner_query.default_cols
|
||||||
@ -499,19 +493,21 @@ class Query(BaseExpression):
|
|||||||
self.select = ()
|
self.select = ()
|
||||||
self.default_cols = False
|
self.default_cols = False
|
||||||
self.extra = {}
|
self.extra = {}
|
||||||
if existing_annotations:
|
if self.annotations:
|
||||||
# Inline reference to existing annotations and mask them as
|
# Inline reference to existing annotations and mask them as
|
||||||
# they are unnecessary given only the summarized aggregations
|
# they are unnecessary given only the summarized aggregations
|
||||||
# are requested.
|
# are requested.
|
||||||
replacements = {
|
replacements = {
|
||||||
Ref(alias, annotation): annotation
|
Ref(alias, annotation): annotation
|
||||||
for alias, annotation in existing_annotations.items()
|
for alias, annotation in self.annotations.items()
|
||||||
}
|
}
|
||||||
for name in added_aggregate_names:
|
self.annotations = {
|
||||||
self.annotations[name] = self.annotations[name].replace_expressions(
|
alias: aggregate.replace_expressions(replacements)
|
||||||
replacements
|
for alias, aggregate in aggregates.items()
|
||||||
)
|
}
|
||||||
self.set_annotation_mask(added_aggregate_names)
|
else:
|
||||||
|
self.annotations = aggregates
|
||||||
|
self.set_annotation_mask(aggregates)
|
||||||
|
|
||||||
empty_set_result = [
|
empty_set_result = [
|
||||||
expression.empty_result_set_value
|
expression.empty_result_set_value
|
||||||
@ -537,8 +533,7 @@ class Query(BaseExpression):
|
|||||||
Perform a COUNT() query using the current filter constraints.
|
Perform a COUNT() query using the current filter constraints.
|
||||||
"""
|
"""
|
||||||
obj = self.clone()
|
obj = self.clone()
|
||||||
obj.add_annotation(Count("*"), alias="__count", is_summary=True)
|
return obj.get_aggregation(using, {"__count": Count("*")})["__count"]
|
||||||
return obj.get_aggregation(using, ["__count"])["__count"]
|
|
||||||
|
|
||||||
def has_filters(self):
|
def has_filters(self):
|
||||||
return self.where
|
return self.where
|
||||||
@ -1085,12 +1080,10 @@ class Query(BaseExpression):
|
|||||||
"semicolons, or SQL comments."
|
"semicolons, or SQL comments."
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_annotation(self, annotation, alias, is_summary=False, select=True):
|
def add_annotation(self, annotation, alias, select=True):
|
||||||
"""Add a single annotation expression to the Query."""
|
"""Add a single annotation expression to the Query."""
|
||||||
self.check_alias(alias)
|
self.check_alias(alias)
|
||||||
annotation = annotation.resolve_expression(
|
annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
|
||||||
self, allow_joins=True, reuse=None, summarize=is_summary
|
|
||||||
)
|
|
||||||
if select:
|
if select:
|
||||||
self.append_annotation_mask([alias])
|
self.append_annotation_mask([alias])
|
||||||
else:
|
else:
|
||||||
|
@ -395,6 +395,9 @@ Miscellaneous
|
|||||||
* The undocumented ``negated`` parameter of the
|
* The undocumented ``negated`` parameter of the
|
||||||
:class:`~django.db.models.Exists` expression is removed.
|
:class:`~django.db.models.Exists` expression is removed.
|
||||||
|
|
||||||
|
* The ``is_summary`` argument of the undocumented ``Query.add_annotation()``
|
||||||
|
method is removed.
|
||||||
|
|
||||||
.. _deprecated-features-4.2:
|
.. _deprecated-features-4.2:
|
||||||
|
|
||||||
Features deprecated in 4.2
|
Features deprecated in 4.2
|
||||||
|
@ -1258,11 +1258,11 @@ class AggregateTestCase(TestCase):
|
|||||||
self.assertEqual(author.sum_age, other_author.sum_age)
|
self.assertEqual(author.sum_age, other_author.sum_age)
|
||||||
|
|
||||||
def test_aggregate_over_aggregate(self):
|
def test_aggregate_over_aggregate(self):
|
||||||
msg = "Cannot compute Avg('age'): 'age' is an aggregate"
|
msg = "Cannot resolve keyword 'age_agg' into field."
|
||||||
with self.assertRaisesMessage(FieldError, msg):
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
Author.objects.annotate(age_alias=F("age"),).aggregate(
|
Author.objects.aggregate(
|
||||||
age=Sum(F("age")),
|
age_agg=Sum(F("age")),
|
||||||
avg_age=Avg(F("age")),
|
avg_age=Avg(F("age_agg")),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_annotated_aggregate_over_annotated_aggregate(self):
|
def test_annotated_aggregate_over_annotated_aggregate(self):
|
||||||
@ -2086,6 +2086,14 @@ class AggregateTestCase(TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(len(qs), 6)
|
self.assertEqual(len(qs), 6)
|
||||||
|
|
||||||
|
def test_aggregation_over_annotation_shared_alias(self):
|
||||||
|
self.assertEqual(
|
||||||
|
Publisher.objects.annotate(agg=Count("book__authors"),).aggregate(
|
||||||
|
agg=Count("agg"),
|
||||||
|
),
|
||||||
|
{"agg": 5},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AggregateAnnotationPruningTests(TestCase):
|
class AggregateAnnotationPruningTests(TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user