diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py index 68f24a5ce3..527626da4c 100644 --- a/django/contrib/postgres/aggregates/mixins.py +++ b/django/contrib/postgres/aggregates/mixins.py @@ -17,13 +17,10 @@ class OrderableAggMixin: return super().resolve_expression(*args, **kwargs) def get_source_expressions(self): - if self.order_by is not None: - return super().get_source_expressions() + [self.order_by] - return super().get_source_expressions() + return super().get_source_expressions() + [self.order_by] def set_source_expressions(self, exprs): - if isinstance(exprs[-1], OrderByList): - *exprs, self.order_by = exprs + *exprs, self.order_by = exprs return super().set_source_expressions(exprs) def as_sql(self, compiler, connection): diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 0cbffacd1b..bf94decab7 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -50,12 +50,10 @@ class Aggregate(Func): def get_source_expressions(self): source_expressions = super().get_source_expressions() - if self.filter: - return source_expressions + [self.filter] - return source_expressions + return source_expressions + [self.filter] def set_source_expressions(self, exprs): - self.filter = self.filter and exprs.pop() + *exprs, self.filter = exprs return super().set_source_expressions(exprs) def resolve_expression( @@ -63,8 +61,10 @@ class Aggregate(Func): ): # Aggregates are not allowed in UPDATE queries, so ignore for_save c = super().resolve_expression(query, allow_joins, reuse, summarize) - c.filter = c.filter and c.filter.resolve_expression( - query, allow_joins, reuse, summarize + c.filter = ( + c.filter.resolve_expression(query, allow_joins, reuse, summarize) + if c.filter + else None ) if summarize: # Summarized aggregates cannot refer to summarized aggregates. @@ -104,7 +104,9 @@ class Aggregate(Func): @property def default_alias(self): - expressions = self.get_source_expressions() + expressions = [ + expr for expr in self.get_source_expressions() if expr is not None + ] if len(expressions) == 1 and hasattr(expressions[0], "name"): return "%s__%s" % (expressions[0].name, self.name.lower()) raise TypeError("Complex expressions require an alias") diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 4408535228..075e707102 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1291,7 +1291,7 @@ class AggregateTestCase(TestCase): def as_sql(self, compiler, connection): copy = self.copy() - copy.set_source_expressions(copy.get_source_expressions()[0:1]) + copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None]) return super(MyMax, copy).as_sql(compiler, connection) with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):