mirror of
				https://github.com/django/django.git
				synced 2025-10-30 17:16:10 +00:00 
			
		
		
		
	Fixed #31679 -- Delayed annotating aggregations.
By avoiding to annotate aggregations meant to be possibly pushed to an outer query until their references are resolved it is possible to aggregate over a query with the same alias. Even if #34176 is a convoluted case to support, this refactor seems worth it given the reduction in complexity it brings with regards to annotation removal when performing a subquery pushdown.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							d526d1569c
						
					
				
				
					commit
					1297c0d0d7
				
			| @@ -23,7 +23,7 @@ from django.db import ( | |||||||
| from django.db.models import AutoField, DateField, DateTimeField, Field, sql | from django.db.models import AutoField, DateField, DateTimeField, Field, sql | ||||||
| from django.db.models.constants import LOOKUP_SEP, OnConflict | from django.db.models.constants import LOOKUP_SEP, OnConflict | ||||||
| from django.db.models.deletion import Collector | from django.db.models.deletion import Collector | ||||||
| from django.db.models.expressions import Case, F, Ref, Value, When | from django.db.models.expressions import Case, F, Value, When | ||||||
| from django.db.models.functions import Cast, Trunc | from django.db.models.functions import Cast, Trunc | ||||||
| from django.db.models.query_utils import FilteredRelation, Q | from django.db.models.query_utils import FilteredRelation, Q | ||||||
| from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE | from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE | ||||||
| @@ -589,24 +589,7 @@ class QuerySet(AltersData): | |||||||
|                 raise TypeError("Complex aggregates require an alias") |                 raise TypeError("Complex aggregates require an alias") | ||||||
|             kwargs[arg.default_alias] = arg |             kwargs[arg.default_alias] = arg | ||||||
|  |  | ||||||
|         query = self.query.chain() |         return self.query.chain().get_aggregation(self.db, kwargs) | ||||||
|         for (alias, aggregate_expr) in kwargs.items(): |  | ||||||
|             query.add_annotation(aggregate_expr, alias, is_summary=True) |  | ||||||
|             annotation = query.annotations[alias] |  | ||||||
|             if not annotation.contains_aggregate: |  | ||||||
|                 raise TypeError("%s is not an aggregate expression" % alias) |  | ||||||
|             for expr in annotation.get_source_expressions(): |  | ||||||
|                 if ( |  | ||||||
|                     expr.contains_aggregate |  | ||||||
|                     and isinstance(expr, Ref) |  | ||||||
|                     and expr.refs in kwargs |  | ||||||
|                 ): |  | ||||||
|                     name = expr.refs |  | ||||||
|                     raise exceptions.FieldError( |  | ||||||
|                         "Cannot compute %s('%s'): '%s' is an aggregate" |  | ||||||
|                         % (annotation.name, name, name) |  | ||||||
|                     ) |  | ||||||
|         return query.get_aggregation(self.db, kwargs) |  | ||||||
|  |  | ||||||
|     async def aaggregate(self, *args, **kwargs): |     async def aaggregate(self, *args, **kwargs): | ||||||
|         return await sync_to_async(self.aggregate)(*args, **kwargs) |         return await sync_to_async(self.aggregate)(*args, **kwargs) | ||||||
| @@ -1655,7 +1638,6 @@ class QuerySet(AltersData): | |||||||
|                 clone.query.add_annotation( |                 clone.query.add_annotation( | ||||||
|                     annotation, |                     annotation, | ||||||
|                     alias, |                     alias, | ||||||
|                     is_summary=False, |  | ||||||
|                     select=select, |                     select=select, | ||||||
|                 ) |                 ) | ||||||
|         for alias, annotation in clone.query.annotations.items(): |         for alias, annotation in clone.query.annotations.items(): | ||||||
|   | |||||||
| @@ -381,24 +381,28 @@ class Query(BaseExpression): | |||||||
|             alias = None |             alias = None | ||||||
|         return target.get_col(alias, field) |         return target.get_col(alias, field) | ||||||
|  |  | ||||||
|     def get_aggregation(self, using, added_aggregate_names): |     def get_aggregation(self, using, aggregate_exprs): | ||||||
|         """ |         """ | ||||||
|         Return the dictionary with the values of the existing aggregations. |         Return the dictionary with the values of the existing aggregations. | ||||||
|         """ |         """ | ||||||
|         if not self.annotation_select: |         if not aggregate_exprs: | ||||||
|             return {} |             return {} | ||||||
|         existing_annotations = { |         aggregates = {} | ||||||
|             alias: annotation |         for alias, aggregate_expr in aggregate_exprs.items(): | ||||||
|             for alias, annotation in self.annotations.items() |             self.check_alias(alias) | ||||||
|             if alias not in added_aggregate_names |             aggregate = aggregate_expr.resolve_expression( | ||||||
|         } |                 self, allow_joins=True, reuse=None, summarize=True | ||||||
|  |             ) | ||||||
|  |             if not aggregate.contains_aggregate: | ||||||
|  |                 raise TypeError("%s is not an aggregate expression" % alias) | ||||||
|  |             aggregates[alias] = aggregate | ||||||
|         # Existing usage of aggregation can be determined by the presence of |         # Existing usage of aggregation can be determined by the presence of | ||||||
|         # selected aggregates but also by filters against aliased aggregates. |         # selected aggregates but also by filters against aliased aggregates. | ||||||
|         _, having, qualify = self.where.split_having_qualify() |         _, having, qualify = self.where.split_having_qualify() | ||||||
|         has_existing_aggregation = ( |         has_existing_aggregation = ( | ||||||
|             any( |             any( | ||||||
|                 getattr(annotation, "contains_aggregate", True) |                 getattr(annotation, "contains_aggregate", True) | ||||||
|                 for annotation in existing_annotations.values() |                 for annotation in self.annotations.values() | ||||||
|             ) |             ) | ||||||
|             or having |             or having | ||||||
|         ) |         ) | ||||||
| @@ -449,25 +453,19 @@ class Query(BaseExpression): | |||||||
|                     # filtering against window functions is involved as it |                     # filtering against window functions is involved as it | ||||||
|                     # requires complex realising. |                     # requires complex realising. | ||||||
|                     annotation_mask = set() |                     annotation_mask = set() | ||||||
|                     for name in added_aggregate_names: |                     for aggregate in aggregates.values(): | ||||||
|                         annotation_mask.add(name) |                         annotation_mask |= aggregate.get_refs() | ||||||
|                         annotation_mask |= inner_query.annotations[name].get_refs() |  | ||||||
|                     inner_query.set_annotation_mask(annotation_mask) |                     inner_query.set_annotation_mask(annotation_mask) | ||||||
|  |  | ||||||
|             # Remove any aggregates marked for reduction from the subquery and |             # Add aggregates to the outer AggregateQuery. This requires making | ||||||
|             # move them to the outer AggregateQuery. This requires making sure |             # sure all columns referenced by the aggregates are selected in the | ||||||
|             # all columns referenced by the aggregates are selected in the |             # inner query. It is achieved by retrieving all column references | ||||||
|             # subquery. It is achieved by retrieving all column references from |             # by the aggregates, explicitly selecting them in the inner query, | ||||||
|             # the aggregates, explicitly selecting them if they are not |             # and making sure the aggregates are repointed to them. | ||||||
|             # already, and making sure the aggregates are repointed to |  | ||||||
|             # referenced to them. |  | ||||||
|             col_refs = {} |             col_refs = {} | ||||||
|             for alias, expression in list(inner_query.annotation_select.items()): |             for alias, aggregate in aggregates.items(): | ||||||
|                 if not expression.is_summary: |  | ||||||
|                     continue |  | ||||||
|                 annotation_select_mask = inner_query.annotation_select_mask |  | ||||||
|                 replacements = {} |                 replacements = {} | ||||||
|                 for col in self._gen_cols([expression], resolve_refs=False): |                 for col in self._gen_cols([aggregate], resolve_refs=False): | ||||||
|                     if not (col_ref := col_refs.get(col)): |                     if not (col_ref := col_refs.get(col)): | ||||||
|                         index = len(col_refs) + 1 |                         index = len(col_refs) + 1 | ||||||
|                         col_alias = f"__col{index}" |                         col_alias = f"__col{index}" | ||||||
| @@ -476,13 +474,9 @@ class Query(BaseExpression): | |||||||
|                         inner_query.annotations[col_alias] = col |                         inner_query.annotations[col_alias] = col | ||||||
|                         inner_query.append_annotation_mask([col_alias]) |                         inner_query.append_annotation_mask([col_alias]) | ||||||
|                     replacements[col] = col_ref |                     replacements[col] = col_ref | ||||||
|                 outer_query.annotations[alias] = expression.replace_expressions( |                 outer_query.annotations[alias] = aggregate.replace_expressions( | ||||||
|                     replacements |                     replacements | ||||||
|                 ) |                 ) | ||||||
|                 del inner_query.annotations[alias] |  | ||||||
|                 annotation_select_mask.remove(alias) |  | ||||||
|                 # Make sure the annotation_select wont use cached results. |  | ||||||
|                 inner_query.set_annotation_mask(inner_query.annotation_select_mask) |  | ||||||
|             if ( |             if ( | ||||||
|                 inner_query.select == () |                 inner_query.select == () | ||||||
|                 and not inner_query.default_cols |                 and not inner_query.default_cols | ||||||
| @@ -499,19 +493,21 @@ class Query(BaseExpression): | |||||||
|             self.select = () |             self.select = () | ||||||
|             self.default_cols = False |             self.default_cols = False | ||||||
|             self.extra = {} |             self.extra = {} | ||||||
|             if existing_annotations: |             if self.annotations: | ||||||
|                 # Inline reference to existing annotations and mask them as |                 # Inline reference to existing annotations and mask them as | ||||||
|                 # they are unnecessary given only the summarized aggregations |                 # they are unnecessary given only the summarized aggregations | ||||||
|                 # are requested. |                 # are requested. | ||||||
|                 replacements = { |                 replacements = { | ||||||
|                     Ref(alias, annotation): annotation |                     Ref(alias, annotation): annotation | ||||||
|                     for alias, annotation in existing_annotations.items() |                     for alias, annotation in self.annotations.items() | ||||||
|                 } |                 } | ||||||
|                 for name in added_aggregate_names: |                 self.annotations = { | ||||||
|                     self.annotations[name] = self.annotations[name].replace_expressions( |                     alias: aggregate.replace_expressions(replacements) | ||||||
|                         replacements |                     for alias, aggregate in aggregates.items() | ||||||
|                     ) |                 } | ||||||
|                 self.set_annotation_mask(added_aggregate_names) |             else: | ||||||
|  |                 self.annotations = aggregates | ||||||
|  |             self.set_annotation_mask(aggregates) | ||||||
|  |  | ||||||
|         empty_set_result = [ |         empty_set_result = [ | ||||||
|             expression.empty_result_set_value |             expression.empty_result_set_value | ||||||
| @@ -537,8 +533,7 @@ class Query(BaseExpression): | |||||||
|         Perform a COUNT() query using the current filter constraints. |         Perform a COUNT() query using the current filter constraints. | ||||||
|         """ |         """ | ||||||
|         obj = self.clone() |         obj = self.clone() | ||||||
|         obj.add_annotation(Count("*"), alias="__count", is_summary=True) |         return obj.get_aggregation(using, {"__count": Count("*")})["__count"] | ||||||
|         return obj.get_aggregation(using, ["__count"])["__count"] |  | ||||||
|  |  | ||||||
|     def has_filters(self): |     def has_filters(self): | ||||||
|         return self.where |         return self.where | ||||||
| @@ -1085,12 +1080,10 @@ class Query(BaseExpression): | |||||||
|                 "semicolons, or SQL comments." |                 "semicolons, or SQL comments." | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def add_annotation(self, annotation, alias, is_summary=False, select=True): |     def add_annotation(self, annotation, alias, select=True): | ||||||
|         """Add a single annotation expression to the Query.""" |         """Add a single annotation expression to the Query.""" | ||||||
|         self.check_alias(alias) |         self.check_alias(alias) | ||||||
|         annotation = annotation.resolve_expression( |         annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None) | ||||||
|             self, allow_joins=True, reuse=None, summarize=is_summary |  | ||||||
|         ) |  | ||||||
|         if select: |         if select: | ||||||
|             self.append_annotation_mask([alias]) |             self.append_annotation_mask([alias]) | ||||||
|         else: |         else: | ||||||
|   | |||||||
| @@ -395,6 +395,9 @@ Miscellaneous | |||||||
| * The undocumented ``negated`` parameter of the | * The undocumented ``negated`` parameter of the | ||||||
|   :class:`~django.db.models.Exists` expression is removed. |   :class:`~django.db.models.Exists` expression is removed. | ||||||
|  |  | ||||||
|  | * The ``is_summary`` argument of the undocumented ``Query.add_annotation()`` | ||||||
|  |   method is removed. | ||||||
|  |  | ||||||
| .. _deprecated-features-4.2: | .. _deprecated-features-4.2: | ||||||
|  |  | ||||||
| Features deprecated in 4.2 | Features deprecated in 4.2 | ||||||
|   | |||||||
| @@ -1258,11 +1258,11 @@ class AggregateTestCase(TestCase): | |||||||
|         self.assertEqual(author.sum_age, other_author.sum_age) |         self.assertEqual(author.sum_age, other_author.sum_age) | ||||||
|  |  | ||||||
|     def test_aggregate_over_aggregate(self): |     def test_aggregate_over_aggregate(self): | ||||||
|         msg = "Cannot compute Avg('age'): 'age' is an aggregate" |         msg = "Cannot resolve keyword 'age_agg' into field." | ||||||
|         with self.assertRaisesMessage(FieldError, msg): |         with self.assertRaisesMessage(FieldError, msg): | ||||||
|             Author.objects.annotate(age_alias=F("age"),).aggregate( |             Author.objects.aggregate( | ||||||
|                 age=Sum(F("age")), |                 age_agg=Sum(F("age")), | ||||||
|                 avg_age=Avg(F("age")), |                 avg_age=Avg(F("age_agg")), | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def test_annotated_aggregate_over_annotated_aggregate(self): |     def test_annotated_aggregate_over_annotated_aggregate(self): | ||||||
| @@ -2086,6 +2086,14 @@ class AggregateTestCase(TestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertEqual(len(qs), 6) |         self.assertEqual(len(qs), 6) | ||||||
|  |  | ||||||
|  |     def test_aggregation_over_annotation_shared_alias(self): | ||||||
|  |         self.assertEqual( | ||||||
|  |             Publisher.objects.annotate(agg=Count("book__authors"),).aggregate( | ||||||
|  |                 agg=Count("agg"), | ||||||
|  |             ), | ||||||
|  |             {"agg": 5}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AggregateAnnotationPruningTests(TestCase): | class AggregateAnnotationPruningTests(TestCase): | ||||||
|     @classmethod |     @classmethod | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user