From 42b567ab4c5bfb1bbd3e629b1079271c5ae44ea0 Mon Sep 17 00:00:00 2001 From: Chris Muthig Date: Wed, 3 Apr 2024 16:06:39 -0600 Subject: [PATCH] Refs #35339 -- Updated Aggregate class to return consistent source expressions. Refactored the filter and order_by expressions in the Aggregate class to return a list of Expression (or None) values, ensuring that the list item is always available and represents the filter expression. For the PostgreSQL OrderableAggMixin, the returned list will always include the filter and the order_by value as the last two elements. Lastly, emtpy Q objects passed directly into aggregate objects using Aggregate.filter in admin facets are filtered out when resolving the expression to avoid errors in get_refs(). Thanks Simon Charette for the review. --- django/contrib/postgres/aggregates/mixins.py | 7 ++----- django/db/models/aggregates.py | 16 +++++++++------- tests/aggregation/tests.py | 2 +- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py index 68f24a5ce3..527626da4c 100644 --- a/django/contrib/postgres/aggregates/mixins.py +++ b/django/contrib/postgres/aggregates/mixins.py @@ -17,13 +17,10 @@ class OrderableAggMixin: return super().resolve_expression(*args, **kwargs) def get_source_expressions(self): - if self.order_by is not None: - return super().get_source_expressions() + [self.order_by] - return super().get_source_expressions() + return super().get_source_expressions() + [self.order_by] def set_source_expressions(self, exprs): - if isinstance(exprs[-1], OrderByList): - *exprs, self.order_by = exprs + *exprs, self.order_by = exprs return super().set_source_expressions(exprs) def as_sql(self, compiler, connection): diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 0cbffacd1b..bf94decab7 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -50,12 +50,10 @@ class Aggregate(Func): def get_source_expressions(self): source_expressions = super().get_source_expressions() - if self.filter: - return source_expressions + [self.filter] - return source_expressions + return source_expressions + [self.filter] def set_source_expressions(self, exprs): - self.filter = self.filter and exprs.pop() + *exprs, self.filter = exprs return super().set_source_expressions(exprs) def resolve_expression( @@ -63,8 +61,10 @@ class Aggregate(Func): ): # Aggregates are not allowed in UPDATE queries, so ignore for_save c = super().resolve_expression(query, allow_joins, reuse, summarize) - c.filter = c.filter and c.filter.resolve_expression( - query, allow_joins, reuse, summarize + c.filter = ( + c.filter.resolve_expression(query, allow_joins, reuse, summarize) + if c.filter + else None ) if summarize: # Summarized aggregates cannot refer to summarized aggregates. @@ -104,7 +104,9 @@ class Aggregate(Func): @property def default_alias(self): - expressions = self.get_source_expressions() + expressions = [ + expr for expr in self.get_source_expressions() if expr is not None + ] if len(expressions) == 1 and hasattr(expressions[0], "name"): return "%s__%s" % (expressions[0].name, self.name.lower()) raise TypeError("Complex expressions require an alias") diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 4408535228..075e707102 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1291,7 +1291,7 @@ class AggregateTestCase(TestCase): def as_sql(self, compiler, connection): copy = self.copy() - copy.set_source_expressions(copy.get_source_expressions()[0:1]) + copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None]) return super(MyMax, copy).as_sql(compiler, connection) with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):