From 1297c0d0d76a708017fe196b61a0ab324df76954 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Tue, 22 Nov 2022 21:49:12 -0500 Subject: [PATCH] Fixed #31679 -- Delayed annotating aggregations. By avoiding to annotate aggregations meant to be possibly pushed to an outer query until their references are resolved it is possible to aggregate over a query with the same alias. Even if #34176 is a convoluted case to support, this refactor seems worth it given the reduction in complexity it brings with regards to annotation removal when performing a subquery pushdown. --- django/db/models/query.py | 22 +--------- django/db/models/sql/query.py | 75 ++++++++++++++++------------------- docs/releases/4.2.txt | 3 ++ tests/aggregation/tests.py | 16 ++++++-- 4 files changed, 51 insertions(+), 65 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index cf419cb8cf..13d24bb871 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -23,7 +23,7 @@ from django.db import ( from django.db.models import AutoField, DateField, DateTimeField, Field, sql from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.deletion import Collector -from django.db.models.expressions import Case, F, Ref, Value, When +from django.db.models.expressions import Case, F, Value, When from django.db.models.functions import Cast, Trunc from django.db.models.query_utils import FilteredRelation, Q from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE @@ -589,24 +589,7 @@ class QuerySet(AltersData): raise TypeError("Complex aggregates require an alias") kwargs[arg.default_alias] = arg - query = self.query.chain() - for (alias, aggregate_expr) in kwargs.items(): - query.add_annotation(aggregate_expr, alias, is_summary=True) - annotation = query.annotations[alias] - if not annotation.contains_aggregate: - raise TypeError("%s is not an aggregate expression" % alias) - for expr in annotation.get_source_expressions(): - if ( - expr.contains_aggregate - and isinstance(expr, Ref) - and expr.refs in kwargs - ): - name = expr.refs - raise exceptions.FieldError( - "Cannot compute %s('%s'): '%s' is an aggregate" - % (annotation.name, name, name) - ) - return query.get_aggregation(self.db, kwargs) + return self.query.chain().get_aggregation(self.db, kwargs) async def aaggregate(self, *args, **kwargs): return await sync_to_async(self.aggregate)(*args, **kwargs) @@ -1655,7 +1638,6 @@ class QuerySet(AltersData): clone.query.add_annotation( annotation, alias, - is_summary=False, select=select, ) for alias, annotation in clone.query.annotations.items(): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 2d150ed6d8..521054f69e 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -381,24 +381,28 @@ class Query(BaseExpression): alias = None return target.get_col(alias, field) - def get_aggregation(self, using, added_aggregate_names): + def get_aggregation(self, using, aggregate_exprs): """ Return the dictionary with the values of the existing aggregations. """ - if not self.annotation_select: + if not aggregate_exprs: return {} - existing_annotations = { - alias: annotation - for alias, annotation in self.annotations.items() - if alias not in added_aggregate_names - } + aggregates = {} + for alias, aggregate_expr in aggregate_exprs.items(): + self.check_alias(alias) + aggregate = aggregate_expr.resolve_expression( + self, allow_joins=True, reuse=None, summarize=True + ) + if not aggregate.contains_aggregate: + raise TypeError("%s is not an aggregate expression" % alias) + aggregates[alias] = aggregate # Existing usage of aggregation can be determined by the presence of # selected aggregates but also by filters against aliased aggregates. _, having, qualify = self.where.split_having_qualify() has_existing_aggregation = ( any( getattr(annotation, "contains_aggregate", True) - for annotation in existing_annotations.values() + for annotation in self.annotations.values() ) or having ) @@ -449,25 +453,19 @@ class Query(BaseExpression): # filtering against window functions is involved as it # requires complex realising. annotation_mask = set() - for name in added_aggregate_names: - annotation_mask.add(name) - annotation_mask |= inner_query.annotations[name].get_refs() + for aggregate in aggregates.values(): + annotation_mask |= aggregate.get_refs() inner_query.set_annotation_mask(annotation_mask) - # Remove any aggregates marked for reduction from the subquery and - # move them to the outer AggregateQuery. This requires making sure - # all columns referenced by the aggregates are selected in the - # subquery. It is achieved by retrieving all column references from - # the aggregates, explicitly selecting them if they are not - # already, and making sure the aggregates are repointed to - # referenced to them. + # Add aggregates to the outer AggregateQuery. This requires making + # sure all columns referenced by the aggregates are selected in the + # inner query. It is achieved by retrieving all column references + # by the aggregates, explicitly selecting them in the inner query, + # and making sure the aggregates are repointed to them. col_refs = {} - for alias, expression in list(inner_query.annotation_select.items()): - if not expression.is_summary: - continue - annotation_select_mask = inner_query.annotation_select_mask + for alias, aggregate in aggregates.items(): replacements = {} - for col in self._gen_cols([expression], resolve_refs=False): + for col in self._gen_cols([aggregate], resolve_refs=False): if not (col_ref := col_refs.get(col)): index = len(col_refs) + 1 col_alias = f"__col{index}" @@ -476,13 +474,9 @@ class Query(BaseExpression): inner_query.annotations[col_alias] = col inner_query.append_annotation_mask([col_alias]) replacements[col] = col_ref - outer_query.annotations[alias] = expression.replace_expressions( + outer_query.annotations[alias] = aggregate.replace_expressions( replacements ) - del inner_query.annotations[alias] - annotation_select_mask.remove(alias) - # Make sure the annotation_select wont use cached results. - inner_query.set_annotation_mask(inner_query.annotation_select_mask) if ( inner_query.select == () and not inner_query.default_cols @@ -499,19 +493,21 @@ class Query(BaseExpression): self.select = () self.default_cols = False self.extra = {} - if existing_annotations: + if self.annotations: # Inline reference to existing annotations and mask them as # they are unnecessary given only the summarized aggregations # are requested. replacements = { Ref(alias, annotation): annotation - for alias, annotation in existing_annotations.items() + for alias, annotation in self.annotations.items() } - for name in added_aggregate_names: - self.annotations[name] = self.annotations[name].replace_expressions( - replacements - ) - self.set_annotation_mask(added_aggregate_names) + self.annotations = { + alias: aggregate.replace_expressions(replacements) + for alias, aggregate in aggregates.items() + } + else: + self.annotations = aggregates + self.set_annotation_mask(aggregates) empty_set_result = [ expression.empty_result_set_value @@ -537,8 +533,7 @@ class Query(BaseExpression): Perform a COUNT() query using the current filter constraints. """ obj = self.clone() - obj.add_annotation(Count("*"), alias="__count", is_summary=True) - return obj.get_aggregation(using, ["__count"])["__count"] + return obj.get_aggregation(using, {"__count": Count("*")})["__count"] def has_filters(self): return self.where @@ -1085,12 +1080,10 @@ class Query(BaseExpression): "semicolons, or SQL comments." ) - def add_annotation(self, annotation, alias, is_summary=False, select=True): + def add_annotation(self, annotation, alias, select=True): """Add a single annotation expression to the Query.""" self.check_alias(alias) - annotation = annotation.resolve_expression( - self, allow_joins=True, reuse=None, summarize=is_summary - ) + annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None) if select: self.append_annotation_mask([alias]) else: diff --git a/docs/releases/4.2.txt b/docs/releases/4.2.txt index e7d2882edb..7b3ca37b35 100644 --- a/docs/releases/4.2.txt +++ b/docs/releases/4.2.txt @@ -395,6 +395,9 @@ Miscellaneous * The undocumented ``negated`` parameter of the :class:`~django.db.models.Exists` expression is removed. +* The ``is_summary`` argument of the undocumented ``Query.add_annotation()`` + method is removed. + .. _deprecated-features-4.2: Features deprecated in 4.2 diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index c098716cca..4e860a7aa3 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1258,11 +1258,11 @@ class AggregateTestCase(TestCase): self.assertEqual(author.sum_age, other_author.sum_age) def test_aggregate_over_aggregate(self): - msg = "Cannot compute Avg('age'): 'age' is an aggregate" + msg = "Cannot resolve keyword 'age_agg' into field." with self.assertRaisesMessage(FieldError, msg): - Author.objects.annotate(age_alias=F("age"),).aggregate( - age=Sum(F("age")), - avg_age=Avg(F("age")), + Author.objects.aggregate( + age_agg=Sum(F("age")), + avg_age=Avg(F("age_agg")), ) def test_annotated_aggregate_over_annotated_aggregate(self): @@ -2086,6 +2086,14 @@ class AggregateTestCase(TestCase): ) self.assertEqual(len(qs), 6) + def test_aggregation_over_annotation_shared_alias(self): + self.assertEqual( + Publisher.objects.annotate(agg=Count("book__authors"),).aggregate( + agg=Count("agg"), + ), + {"agg": 5}, + ) + class AggregateAnnotationPruningTests(TestCase): @classmethod