1
0
mirror of https://github.com/django/django.git synced 2025-03-08 08:22:33 +00:00

Refs -- Updated Aggregate class to return consistent source expressions.

Refactored the filter and order_by expressions in the Aggregate class to
return a list of Expression (or None) values, ensuring that the list
item is always available and represents the filter expression.
For the PostgreSQL OrderableAggMixin, the returned list will always
include the filter and the order_by value as the last two elements.

Lastly, emtpy Q objects passed directly into aggregate objects using
Aggregate.filter in admin facets are filtered out when resolving the
expression to avoid errors in get_refs().

Thanks Simon Charette for the review.
This commit is contained in:
Chris Muthig 2024-04-03 16:06:39 -06:00 committed by nessita
parent ec8552417d
commit 42b567ab4c
3 changed files with 12 additions and 13 deletions
django
contrib/postgres/aggregates
db/models
tests/aggregation

@ -17,13 +17,10 @@ class OrderableAggMixin:
return super().resolve_expression(*args, **kwargs) return super().resolve_expression(*args, **kwargs)
def get_source_expressions(self): def get_source_expressions(self):
if self.order_by is not None: return super().get_source_expressions() + [self.order_by]
return super().get_source_expressions() + [self.order_by]
return super().get_source_expressions()
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
if isinstance(exprs[-1], OrderByList): *exprs, self.order_by = exprs
*exprs, self.order_by = exprs
return super().set_source_expressions(exprs) return super().set_source_expressions(exprs)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):

@ -50,12 +50,10 @@ class Aggregate(Func):
def get_source_expressions(self): def get_source_expressions(self):
source_expressions = super().get_source_expressions() source_expressions = super().get_source_expressions()
if self.filter: return source_expressions + [self.filter]
return source_expressions + [self.filter]
return source_expressions
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.filter = self.filter and exprs.pop() *exprs, self.filter = exprs
return super().set_source_expressions(exprs) return super().set_source_expressions(exprs)
def resolve_expression( def resolve_expression(
@ -63,8 +61,10 @@ class Aggregate(Func):
): ):
# Aggregates are not allowed in UPDATE queries, so ignore for_save # Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize) c = super().resolve_expression(query, allow_joins, reuse, summarize)
c.filter = c.filter and c.filter.resolve_expression( c.filter = (
query, allow_joins, reuse, summarize c.filter.resolve_expression(query, allow_joins, reuse, summarize)
if c.filter
else None
) )
if summarize: if summarize:
# Summarized aggregates cannot refer to summarized aggregates. # Summarized aggregates cannot refer to summarized aggregates.
@ -104,7 +104,9 @@ class Aggregate(Func):
@property @property
def default_alias(self): def default_alias(self):
expressions = self.get_source_expressions() expressions = [
expr for expr in self.get_source_expressions() if expr is not None
]
if len(expressions) == 1 and hasattr(expressions[0], "name"): if len(expressions) == 1 and hasattr(expressions[0], "name"):
return "%s__%s" % (expressions[0].name, self.name.lower()) return "%s__%s" % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias") raise TypeError("Complex expressions require an alias")

@ -1291,7 +1291,7 @@ class AggregateTestCase(TestCase):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
copy = self.copy() copy = self.copy()
copy.set_source_expressions(copy.get_source_expressions()[0:1]) copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
return super(MyMax, copy).as_sql(compiler, connection) return super(MyMax, copy).as_sql(compiler, connection)
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"): with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):