1
0
mirror of https://github.com/django/django.git synced 2025-10-26 15:16:09 +00:00

Fixed #10929 -- Added default argument to aggregates.

Thanks to Simon Charette and Adam Johnson for the reviews.
This commit is contained in:
Nick Pope
2021-02-21 01:38:55 +00:00
committed by Mariusz Felisiak
parent 59942a66ce
commit 501a8db465
11 changed files with 393 additions and 64 deletions

View File

@@ -18,7 +18,7 @@ class ArrayAgg(OrderableAggMixin, Aggregate):
return ArrayField(self.source_expressions[0].output_field)
def convert_value(self, value, expression, connection):
if not value:
if value is None and self.default is None:
return []
return value
@@ -48,7 +48,7 @@ class JSONBAgg(OrderableAggMixin, Aggregate):
output_field = JSONField()
def convert_value(self, value, expression, connection):
if not value:
if value is None and self.default is None:
return '[]'
return value
@@ -63,6 +63,6 @@ class StringAgg(OrderableAggMixin, Aggregate):
super().__init__(expression, delimiter_expr, **extra)
def convert_value(self, value, expression, connection):
if not value:
if value is None and self.default is None:
return ''
return value

View File

@@ -9,10 +9,10 @@ __all__ = [
class StatAggregate(Aggregate):
output_field = FloatField()
def __init__(self, y, x, output_field=None, filter=None):
def __init__(self, y, x, output_field=None, filter=None, default=None):
if not x or not y:
raise ValueError('Both y and x must be provided.')
super().__init__(y, x, output_field=output_field, filter=filter)
super().__init__(y, x, output_field=output_field, filter=filter, default=default)
class Corr(StatAggregate):
@@ -20,9 +20,9 @@ class Corr(StatAggregate):
class CovarPop(StatAggregate):
def __init__(self, y, x, sample=False, filter=None):
def __init__(self, y, x, sample=False, filter=None, default=None):
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
super().__init__(y, x, filter=filter)
super().__init__(y, x, filter=filter, default=default)
class RegrAvgX(StatAggregate):

View File

@@ -88,6 +88,17 @@ class DatabaseFeatures(BaseDatabaseFeatures):
'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
},
})
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,):
skips.update({
'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python',
},
'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database',
},
})
if (
self.connection.mysql_is_mariadb and
(10, 4, 3) < self.connection.mysql_version < (10, 5, 2)

View File

@@ -4,6 +4,7 @@ Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import (
FixDurationInputMixin, NumericOutputFieldMixin,
)
@@ -22,11 +23,14 @@ class Aggregate(Func):
allow_distinct = False
empty_aggregate_value = None
def __init__(self, *expressions, distinct=False, filter=None, **extra):
def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if default is not None and self.empty_aggregate_value is not None:
raise TypeError(f'{self.__class__.__name__} does not allow default.')
self.distinct = distinct
self.filter = filter
self.default = default
super().__init__(*expressions, **extra)
def get_source_fields(self):
@@ -56,7 +60,12 @@ class Aggregate(Func):
before_resolved = self.get_source_expressions()[index]
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
return c
if (default := c.default) is None:
return c
if hasattr(default, 'resolve_expression'):
default = default.resolve_expression(query, allow_joins, reuse, summarize)
c.default = None # Reset the default argument before wrapping.
return Coalesce(c, default, output_field=c._output_field_or_none)
@property
def default_alias(self):