mirror of
https://github.com/django/django.git
synced 2025-03-12 10:22:37 +00:00
Fixed #35444 -- Added generic support for Aggregate.order_by.
This moves the behaviors of `order_by` used in Postgres aggregates into the `Aggregate` class. This allows for creating aggregate functions that support this behavior across all database engines. This is shown by moving the `StringAgg` class into the shared `aggregates` module and adding support for all databases. The Postgres `StringAgg` class is now a thin wrapper on the new shared `StringAgg` class. Thank you Simon Charette for the review.
This commit is contained in:
parent
6d1cf5375f
commit
4b977a5d72
@ -33,15 +33,14 @@ class GeoAggregate(Aggregate):
|
||||
if not self.is_extent:
|
||||
tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05)
|
||||
clone = self.copy()
|
||||
source_expressions = self.get_source_expressions()
|
||||
source_expressions.pop() # Don't wrap filters with SDOAGGRTYPE().
|
||||
*source_exprs, filter_expr, order_by_expr = self.get_source_expressions()
|
||||
spatial_type_expr = Func(
|
||||
*source_expressions,
|
||||
*source_exprs,
|
||||
Value(tolerance),
|
||||
function="SDOAGGRTYPE",
|
||||
output_field=self.output_field,
|
||||
)
|
||||
source_expressions = [spatial_type_expr, self.filter]
|
||||
source_expressions = [spatial_type_expr, filter_expr, order_by_expr]
|
||||
clone.set_source_expressions(source_expressions)
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
@ -1,7 +1,12 @@
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db.models import Aggregate, BooleanField, JSONField, TextField, Value
|
||||
import warnings
|
||||
|
||||
from .mixins import OrderableAggMixin
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db.models import Aggregate, BooleanField, JSONField
|
||||
from django.db.models import StringAgg as _StringAgg
|
||||
from django.db.models import Value
|
||||
from django.utils.deprecation import RemovedInDjango70Warning
|
||||
|
||||
from .mixins import _DeprecatedOrdering
|
||||
|
||||
__all__ = [
|
||||
"ArrayAgg",
|
||||
@ -11,14 +16,16 @@ __all__ = [
|
||||
"BoolAnd",
|
||||
"BoolOr",
|
||||
"JSONBAgg",
|
||||
"StringAgg",
|
||||
"StringAgg", # RemovedInDjango70Warning.
|
||||
]
|
||||
|
||||
|
||||
class ArrayAgg(OrderableAggMixin, Aggregate):
|
||||
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||
# class ArrayAgg(Aggregate):
|
||||
class ArrayAgg(_DeprecatedOrdering, Aggregate):
|
||||
function = "ARRAY_AGG"
|
||||
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
|
||||
allow_distinct = True
|
||||
allow_order_by = True
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
@ -47,19 +54,37 @@ class BoolOr(Aggregate):
|
||||
output_field = BooleanField()
|
||||
|
||||
|
||||
class JSONBAgg(OrderableAggMixin, Aggregate):
|
||||
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||
# class JSONBAgg(Aggregate):
|
||||
class JSONBAgg(_DeprecatedOrdering, Aggregate):
|
||||
function = "JSONB_AGG"
|
||||
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
|
||||
allow_distinct = True
|
||||
allow_order_by = True
|
||||
output_field = JSONField()
|
||||
|
||||
|
||||
class StringAgg(OrderableAggMixin, Aggregate):
|
||||
function = "STRING_AGG"
|
||||
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
|
||||
allow_distinct = True
|
||||
output_field = TextField()
|
||||
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||
# class StringAgg(_StringAgg):
|
||||
# RemovedInDjango70Warning: When the deprecation ends, remove completely.
|
||||
class StringAgg(_DeprecatedOrdering, _StringAgg):
|
||||
|
||||
def __init__(self, expression, delimiter, **extra):
|
||||
delimiter_expr = Value(str(delimiter))
|
||||
super().__init__(expression, delimiter_expr, **extra)
|
||||
if isinstance(delimiter, str):
|
||||
warnings.warn(
|
||||
"delimiter: str will be resolved as a field reference instead "
|
||||
"of a string literal on Django 7.0. Pass "
|
||||
f"`delimiter=Value({delimiter!r})` to preserve the previous behaviour.",
|
||||
category=RemovedInDjango70Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
delimiter = Value(delimiter)
|
||||
|
||||
warnings.warn(
|
||||
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||
"django.db.models.aggregate.StringAgg instead.",
|
||||
category=RemovedInDjango70Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
super().__init__(expression, delimiter, **extra)
|
||||
|
@ -1,15 +1,11 @@
|
||||
import warnings
|
||||
|
||||
from django.core.exceptions import FullResultSet
|
||||
from django.db.models.expressions import OrderByList
|
||||
from django.utils.deprecation import RemovedInDjango61Warning
|
||||
|
||||
|
||||
class OrderableAggMixin:
|
||||
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||
# def __init__(self, *expressions, order_by=(), **extra):
|
||||
# RemovedInDjango61Warning.
|
||||
class _DeprecatedOrdering:
|
||||
def __init__(self, *expressions, ordering=(), order_by=(), **extra):
|
||||
# RemovedInDjango61Warning.
|
||||
if ordering:
|
||||
warnings.warn(
|
||||
"The ordering argument is deprecated. Use order_by instead.",
|
||||
@ -19,44 +15,14 @@ class OrderableAggMixin:
|
||||
if order_by:
|
||||
raise TypeError("Cannot specify both order_by and ordering.")
|
||||
order_by = ordering
|
||||
if not order_by:
|
||||
self.order_by = None
|
||||
elif isinstance(order_by, (list, tuple)):
|
||||
self.order_by = OrderByList(*order_by)
|
||||
else:
|
||||
self.order_by = OrderByList(order_by)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
if self.order_by is not None:
|
||||
self.order_by = self.order_by.resolve_expression(*args, **kwargs)
|
||||
return super().resolve_expression(*args, **kwargs)
|
||||
super().__init__(*expressions, order_by=order_by, **extra)
|
||||
|
||||
def get_source_expressions(self):
|
||||
return super().get_source_expressions() + [self.order_by]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
*exprs, self.order_by = exprs
|
||||
return super().set_source_expressions(exprs)
|
||||
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||
# class OrderableAggMixin:
|
||||
class OrderableAggMixin(_DeprecatedOrdering):
|
||||
allow_order_by = True
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
*source_exprs, filtering_expr, order_by_expr = self.get_source_expressions()
|
||||
|
||||
order_by_sql = ""
|
||||
order_by_params = []
|
||||
if order_by_expr is not None:
|
||||
order_by_sql, order_by_params = compiler.compile(order_by_expr)
|
||||
|
||||
filter_params = []
|
||||
if filtering_expr is not None:
|
||||
try:
|
||||
_, filter_params = compiler.compile(filtering_expr)
|
||||
except FullResultSet:
|
||||
pass
|
||||
|
||||
source_params = []
|
||||
for source_expr in source_exprs:
|
||||
source_params += compiler.compile(source_expr)[1]
|
||||
|
||||
sql, _ = super().as_sql(compiler, connection, order_by=order_by_sql)
|
||||
return sql, (*source_params, *order_by_params, *filter_params)
|
||||
def __init_subclass__(cls, /, *args, **kwargs):
|
||||
super().__init_subclass__(*args, **kwargs)
|
||||
|
@ -257,6 +257,15 @@ class BaseDatabaseFeatures:
|
||||
# expressions?
|
||||
supports_aggregate_filter_clause = False
|
||||
|
||||
# Does the database support ORDER BY in aggregate expressions?
|
||||
supports_aggregate_order_by_clause = False
|
||||
|
||||
# Does the database backend support DISTINCT when using multiple arguments in an
|
||||
# aggregate expression? For example, Sqlite treats the "delimiter" argument of
|
||||
# STRING_AGG/GROUP_CONCAT as an extra argument and does not allow using a custom
|
||||
# delimiter along with DISTINCT.
|
||||
supports_aggregate_distinct_multiple_argument = True
|
||||
|
||||
# Does the backend support indexing a TextField?
|
||||
supports_index_on_text_field = True
|
||||
|
||||
|
@ -19,6 +19,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
requires_explicit_null_ordering_when_grouping = True
|
||||
atomic_transactions = False
|
||||
can_clone_databases = True
|
||||
supports_aggregate_order_by_clause = True
|
||||
supports_comments = True
|
||||
supports_comments_inline = True
|
||||
supports_temporal_subtraction = True
|
||||
|
@ -45,6 +45,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
# does by uppercasing all identifiers.
|
||||
ignores_table_name_case = True
|
||||
supports_index_on_text_field = False
|
||||
supports_aggregate_order_by_clause = True
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE "TEST_PROCEDURE" AS
|
||||
V_I INTEGER;
|
||||
|
@ -64,6 +64,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
supports_frame_exclusion = True
|
||||
only_supports_unbounded_with_preceding_and_following = True
|
||||
supports_aggregate_filter_clause = True
|
||||
supports_aggregate_order_by_clause = True
|
||||
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
|
||||
supports_deferrable_unique_constraints = True
|
||||
has_json_operators = True
|
||||
|
@ -34,6 +34,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
supports_frame_range_fixed_distance = True
|
||||
supports_frame_exclusion = True
|
||||
supports_aggregate_filter_clause = True
|
||||
supports_aggregate_order_by_clause = Database.sqlite_version_info >= (3, 44, 0)
|
||||
supports_aggregate_distinct_multiple_argument = False
|
||||
order_by_nulls_first = True
|
||||
supports_json_field_contains = False
|
||||
supports_update_conflicts = True
|
||||
|
@ -3,8 +3,17 @@ Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import (
|
||||
Case,
|
||||
ColPairs,
|
||||
Func,
|
||||
OrderByList,
|
||||
Star,
|
||||
Value,
|
||||
When,
|
||||
)
|
||||
from django.db.models.fields import IntegerField, TextField
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin,
|
||||
@ -18,42 +27,91 @@ __all__ = [
|
||||
"Max",
|
||||
"Min",
|
||||
"StdDev",
|
||||
"StringAgg",
|
||||
"Sum",
|
||||
"Variance",
|
||||
]
|
||||
|
||||
|
||||
class AggregateFilter(Func):
|
||||
arity = 1
|
||||
template = " FILTER (WHERE %(expressions)s)"
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_aggregate_filter_clause:
|
||||
raise NotSupportedError(
|
||||
"Aggregate filter clauses are not supported on this database backend."
|
||||
)
|
||||
try:
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
except FullResultSet:
|
||||
return "", ()
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
return self.source_expressions[0]
|
||||
|
||||
def __str__(self):
|
||||
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||
|
||||
|
||||
class AggregateOrderBy(OrderByList):
|
||||
template = " ORDER BY %(expressions)s"
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_aggregate_order_by_clause:
|
||||
raise NotSupportedError(
|
||||
"This database backend does not support specifying an order on "
|
||||
"aggregates."
|
||||
)
|
||||
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Aggregate(Func):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s)"
|
||||
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
filter_template = "%s FILTER (WHERE %%(filter)s)"
|
||||
window_compatible = True
|
||||
allow_distinct = False
|
||||
allow_order_by = False
|
||||
empty_result_set_value = None
|
||||
|
||||
def __init__(
|
||||
self, *expressions, distinct=False, filter=None, default=None, **extra
|
||||
self,
|
||||
*expressions,
|
||||
distinct=False,
|
||||
filter=None,
|
||||
default=None,
|
||||
order_by=None,
|
||||
**extra,
|
||||
):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if order_by and not self.allow_order_by:
|
||||
raise TypeError("%s does not allow order_by." % self.__class__.__name__)
|
||||
if default is not None and self.empty_result_set_value is not None:
|
||||
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
||||
|
||||
self.distinct = distinct
|
||||
self.filter = filter
|
||||
self.filter = filter and AggregateFilter(filter)
|
||||
self.default = default
|
||||
self.order_by = AggregateOrderBy.from_param(
|
||||
f"{self.__class__.__name__}.order_by", order_by
|
||||
)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
# Don't return the filter expression since it's not a source field.
|
||||
# Don't consider filter and order by expression as they have nothing
|
||||
# to do with the output field resolution.
|
||||
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||
|
||||
def get_source_expressions(self):
|
||||
source_expressions = super().get_source_expressions()
|
||||
return source_expressions + [self.filter]
|
||||
return source_expressions + [self.filter, self.order_by]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
*exprs, self.filter = exprs
|
||||
*exprs, self.filter, self.order_by = exprs
|
||||
return super().set_source_expressions(exprs)
|
||||
|
||||
def resolve_expression(
|
||||
@ -66,6 +124,11 @@ class Aggregate(Func):
|
||||
if c.filter
|
||||
else None
|
||||
)
|
||||
c.order_by = (
|
||||
c.order_by.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if c.order_by
|
||||
else None
|
||||
)
|
||||
if summarize:
|
||||
# Summarized aggregates cannot refer to summarized aggregates.
|
||||
for ref in c.get_refs():
|
||||
@ -115,35 +178,45 @@ class Aggregate(Func):
|
||||
return []
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
|
||||
if self.filter:
|
||||
if connection.features.supports_aggregate_filter_clause:
|
||||
try:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
template = self.filter_template % extra_context.get(
|
||||
"template", self.template
|
||||
)
|
||||
sql, params = super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
filter=filter_sql,
|
||||
**extra_context,
|
||||
)
|
||||
return sql, (*params, *filter_params)
|
||||
else:
|
||||
if (
|
||||
self.distinct
|
||||
and not connection.features.supports_aggregate_distinct_multiple_argument
|
||||
and len(super().get_source_expressions()) > 1
|
||||
):
|
||||
raise NotSupportedError(
|
||||
f"{self.name} does not support distinct with multiple expressions on "
|
||||
f"this database backend."
|
||||
)
|
||||
|
||||
distinct_sql = "DISTINCT " if self.distinct else ""
|
||||
order_by_sql = ""
|
||||
order_by_params = []
|
||||
filter_sql = ""
|
||||
filter_params = []
|
||||
|
||||
if (order_by := self.order_by) is not None:
|
||||
order_by_sql, order_by_params = compiler.compile(order_by)
|
||||
|
||||
if self.filter is not None:
|
||||
try:
|
||||
filter_sql, filter_params = compiler.compile(self.filter)
|
||||
except NotSupportedError:
|
||||
# Fallback to a CASE statement on backends that don't support
|
||||
# the FILTER clause.
|
||||
copy = self.copy()
|
||||
copy.filter = None
|
||||
source_expressions = copy.get_source_expressions()
|
||||
condition = When(self.filter, then=source_expressions[0])
|
||||
condition = When(self.filter.condition, then=source_expressions[0])
|
||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||
return super(Aggregate, copy).as_sql(
|
||||
compiler, connection, **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
return copy.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
extra_context.update(
|
||||
distinct=distinct_sql,
|
||||
filter=filter_sql,
|
||||
order_by=order_by_sql,
|
||||
)
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
return sql, (*params, *order_by_params, *filter_params)
|
||||
|
||||
def _get_repr_options(self):
|
||||
options = super()._get_repr_options()
|
||||
@ -151,6 +224,8 @@ class Aggregate(Func):
|
||||
options["distinct"] = self.distinct
|
||||
if self.filter:
|
||||
options["filter"] = self.filter
|
||||
if self.order_by:
|
||||
options["order_by"] = self.order_by
|
||||
return options
|
||||
|
||||
|
||||
@ -179,17 +254,17 @@ class Count(Aggregate):
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
result = super().resolve_expression(*args, **kwargs)
|
||||
expr = result.source_expressions[0]
|
||||
source_expressions = result.get_source_expressions()
|
||||
|
||||
# In case of composite primary keys, count the first column.
|
||||
if isinstance(expr, ColPairs):
|
||||
if isinstance(expr := source_expressions[0], ColPairs):
|
||||
if self.distinct:
|
||||
raise ValueError(
|
||||
"COUNT(DISTINCT) doesn't support composite primary keys"
|
||||
)
|
||||
|
||||
cols = expr.get_cols()
|
||||
return Count(cols[0], filter=result.filter)
|
||||
source_expressions[0] = expr.get_cols()[0]
|
||||
result.set_source_expressions(source_expressions)
|
||||
|
||||
return result
|
||||
|
||||
@ -218,6 +293,88 @@ class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
||||
|
||||
|
||||
class StringAggDelimiter(Func):
|
||||
arity = 1
|
||||
template = "%(expressions)s"
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
super().__init__(value)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = " SEPARATOR %(expressions)s"
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class StringAgg(Aggregate):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
|
||||
function = "STRING_AGG"
|
||||
name = "StringAgg"
|
||||
allow_distinct = True
|
||||
allow_order_by = True
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, expression, delimiter, **extra):
|
||||
self.delimiter = StringAggDelimiter(delimiter)
|
||||
super().__init__(expression, self.delimiter, **extra)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.order_by:
|
||||
template = (
|
||||
"%(function)s(%(distinct)s%(expressions)s) WITHIN GROUP (%(order_by)s)"
|
||||
"%(filter)s"
|
||||
)
|
||||
else:
|
||||
template = "%(function)s(%(distinct)s%(expressions)s)%(filter)s"
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="LISTAGG",
|
||||
template=template,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
extra_context["function"] = "GROUP_CONCAT"
|
||||
|
||||
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s%(delimiter)s)"
|
||||
extra_context["template"] = template
|
||||
|
||||
c = self.copy()
|
||||
# The creation of the delimiter SQL and the ordering of the parameters must be
|
||||
# handled explicitly, as MySQL puts the delimiter at the end of the aggregate
|
||||
# using the `SEPARATOR` declaration (rather than treating as an expression like
|
||||
# other database backends).
|
||||
delimiter_params = []
|
||||
if c.delimiter:
|
||||
delimiter_sql, delimiter_params = compiler.compile(c.delimiter)
|
||||
# Drop the delimiter from the source expressions.
|
||||
c.source_expressions = c.source_expressions[:-1]
|
||||
extra_context["delimiter"] = delimiter_sql
|
||||
|
||||
sql, params = c.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
return sql, (*params, *delimiter_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if connection.get_database_version() < (3, 44):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="GROUP_CONCAT",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = "SUM"
|
||||
name = "Sum"
|
||||
|
@ -1481,6 +1481,21 @@ class OrderByList(ExpressionList):
|
||||
)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@classmethod
|
||||
def from_param(cls, context, param):
|
||||
if param is None:
|
||||
return None
|
||||
if isinstance(param, (list, tuple)):
|
||||
if not param:
|
||||
return None
|
||||
return cls(*param)
|
||||
elif isinstance(param, str) or hasattr(param, "resolve_expression"):
|
||||
return cls(param)
|
||||
raise ValueError(
|
||||
f"{context} must be either a string reference to a "
|
||||
f"field, an expression, or a list or tuple of them not {param!r}."
|
||||
)
|
||||
|
||||
|
||||
@deconstructible(path="django.db.models.ExpressionWrapper")
|
||||
class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
||||
@ -1943,16 +1958,7 @@ class Window(SQLiteNumericMixin, Expression):
|
||||
self.partition_by = (self.partition_by,)
|
||||
self.partition_by = ExpressionList(*self.partition_by)
|
||||
|
||||
if self.order_by is not None:
|
||||
if isinstance(self.order_by, (list, tuple)):
|
||||
self.order_by = OrderByList(*self.order_by)
|
||||
elif isinstance(self.order_by, (BaseExpression, str)):
|
||||
self.order_by = OrderByList(self.order_by)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Window.order_by must be either a string reference to a "
|
||||
"field, an expression, or a list or tuple of them."
|
||||
)
|
||||
self.order_by = OrderByList.from_param("Window.order_by", self.order_by)
|
||||
super().__init__(output_field=output_field)
|
||||
self.source_expression = self._parse_expressions(expression)[0]
|
||||
|
||||
|
@ -18,6 +18,8 @@ details on these changes.
|
||||
* The ``serialize`` keyword argument of
|
||||
``BaseDatabaseCreation.create_test_db()`` will be removed.
|
||||
|
||||
* The ``django.contrib.postgres.aggregates.StringAgg`` class will be removed.
|
||||
|
||||
.. _deprecation-removed-in-6.1:
|
||||
|
||||
6.1
|
||||
|
@ -194,6 +194,8 @@ General-purpose aggregation functions
|
||||
|
||||
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, order_by=())
|
||||
|
||||
.. deprecated:: 6.0
|
||||
|
||||
Returns the input values concatenated into a string, separated by
|
||||
the ``delimiter`` string, or ``default`` if there are no values.
|
||||
|
||||
|
@ -448,7 +448,7 @@ some complex computations::
|
||||
|
||||
The ``Aggregate`` API is as follows:
|
||||
|
||||
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra)
|
||||
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, order_by=None, **extra)
|
||||
|
||||
.. attribute:: template
|
||||
|
||||
@ -473,6 +473,15 @@ The ``Aggregate`` API is as follows:
|
||||
allows passing a ``distinct`` keyword argument. If set to ``False``
|
||||
(default), ``TypeError`` is raised if ``distinct=True`` is passed.
|
||||
|
||||
.. attribute:: allow_order_by
|
||||
|
||||
.. versionadded:: 6.0
|
||||
|
||||
A class attribute determining whether or not this aggregate function
|
||||
allows passing a ``order_by`` keyword argument. If set to ``False``
|
||||
(default), ``TypeError`` is raised if ``order_by`` is passed as a value
|
||||
other than ``None``.
|
||||
|
||||
.. attribute:: empty_result_set_value
|
||||
|
||||
Defaults to ``None`` since most aggregate functions result in ``NULL``
|
||||
@ -491,6 +500,12 @@ The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
|
||||
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
|
||||
and :ref:`filtering-on-annotations` for example usage.
|
||||
|
||||
The ``order_by`` argument behaves similarly to the ``field_names`` input of the
|
||||
:meth:`~.QuerySet.order_by` function, accepting a field name (with an optional
|
||||
``"-"`` prefix which indicates descending order) or an expression (or a tuple
|
||||
or list of strings and/or expressions) that specifies the ordering of the
|
||||
elements in the result.
|
||||
|
||||
The ``default`` argument takes a value that will be passed along with the
|
||||
aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for
|
||||
specifying a value to be returned other than ``None`` when the queryset (or
|
||||
@ -499,6 +514,10 @@ grouping) contains no entries.
|
||||
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
||||
into the ``template`` attribute.
|
||||
|
||||
.. versionchanged:: 6.0
|
||||
|
||||
The ``order_by`` argument was added.
|
||||
|
||||
Creating your own Aggregate Functions
|
||||
-------------------------------------
|
||||
|
||||
|
@ -4046,6 +4046,25 @@ by the aggregate.
|
||||
However, if ``sample=True``, the return value will be the sample
|
||||
variance.
|
||||
|
||||
``StringAgg``
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. versionadded:: 6.0
|
||||
|
||||
.. class:: StringAgg(expression, delimiter, output_field=None, distinct=False, filter=None, order_by=None, default=None, **extra)
|
||||
|
||||
Returns the input values concatenated into a string, separated by the
|
||||
``delimiter`` string, or ``default`` if there are no values.
|
||||
|
||||
* Default alias: ``<field>__stringagg``
|
||||
* Return type: ``string`` or ``output_field`` if supplied. If the
|
||||
queryset or grouping is empty, ``default`` is returned.
|
||||
|
||||
.. attribute:: delimiter
|
||||
|
||||
A ``Value`` or expression representing the string that should separate
|
||||
each of the values. For example, ``Value(",")``.
|
||||
|
||||
Query-related tools
|
||||
===================
|
||||
|
||||
|
@ -184,6 +184,16 @@ Models
|
||||
* :doc:`Constraints </ref/models/constraints>` now implement a ``check()``
|
||||
method that is already registered with the check framework.
|
||||
|
||||
* The new ``order_by`` argument for :class:`~django.db.models.Aggregate` allows
|
||||
specifying the ordering of the elements in the result.
|
||||
|
||||
* The new :attr:`.Aggregate.allow_order_by` class attribute determines whether
|
||||
the aggregate function allows passing an ``order_by`` keyword argument.
|
||||
|
||||
* The new :class:`~django.db.models.StringAgg` aggregate returns the input
|
||||
values concatenated into a string, separated by the ``delimiter`` string.
|
||||
This aggregate was previously supported only for PostgreSQL.
|
||||
|
||||
Requests and Responses
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -288,6 +298,9 @@ Miscellaneous
|
||||
* ``BaseDatabaseCreation.create_test_db(serialize)`` is deprecated. Use
|
||||
``serialize_db_to_string()`` instead.
|
||||
|
||||
* The PostgreSQL ``StringAgg`` class is deprecated in favor of the generally
|
||||
available :class:`~django.db.models.StringAgg` class.
|
||||
|
||||
Features removed in 6.0
|
||||
=======================
|
||||
|
||||
|
@ -43,3 +43,7 @@ class Store(models.Model):
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Employee(models.Model):
|
||||
work_day_preferences = models.JSONField()
|
||||
|
@ -4,10 +4,11 @@ import re
|
||||
from decimal import Decimal
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import connection
|
||||
from django.db import NotSupportedError, connection
|
||||
from django.db.models import (
|
||||
Avg,
|
||||
Case,
|
||||
CharField,
|
||||
Count,
|
||||
DateField,
|
||||
DateTimeField,
|
||||
@ -22,6 +23,7 @@ from django.db.models import (
|
||||
OuterRef,
|
||||
Q,
|
||||
StdDev,
|
||||
StringAgg,
|
||||
Subquery,
|
||||
Sum,
|
||||
TimeField,
|
||||
@ -32,9 +34,11 @@ from django.db.models import (
|
||||
Window,
|
||||
)
|
||||
from django.db.models.expressions import Func, RawSQL
|
||||
from django.db.models.fields.json import KeyTextTransform
|
||||
from django.db.models.functions import (
|
||||
Cast,
|
||||
Coalesce,
|
||||
Concat,
|
||||
Greatest,
|
||||
Least,
|
||||
Lower,
|
||||
@ -45,11 +49,11 @@ from django.db.models.functions import (
|
||||
TruncHour,
|
||||
)
|
||||
from django.test import TestCase
|
||||
from django.test.testcases import skipUnlessDBFeature
|
||||
from django.test.testcases import skipIfDBFeature, skipUnlessDBFeature
|
||||
from django.test.utils import Approximate, CaptureQueriesContext
|
||||
from django.utils import timezone
|
||||
|
||||
from .models import Author, Book, Publisher, Store
|
||||
from .models import Author, Book, Employee, Publisher, Store
|
||||
|
||||
|
||||
class NowUTC(Now):
|
||||
@ -566,6 +570,28 @@ class AggregateTestCase(TestCase):
|
||||
)
|
||||
self.assertEqual(books["ratings"], expected_result)
|
||||
|
||||
@skipUnlessDBFeature("supports_aggregate_distinct_multiple_argument")
|
||||
def test_distinct_on_stringagg(self):
|
||||
books = Book.objects.aggregate(
|
||||
ratings=StringAgg(Cast(F("rating"), CharField()), Value(","), distinct=True)
|
||||
)
|
||||
self.assertEqual(books["ratings"], "3,4,4.5,5")
|
||||
|
||||
@skipIfDBFeature("supports_aggregate_distinct_multiple_argument")
|
||||
def test_raises_error_on_multiple_argument_distinct(self):
|
||||
message = (
|
||||
"StringAgg does not support distinct with multiple expressions on this "
|
||||
"database backend."
|
||||
)
|
||||
with self.assertRaisesMessage(NotSupportedError, message):
|
||||
Book.objects.aggregate(
|
||||
ratings=StringAgg(
|
||||
Cast(F("rating"), CharField()),
|
||||
Value(","),
|
||||
distinct=True,
|
||||
)
|
||||
)
|
||||
|
||||
def test_non_grouped_annotation_not_in_group_by(self):
|
||||
"""
|
||||
An annotation not included in values() before an aggregate should be
|
||||
@ -1288,24 +1314,30 @@ class AggregateTestCase(TestCase):
|
||||
Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price"))
|
||||
|
||||
def test_multi_arg_aggregate(self):
|
||||
class MyMax(Max):
|
||||
class MultiArgAgg(Max):
|
||||
output_field = DecimalField()
|
||||
arity = None
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
|
||||
return super(MyMax, copy).as_sql(compiler, connection)
|
||||
# Most database backends do not support compiling multiple arguments on
|
||||
# the Max aggregate, and that isn't what is being tested here anyway. To
|
||||
# avoid errors, the extra argument is just dropped.
|
||||
copy.set_source_expressions(
|
||||
copy.get_source_expressions()[0:1] + [None, None]
|
||||
)
|
||||
|
||||
return super(MultiArgAgg, copy).as_sql(compiler, connection)
|
||||
|
||||
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
|
||||
Book.objects.aggregate(MyMax("pages", "price"))
|
||||
Book.objects.aggregate(MultiArgAgg("pages", "price"))
|
||||
|
||||
with self.assertRaisesMessage(
|
||||
TypeError, "Complex annotations require an alias"
|
||||
):
|
||||
Book.objects.annotate(MyMax("pages", "price"))
|
||||
Book.objects.annotate(MultiArgAgg("pages", "price"))
|
||||
|
||||
Book.objects.aggregate(max_field=MyMax("pages", "price"))
|
||||
Book.objects.aggregate(max_field=MultiArgAgg("pages", "price"))
|
||||
|
||||
def test_add_implementation(self):
|
||||
class MySum(Sum):
|
||||
@ -1318,6 +1350,8 @@ class AggregateTestCase(TestCase):
|
||||
"function": self.function.lower(),
|
||||
"expressions": sql,
|
||||
"distinct": "",
|
||||
"filter": "",
|
||||
"order_by": "",
|
||||
}
|
||||
substitutions.update(self.extra)
|
||||
return self.template % substitutions, params
|
||||
@ -1351,7 +1385,13 @@ class AggregateTestCase(TestCase):
|
||||
|
||||
# test overriding all parts of the template
|
||||
def be_evil(self, compiler, connection):
|
||||
substitutions = {"function": "MAX", "expressions": "2", "distinct": ""}
|
||||
substitutions = {
|
||||
"function": "MAX",
|
||||
"expressions": "2",
|
||||
"distinct": "",
|
||||
"filter": "",
|
||||
"order_by": "",
|
||||
}
|
||||
substitutions.update(self.extra)
|
||||
return self.template % substitutions, ()
|
||||
|
||||
@ -1779,10 +1819,12 @@ class AggregateTestCase(TestCase):
|
||||
Publisher.objects.none().aggregate(
|
||||
sum_awards=Sum("num_awards"),
|
||||
books_count=Count("book"),
|
||||
all_names=StringAgg("name", Value(",")),
|
||||
),
|
||||
{
|
||||
"sum_awards": None,
|
||||
"books_count": 0,
|
||||
"all_names": None,
|
||||
},
|
||||
)
|
||||
# Expression without empty_result_set_value forces queries to be
|
||||
@ -1874,6 +1916,12 @@ class AggregateTestCase(TestCase):
|
||||
)
|
||||
self.assertEqual(result["value"], 35)
|
||||
|
||||
def test_stringagg_default_value(self):
|
||||
result = Author.objects.filter(age__gt=100).aggregate(
|
||||
value=StringAgg("name", delimiter=Value(";"), default=Value("<empty>")),
|
||||
)
|
||||
self.assertEqual(result["value"], "<empty>")
|
||||
|
||||
def test_aggregation_default_group_by(self):
|
||||
qs = (
|
||||
Publisher.objects.values("name")
|
||||
@ -2202,6 +2250,167 @@ class AggregateTestCase(TestCase):
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
super(function, func_instance).__init__(Value(1), Value(2))
|
||||
|
||||
def test_string_agg_requires_delimiter(self):
|
||||
with self.assertRaises(TypeError):
|
||||
Book.objects.aggregate(stringagg=StringAgg("name"))
|
||||
|
||||
def test_string_agg_escapes_delimiter(self):
|
||||
values = Publisher.objects.aggregate(
|
||||
stringagg=StringAgg("name", delimiter=Value("'"))
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
values,
|
||||
{
|
||||
"stringagg": "Apress'Sams'Prentice Hall'Morgan Kaufmann'Jonno's House "
|
||||
"of Books",
|
||||
},
|
||||
)
|
||||
|
||||
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
|
||||
def test_string_agg_order_by(self):
|
||||
order_by_test_cases = (
|
||||
(
|
||||
F("original_opening").desc(),
|
||||
"Books.com;Amazon.com;Mamma and Pappa's Books",
|
||||
),
|
||||
(
|
||||
F("original_opening").asc(),
|
||||
"Mamma and Pappa's Books;Amazon.com;Books.com",
|
||||
),
|
||||
(F("original_opening"), "Mamma and Pappa's Books;Amazon.com;Books.com"),
|
||||
("original_opening", "Mamma and Pappa's Books;Amazon.com;Books.com"),
|
||||
("-original_opening", "Books.com;Amazon.com;Mamma and Pappa's Books"),
|
||||
(
|
||||
Concat("original_opening", Value("@")),
|
||||
"Mamma and Pappa's Books;Amazon.com;Books.com",
|
||||
),
|
||||
(
|
||||
Concat("original_opening", Value("@")).desc(),
|
||||
"Books.com;Amazon.com;Mamma and Pappa's Books",
|
||||
),
|
||||
)
|
||||
for order_by, expected_output in order_by_test_cases:
|
||||
with self.subTest(order_by=order_by, expected_output=expected_output):
|
||||
values = Store.objects.aggregate(
|
||||
stringagg=StringAgg("name", delimiter=Value(";"), order_by=order_by)
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": expected_output})
|
||||
|
||||
@skipIfDBFeature("supports_aggregate_order_by_clause")
|
||||
def test_string_agg_order_by_is_not_supported(self):
|
||||
message = (
|
||||
"This database backend does not support specifying an order on aggregates."
|
||||
)
|
||||
with self.assertRaisesMessage(NotSupportedError, message):
|
||||
Store.objects.aggregate(
|
||||
stringagg=StringAgg(
|
||||
"name",
|
||||
delimiter=Value(";"),
|
||||
order_by="original_opening",
|
||||
)
|
||||
)
|
||||
|
||||
def test_string_agg_filter(self):
|
||||
values = Book.objects.aggregate(
|
||||
stringagg=StringAgg(
|
||||
"name",
|
||||
delimiter=Value(";"),
|
||||
filter=Q(name__startswith="P"),
|
||||
)
|
||||
)
|
||||
|
||||
expected_values = {
|
||||
"stringagg": "Practical Django Projects;"
|
||||
"Python Web Development with Django;Paradigms of Artificial "
|
||||
"Intelligence Programming: Case Studies in Common Lisp",
|
||||
}
|
||||
self.assertEqual(values, expected_values)
|
||||
|
||||
@skipUnlessDBFeature("supports_json_field", "supports_aggregate_order_by_clause")
|
||||
def test_string_agg_jsonfield_order_by(self):
|
||||
Employee.objects.bulk_create(
|
||||
[
|
||||
Employee(work_day_preferences={"Monday": "morning"}),
|
||||
Employee(work_day_preferences={"Monday": "afternoon"}),
|
||||
]
|
||||
)
|
||||
values = Employee.objects.aggregate(
|
||||
stringagg=StringAgg(
|
||||
KeyTextTransform("Monday", "work_day_preferences"),
|
||||
delimiter=Value(","),
|
||||
order_by=KeyTextTransform(
|
||||
"Monday",
|
||||
"work_day_preferences",
|
||||
),
|
||||
output_field=CharField(),
|
||||
),
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "afternoon,morning"})
|
||||
|
||||
def test_string_agg_filter_in_subquery(self):
|
||||
aggregate = StringAgg(
|
||||
"authors__name",
|
||||
delimiter=Value(";"),
|
||||
filter=~Q(authors__name__startswith="J"),
|
||||
)
|
||||
subquery = (
|
||||
Book.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = list(
|
||||
Book.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
).values_list("agg", flat=True)
|
||||
)
|
||||
|
||||
expected_values = [
|
||||
"Adrian Holovaty",
|
||||
"Brad Dayley",
|
||||
"Paul Bissex;Wesley J. Chun",
|
||||
"Peter Norvig;Stuart Russell",
|
||||
"Peter Norvig",
|
||||
"" if connection.features.interprets_empty_strings_as_nulls else None,
|
||||
]
|
||||
|
||||
self.assertQuerySetEqual(expected_values, values, ordered=False)
|
||||
|
||||
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
|
||||
def test_order_by_in_subquery(self):
|
||||
aggregate = StringAgg(
|
||||
"authors__name",
|
||||
delimiter=Value(";"),
|
||||
order_by="authors__name",
|
||||
)
|
||||
subquery = (
|
||||
Book.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = list(
|
||||
Book.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.order_by("agg")
|
||||
.values_list("agg", flat=True)
|
||||
)
|
||||
|
||||
expected_values = [
|
||||
"Adrian Holovaty;Jacob Kaplan-Moss",
|
||||
"Brad Dayley",
|
||||
"James Bennett",
|
||||
"Jeffrey Forcier;Paul Bissex;Wesley J. Chun",
|
||||
"Peter Norvig",
|
||||
"Peter Norvig;Stuart Russell",
|
||||
]
|
||||
|
||||
self.assertEqual(expected_values, values)
|
||||
|
||||
|
||||
class AggregateAnnotationPruningTests(TestCase):
|
||||
@classmethod
|
||||
|
@ -1720,14 +1720,14 @@ class WindowFunctionTests(TestCase):
|
||||
"""Window expressions can't be used in an INSERT statement."""
|
||||
msg = (
|
||||
"Window expressions are not allowed in this query (salary=<Window: "
|
||||
"Sum(Value(10000), order_by=OrderBy(F(pk), descending=False)) OVER ()"
|
||||
"Sum(Value(10000)) OVER ()"
|
||||
)
|
||||
with self.assertRaisesMessage(FieldError, msg):
|
||||
Employee.objects.create(
|
||||
name="Jameson",
|
||||
department="Management",
|
||||
hire_date=datetime.date(2007, 7, 1),
|
||||
salary=Window(expression=Sum(Value(10000), order_by=F("pk").asc())),
|
||||
salary=Window(expression=Sum(Value(10000))),
|
||||
)
|
||||
|
||||
def test_window_expression_within_subquery(self):
|
||||
@ -2025,7 +2025,7 @@ class NonQueryWindowTests(SimpleTestCase):
|
||||
def test_invalid_order_by(self):
|
||||
msg = (
|
||||
"Window.order_by must be either a string reference to a field, an "
|
||||
"expression, or a list or tuple of them."
|
||||
"expression, or a list or tuple of them not {'-horse'}."
|
||||
)
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Window(expression=Sum("power"), order_by={"-horse"})
|
||||
|
@ -1,3 +1,5 @@
|
||||
import warnings
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.models import (
|
||||
CharField,
|
||||
@ -11,16 +13,19 @@ from django.db.models import (
|
||||
Value,
|
||||
Window,
|
||||
)
|
||||
from django.db.models.fields.json import KeyTextTransform, KeyTransform
|
||||
from django.db.models.fields.json import KeyTransform
|
||||
from django.db.models.functions import Cast, Concat, LPad, Substr
|
||||
from django.test.utils import Approximate
|
||||
from django.utils import timezone
|
||||
from django.utils.deprecation import RemovedInDjango61Warning
|
||||
from django.utils.deprecation import RemovedInDjango61Warning, RemovedInDjango70Warning
|
||||
|
||||
from . import PostgreSQLTestCase
|
||||
from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
|
||||
|
||||
try:
|
||||
from django.contrib.postgres.aggregates import (
|
||||
StringAgg, # RemovedInDjango70Warning.
|
||||
)
|
||||
from django.contrib.postgres.aggregates import (
|
||||
ArrayAgg,
|
||||
BitAnd,
|
||||
@ -41,7 +46,6 @@ try:
|
||||
RegrSXY,
|
||||
RegrSYY,
|
||||
StatAggregate,
|
||||
StringAgg,
|
||||
)
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
except ImportError:
|
||||
@ -94,7 +98,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
BoolAnd("boolean_field"),
|
||||
BoolOr("boolean_field"),
|
||||
JSONBAgg("integer_field"),
|
||||
StringAgg("char_field", delimiter=";"),
|
||||
BitXor("integer_field"),
|
||||
]
|
||||
for aggregation in tests:
|
||||
@ -127,11 +130,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
|
||||
["<empty>"],
|
||||
),
|
||||
(StringAgg("char_field", delimiter=";", default="<empty>"), "<empty>"),
|
||||
(
|
||||
StringAgg("char_field", delimiter=";", default=Value("<empty>")),
|
||||
"<empty>",
|
||||
),
|
||||
(BitXor("integer_field", default=0), 0),
|
||||
]
|
||||
for aggregation, expected_result in tests:
|
||||
@ -158,8 +156,9 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
|
||||
self.assertEqual(ctx.filename, __file__)
|
||||
|
||||
# RemovedInDjango61Warning: Remove this test
|
||||
def test_ordering_and_order_by_causes_error(self):
|
||||
with self.assertWarns(RemovedInDjango61Warning):
|
||||
with warnings.catch_warnings(record=True, action="always") as wm:
|
||||
with self.assertRaisesMessage(
|
||||
TypeError,
|
||||
"Cannot specify both order_by and ordering.",
|
||||
@ -173,6 +172,21 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
first_warning = wm[0]
|
||||
self.assertEqual(first_warning.category, RemovedInDjango70Warning)
|
||||
self.assertEqual(
|
||||
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||
"django.db.models.aggregate.StringAgg instead.",
|
||||
str(first_warning.message),
|
||||
)
|
||||
|
||||
second_warning = wm[1]
|
||||
self.assertEqual(second_warning.category, RemovedInDjango61Warning)
|
||||
self.assertEqual(
|
||||
"The ordering argument is deprecated. Use order_by instead.",
|
||||
str(second_warning.message),
|
||||
)
|
||||
|
||||
def test_array_agg_charfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
|
||||
self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
|
||||
@ -425,66 +439,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
)
|
||||
self.assertEqual(values, {"boolor": False})
|
||||
|
||||
def test_string_agg_requires_delimiter(self):
|
||||
with self.assertRaises(TypeError):
|
||||
AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
|
||||
|
||||
def test_string_agg_delimiter_escaping(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter="'")
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||
|
||||
def test_string_agg_charfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=";")
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
|
||||
|
||||
def test_string_agg_default_output_field(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("text_field", delimiter=";"),
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
|
||||
|
||||
def test_string_agg_charfield_order_by(self):
|
||||
order_by_test_cases = (
|
||||
(F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
|
||||
(F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
|
||||
(F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
|
||||
("char_field", "Foo1;Foo2;Foo3;Foo4"),
|
||||
("-char_field", "Foo4;Foo3;Foo2;Foo1"),
|
||||
(Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
|
||||
(Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
|
||||
)
|
||||
for order_by, expected_output in order_by_test_cases:
|
||||
with self.subTest(order_by=order_by, expected_output=expected_output):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=";", order_by=order_by)
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": expected_output})
|
||||
|
||||
def test_string_agg_jsonfield_order_by(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg(
|
||||
KeyTextTransform("lang", "json_field"),
|
||||
delimiter=";",
|
||||
order_by=KeyTextTransform("lang", "json_field"),
|
||||
output_field=CharField(),
|
||||
),
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "en;pl"})
|
||||
|
||||
def test_string_agg_filter(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg(
|
||||
"char_field",
|
||||
delimiter=";",
|
||||
filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
|
||||
)
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
|
||||
|
||||
def test_orderable_agg_alternative_fields(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc())
|
||||
@ -593,48 +547,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_string_agg_array_agg_order_by_in_subquery(self):
|
||||
def test_array_agg_order_by_in_subquery(self):
|
||||
stats = []
|
||||
for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
|
||||
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
|
||||
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
|
||||
StatTestModel.objects.bulk_create(stats)
|
||||
|
||||
for aggregate, expected_result in (
|
||||
(
|
||||
ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2"),
|
||||
[
|
||||
("Foo1", [0, 1]),
|
||||
("Foo2", [1, 2]),
|
||||
("Foo3", [2, 3]),
|
||||
("Foo4", [3, 4]),
|
||||
],
|
||||
),
|
||||
(
|
||||
StringAgg(
|
||||
Cast("stattestmodel__int1", CharField()),
|
||||
delimiter=";",
|
||||
order_by="-stattestmodel__int2",
|
||||
),
|
||||
[("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
|
||||
),
|
||||
):
|
||||
with self.subTest(aggregate=aggregate.__class__.__name__):
|
||||
subquery = (
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
aggregate = ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2")
|
||||
expected_result = [
|
||||
("Foo1", [0, 1]),
|
||||
("Foo2", [1, 2]),
|
||||
("Foo3", [2, 3]),
|
||||
("Foo4", [3, 4]),
|
||||
]
|
||||
|
||||
subquery = (
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
|
||||
def test_string_agg_array_agg_filter_in_subquery(self):
|
||||
StatTestModel.objects.bulk_create(
|
||||
@ -644,56 +586,31 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
|
||||
]
|
||||
)
|
||||
for aggregate, expected_result in (
|
||||
(
|
||||
ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
|
||||
[("Foo1", [0, 1]), ("Foo2", None)],
|
||||
),
|
||||
(
|
||||
StringAgg(
|
||||
Cast("stattestmodel__int2", CharField()),
|
||||
delimiter=";",
|
||||
filter=Q(stattestmodel__int1__lt=2),
|
||||
),
|
||||
[("Foo1", "5;4"), ("Foo2", None)],
|
||||
),
|
||||
):
|
||||
with self.subTest(aggregate=aggregate.__class__.__name__):
|
||||
subquery = (
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.filter(
|
||||
char_field__in=["Foo1", "Foo2"],
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
|
||||
def test_string_agg_filter_in_subquery_with_exclude(self):
|
||||
aggregate = ArrayAgg(
|
||||
"stattestmodel__int1",
|
||||
filter=Q(stattestmodel__int2__gt=3),
|
||||
)
|
||||
expected_result = [("Foo1", [0, 1]), ("Foo2", None)]
|
||||
|
||||
subquery = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
stringagg=StringAgg(
|
||||
"char_field",
|
||||
delimiter=";",
|
||||
filter=Q(char_field__endswith="1"),
|
||||
)
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.exclude(stringagg="")
|
||||
.values("id")
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
self.assertSequenceEqual(
|
||||
AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
|
||||
[self.aggs[0]],
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.filter(
|
||||
char_field__in=["Foo1", "Foo2"],
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
|
||||
def test_ordering_isnt_cleared_for_array_subquery(self):
|
||||
inner_qs = AggregateTestModel.objects.order_by("-integer_field")
|
||||
@ -729,11 +646,41 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
|
||||
for aggregation in tests:
|
||||
with self.subTest(aggregation=aggregation):
|
||||
results = AggregateTestModel.objects.annotate(
|
||||
agg=aggregation
|
||||
).values_list("agg")
|
||||
self.assertCountEqual(
|
||||
AggregateTestModel.objects.values_list(aggregation),
|
||||
results,
|
||||
[([0],), ([1],), ([2],), ([0],)],
|
||||
)
|
||||
|
||||
def test_string_agg_delimiter_deprecation(self):
|
||||
msg = (
|
||||
"delimiter: str will be resolved as a field reference instead "
|
||||
'of a string literal on Django 7.0. Pass `delimiter=Value("\'")` to '
|
||||
"preserve the previous behaviour."
|
||||
)
|
||||
|
||||
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter="'")
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||
self.assertEqual(ctx.filename, __file__)
|
||||
|
||||
def test_string_agg_deprecation(self):
|
||||
msg = (
|
||||
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||
"django.db.models.aggregate.StringAgg instead."
|
||||
)
|
||||
|
||||
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=Value("'"))
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||
self.assertEqual(ctx.filename, __file__)
|
||||
|
||||
|
||||
class TestAggregateDistinct(PostgreSQLTestCase):
|
||||
@classmethod
|
||||
@ -742,20 +689,6 @@ class TestAggregateDistinct(PostgreSQLTestCase):
|
||||
AggregateTestModel.objects.create(char_field="Foo")
|
||||
AggregateTestModel.objects.create(char_field="Bar")
|
||||
|
||||
def test_string_agg_distinct_false(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
|
||||
)
|
||||
self.assertEqual(values["stringagg"].count("Foo"), 2)
|
||||
self.assertEqual(values["stringagg"].count("Bar"), 1)
|
||||
|
||||
def test_string_agg_distinct_true(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
|
||||
)
|
||||
self.assertEqual(values["stringagg"].count("Foo"), 1)
|
||||
self.assertEqual(values["stringagg"].count("Bar"), 1)
|
||||
|
||||
def test_array_agg_distinct_false(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg("char_field", distinct=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user