mirror of
https://github.com/django/django.git
synced 2025-06-05 03:29:12 +00:00
Refs #28333 -- Added partial support for filtering against window functions.
Adds support for joint predicates against window annotations through subquery wrapping while maintaining errors for disjointed filter attempts. The "qualify" wording was used to refer to predicates against window annotations as it's the name of a specialized Snowflake extension to SQL that is to window functions what HAVING is to aggregates. While not complete the implementation should cover most of the common use cases for filtering against window functions without requiring the complex subquery pushdown and predicate re-aliasing machinery to deal with disjointed predicates against columns, aggregates, and window functions. A complete disjointed filtering implementation should likely be deferred until proper QUALIFY support lands or the ORM gains a proper subquery pushdown interface.
This commit is contained in:
parent
f3f9d03edf
commit
f387d024fc
@ -28,10 +28,13 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
|||||||
# the SQLDeleteCompiler's default implementation when multiple tables
|
# the SQLDeleteCompiler's default implementation when multiple tables
|
||||||
# are involved since MySQL/MariaDB will generate a more efficient query
|
# are involved since MySQL/MariaDB will generate a more efficient query
|
||||||
# plan than when using a subquery.
|
# plan than when using a subquery.
|
||||||
where, having = self.query.where.split_having()
|
where, having, qualify = self.query.where.split_having_qualify(
|
||||||
if self.single_alias or having:
|
must_group_by=self.query.group_by is not None
|
||||||
# DELETE FROM cannot be used when filtering against aggregates
|
)
|
||||||
# since it doesn't allow for GROUP BY and HAVING clauses.
|
if self.single_alias or having or qualify:
|
||||||
|
# DELETE FROM cannot be used when filtering against aggregates or
|
||||||
|
# window functions as it doesn't allow for GROUP BY/HAVING clauses
|
||||||
|
# and the subquery wrapping (necessary to emulate QUALIFY).
|
||||||
return super().as_sql()
|
return super().as_sql()
|
||||||
result = [
|
result = [
|
||||||
"DELETE %s FROM"
|
"DELETE %s FROM"
|
||||||
|
@ -836,6 +836,7 @@ class ResolvedOuterRef(F):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
contains_aggregate = False
|
contains_aggregate = False
|
||||||
|
contains_over_clause = False
|
||||||
|
|
||||||
def as_sql(self, *args, **kwargs):
|
def as_sql(self, *args, **kwargs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1210,6 +1211,12 @@ class OrderByList(Func):
|
|||||||
return "", ()
|
return "", ()
|
||||||
return super().as_sql(*args, **kwargs)
|
return super().as_sql(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_group_by_cols(self):
|
||||||
|
group_by_cols = []
|
||||||
|
for order_by in self.get_source_expressions():
|
||||||
|
group_by_cols.extend(order_by.get_group_by_cols())
|
||||||
|
return group_by_cols
|
||||||
|
|
||||||
|
|
||||||
@deconstructible(path="django.db.models.ExpressionWrapper")
|
@deconstructible(path="django.db.models.ExpressionWrapper")
|
||||||
class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
||||||
@ -1631,7 +1638,6 @@ class Window(SQLiteNumericMixin, Expression):
|
|||||||
# be introduced in the query as a result is not desired.
|
# be introduced in the query as a result is not desired.
|
||||||
contains_aggregate = False
|
contains_aggregate = False
|
||||||
contains_over_clause = True
|
contains_over_clause = True
|
||||||
filterable = False
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1733,7 +1739,12 @@ class Window(SQLiteNumericMixin, Expression):
|
|||||||
return "<%s: %s>" % (self.__class__.__name__, self)
|
return "<%s: %s>" % (self.__class__.__name__, self)
|
||||||
|
|
||||||
def get_group_by_cols(self, alias=None):
|
def get_group_by_cols(self, alias=None):
|
||||||
return []
|
group_by_cols = []
|
||||||
|
if self.partition_by:
|
||||||
|
group_by_cols.extend(self.partition_by.get_group_by_cols())
|
||||||
|
if self.order_by is not None:
|
||||||
|
group_by_cols.extend(self.order_by.get_group_by_cols())
|
||||||
|
return group_by_cols
|
||||||
|
|
||||||
|
|
||||||
class WindowFrame(Expression):
|
class WindowFrame(Expression):
|
||||||
|
@ -14,6 +14,7 @@ from django.utils.deprecation import RemovedInDjango50Warning
|
|||||||
|
|
||||||
class MultiColSource:
|
class MultiColSource:
|
||||||
contains_aggregate = False
|
contains_aggregate = False
|
||||||
|
contains_over_clause = False
|
||||||
|
|
||||||
def __init__(self, alias, targets, sources, field):
|
def __init__(self, alias, targets, sources, field):
|
||||||
self.targets, self.sources, self.field, self.alias = (
|
self.targets, self.sources, self.field, self.alias = (
|
||||||
|
@ -9,6 +9,7 @@ from django.db import DatabaseError, NotSupportedError
|
|||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
|
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
|
||||||
from django.db.models.functions import Cast, Random
|
from django.db.models.functions import Cast, Random
|
||||||
|
from django.db.models.lookups import Lookup
|
||||||
from django.db.models.query_utils import select_related_descend
|
from django.db.models.query_utils import select_related_descend
|
||||||
from django.db.models.sql.constants import (
|
from django.db.models.sql.constants import (
|
||||||
CURSOR,
|
CURSOR,
|
||||||
@ -73,7 +74,9 @@ class SQLCompiler:
|
|||||||
"""
|
"""
|
||||||
self.setup_query(with_col_aliases=with_col_aliases)
|
self.setup_query(with_col_aliases=with_col_aliases)
|
||||||
order_by = self.get_order_by()
|
order_by = self.get_order_by()
|
||||||
self.where, self.having = self.query.where.split_having()
|
self.where, self.having, self.qualify = self.query.where.split_having_qualify(
|
||||||
|
must_group_by=self.query.group_by is not None
|
||||||
|
)
|
||||||
extra_select = self.get_extra_select(order_by, self.select)
|
extra_select = self.get_extra_select(order_by, self.select)
|
||||||
self.has_extra_select = bool(extra_select)
|
self.has_extra_select = bool(extra_select)
|
||||||
group_by = self.get_group_by(self.select + extra_select, order_by)
|
group_by = self.get_group_by(self.select + extra_select, order_by)
|
||||||
@ -584,6 +587,74 @@ class SQLCompiler:
|
|||||||
params.extend(part)
|
params.extend(part)
|
||||||
return result, params
|
return result, params
|
||||||
|
|
||||||
|
def get_qualify_sql(self):
|
||||||
|
where_parts = []
|
||||||
|
if self.where:
|
||||||
|
where_parts.append(self.where)
|
||||||
|
if self.having:
|
||||||
|
where_parts.append(self.having)
|
||||||
|
inner_query = self.query.clone()
|
||||||
|
inner_query.subquery = True
|
||||||
|
inner_query.where = inner_query.where.__class__(where_parts)
|
||||||
|
# Augment the inner query with any window function references that
|
||||||
|
# might have been masked via values() and alias(). If any masked
|
||||||
|
# aliases are added they'll be masked again to avoid fetching
|
||||||
|
# the data in the `if qual_aliases` branch below.
|
||||||
|
select = {
|
||||||
|
expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
|
||||||
|
}
|
||||||
|
qual_aliases = set()
|
||||||
|
replacements = {}
|
||||||
|
expressions = list(self.qualify.leaves())
|
||||||
|
while expressions:
|
||||||
|
expr = expressions.pop()
|
||||||
|
if select_alias := (select.get(expr) or replacements.get(expr)):
|
||||||
|
replacements[expr] = select_alias
|
||||||
|
elif isinstance(expr, Lookup):
|
||||||
|
expressions.extend(expr.get_source_expressions())
|
||||||
|
else:
|
||||||
|
num_qual_alias = len(qual_aliases)
|
||||||
|
select_alias = f"qual{num_qual_alias}"
|
||||||
|
qual_aliases.add(select_alias)
|
||||||
|
inner_query.add_annotation(expr, select_alias)
|
||||||
|
replacements[expr] = select_alias
|
||||||
|
self.qualify = self.qualify.replace_expressions(
|
||||||
|
{expr: Ref(alias, expr) for expr, alias in replacements.items()}
|
||||||
|
)
|
||||||
|
inner_query_compiler = inner_query.get_compiler(
|
||||||
|
self.using, elide_empty=self.elide_empty
|
||||||
|
)
|
||||||
|
inner_sql, inner_params = inner_query_compiler.as_sql(
|
||||||
|
# The limits must be applied to the outer query to avoid pruning
|
||||||
|
# results too eagerly.
|
||||||
|
with_limits=False,
|
||||||
|
# Force unique aliasing of selected columns to avoid collisions
|
||||||
|
# and make rhs predicates referencing easier.
|
||||||
|
with_col_aliases=True,
|
||||||
|
)
|
||||||
|
qualify_sql, qualify_params = self.compile(self.qualify)
|
||||||
|
result = [
|
||||||
|
"SELECT * FROM (",
|
||||||
|
inner_sql,
|
||||||
|
")",
|
||||||
|
self.connection.ops.quote_name("qualify"),
|
||||||
|
"WHERE",
|
||||||
|
qualify_sql,
|
||||||
|
]
|
||||||
|
if qual_aliases:
|
||||||
|
# If some select aliases were unmasked for filtering purposes they
|
||||||
|
# must be masked back.
|
||||||
|
cols = [self.connection.ops.quote_name(alias) for alias in select.values()]
|
||||||
|
result = [
|
||||||
|
"SELECT",
|
||||||
|
", ".join(cols),
|
||||||
|
"FROM (",
|
||||||
|
*result,
|
||||||
|
")",
|
||||||
|
self.connection.ops.quote_name("qualify_mask"),
|
||||||
|
]
|
||||||
|
return result, list(inner_params) + qualify_params
|
||||||
|
|
||||||
def as_sql(self, with_limits=True, with_col_aliases=False):
|
def as_sql(self, with_limits=True, with_col_aliases=False):
|
||||||
"""
|
"""
|
||||||
Create the SQL for this query. Return the SQL string and list of
|
Create the SQL for this query. Return the SQL string and list of
|
||||||
@ -614,6 +685,9 @@ class SQLCompiler:
|
|||||||
result, params = self.get_combinator_sql(
|
result, params = self.get_combinator_sql(
|
||||||
combinator, self.query.combinator_all
|
combinator, self.query.combinator_all
|
||||||
)
|
)
|
||||||
|
elif self.qualify:
|
||||||
|
result, params = self.get_qualify_sql()
|
||||||
|
order_by = None
|
||||||
else:
|
else:
|
||||||
distinct_fields, distinct_params = self.get_distinct()
|
distinct_fields, distinct_params = self.get_distinct()
|
||||||
# This must come after 'select', 'ordering', and 'distinct'
|
# This must come after 'select', 'ordering', and 'distinct'
|
||||||
|
@ -35,48 +35,81 @@ class WhereNode(tree.Node):
|
|||||||
resolved = False
|
resolved = False
|
||||||
conditional = True
|
conditional = True
|
||||||
|
|
||||||
def split_having(self, negated=False):
|
def split_having_qualify(self, negated=False, must_group_by=False):
|
||||||
"""
|
"""
|
||||||
Return two possibly None nodes: one for those parts of self that
|
Return three possibly None nodes: one for those parts of self that
|
||||||
should be included in the WHERE clause and one for those parts of
|
should be included in the WHERE clause, one for those parts of self
|
||||||
self that must be included in the HAVING clause.
|
that must be included in the HAVING clause, and one for those parts
|
||||||
|
that refer to window functions.
|
||||||
"""
|
"""
|
||||||
if not self.contains_aggregate:
|
if not self.contains_aggregate and not self.contains_over_clause:
|
||||||
return self, None
|
return self, None, None
|
||||||
in_negated = negated ^ self.negated
|
in_negated = negated ^ self.negated
|
||||||
# If the effective connector is OR or XOR and this node contains an
|
# Whether or not children must be connected in the same filtering
|
||||||
# aggregate, then we need to push the whole branch to HAVING clause.
|
# clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
|
||||||
may_need_split = (
|
must_remain_connected = (
|
||||||
(in_negated and self.connector == AND)
|
(in_negated and self.connector == AND)
|
||||||
or (not in_negated and self.connector == OR)
|
or (not in_negated and self.connector == OR)
|
||||||
or self.connector == XOR
|
or self.connector == XOR
|
||||||
)
|
)
|
||||||
if may_need_split and self.contains_aggregate:
|
if (
|
||||||
return None, self
|
must_remain_connected
|
||||||
|
and self.contains_aggregate
|
||||||
|
and not self.contains_over_clause
|
||||||
|
):
|
||||||
|
# It's must cheaper to short-circuit and stash everything in the
|
||||||
|
# HAVING clause than split children if possible.
|
||||||
|
return None, self, None
|
||||||
where_parts = []
|
where_parts = []
|
||||||
having_parts = []
|
having_parts = []
|
||||||
|
qualify_parts = []
|
||||||
for c in self.children:
|
for c in self.children:
|
||||||
if hasattr(c, "split_having"):
|
if hasattr(c, "split_having_qualify"):
|
||||||
where_part, having_part = c.split_having(in_negated)
|
where_part, having_part, qualify_part = c.split_having_qualify(
|
||||||
|
in_negated, must_group_by
|
||||||
|
)
|
||||||
if where_part is not None:
|
if where_part is not None:
|
||||||
where_parts.append(where_part)
|
where_parts.append(where_part)
|
||||||
if having_part is not None:
|
if having_part is not None:
|
||||||
having_parts.append(having_part)
|
having_parts.append(having_part)
|
||||||
|
if qualify_part is not None:
|
||||||
|
qualify_parts.append(qualify_part)
|
||||||
|
elif c.contains_over_clause:
|
||||||
|
qualify_parts.append(c)
|
||||||
elif c.contains_aggregate:
|
elif c.contains_aggregate:
|
||||||
having_parts.append(c)
|
having_parts.append(c)
|
||||||
else:
|
else:
|
||||||
where_parts.append(c)
|
where_parts.append(c)
|
||||||
having_node = (
|
if must_remain_connected and qualify_parts:
|
||||||
self.create(having_parts, self.connector, self.negated)
|
# Disjunctive heterogeneous predicates can be pushed down to
|
||||||
if having_parts
|
# qualify as long as no conditional aggregation is involved.
|
||||||
else None
|
if not where_parts or (where_parts and not must_group_by):
|
||||||
)
|
return None, None, self
|
||||||
|
elif where_parts:
|
||||||
|
# In theory this should only be enforced when dealing with
|
||||||
|
# where_parts containing predicates against multi-valued
|
||||||
|
# relationships that could affect aggregation results but this
|
||||||
|
# is complex to infer properly.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Heterogeneous disjunctive predicates against window functions are "
|
||||||
|
"not implemented when performing conditional aggregation."
|
||||||
|
)
|
||||||
where_node = (
|
where_node = (
|
||||||
self.create(where_parts, self.connector, self.negated)
|
self.create(where_parts, self.connector, self.negated)
|
||||||
if where_parts
|
if where_parts
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return where_node, having_node
|
having_node = (
|
||||||
|
self.create(having_parts, self.connector, self.negated)
|
||||||
|
if having_parts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
qualify_node = (
|
||||||
|
self.create(qualify_parts, self.connector, self.negated)
|
||||||
|
if qualify_parts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return where_node, having_node, qualify_node
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
"""
|
"""
|
||||||
@ -183,6 +216,14 @@ class WhereNode(tree.Node):
|
|||||||
clone.relabel_aliases(change_map)
|
clone.relabel_aliases(change_map)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
def replace_expressions(self, replacements):
|
||||||
|
if replacement := replacements.get(self):
|
||||||
|
return replacement
|
||||||
|
clone = self.create(connector=self.connector, negated=self.negated)
|
||||||
|
for child in self.children:
|
||||||
|
clone.children.append(child.replace_expressions(replacements))
|
||||||
|
return clone
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _contains_aggregate(cls, obj):
|
def _contains_aggregate(cls, obj):
|
||||||
if isinstance(obj, tree.Node):
|
if isinstance(obj, tree.Node):
|
||||||
@ -231,6 +272,10 @@ class WhereNode(tree.Node):
|
|||||||
|
|
||||||
return BooleanField()
|
return BooleanField()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _output_field_or_none(self):
|
||||||
|
return self.output_field
|
||||||
|
|
||||||
def select_format(self, compiler, sql, params):
|
def select_format(self, compiler, sql, params):
|
||||||
# Wrap filters with a CASE WHEN expression if a database backend
|
# Wrap filters with a CASE WHEN expression if a database backend
|
||||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||||
@ -245,19 +290,28 @@ class WhereNode(tree.Node):
|
|||||||
def get_lookup(self, lookup):
|
def get_lookup(self, lookup):
|
||||||
return self.output_field.get_lookup(lookup)
|
return self.output_field.get_lookup(lookup)
|
||||||
|
|
||||||
|
def leaves(self):
|
||||||
|
for child in self.children:
|
||||||
|
if isinstance(child, WhereNode):
|
||||||
|
yield from child.leaves()
|
||||||
|
else:
|
||||||
|
yield child
|
||||||
|
|
||||||
|
|
||||||
class NothingNode:
|
class NothingNode:
|
||||||
"""A node that matches nothing."""
|
"""A node that matches nothing."""
|
||||||
|
|
||||||
contains_aggregate = False
|
contains_aggregate = False
|
||||||
|
contains_over_clause = False
|
||||||
|
|
||||||
def as_sql(self, compiler=None, connection=None):
|
def as_sql(self, compiler=None, connection=None):
|
||||||
raise EmptyResultSet
|
raise EmptyResultSet
|
||||||
|
|
||||||
|
|
||||||
class ExtraWhere:
|
class ExtraWhere:
|
||||||
# The contents are a black box - assume no aggregates are used.
|
# The contents are a black box - assume no aggregates or windows are used.
|
||||||
contains_aggregate = False
|
contains_aggregate = False
|
||||||
|
contains_over_clause = False
|
||||||
|
|
||||||
def __init__(self, sqls, params):
|
def __init__(self, sqls, params):
|
||||||
self.sqls = sqls
|
self.sqls = sqls
|
||||||
@ -269,9 +323,10 @@ class ExtraWhere:
|
|||||||
|
|
||||||
|
|
||||||
class SubqueryConstraint:
|
class SubqueryConstraint:
|
||||||
# Even if aggregates would be used in a subquery, the outer query isn't
|
# Even if aggregates or windows would be used in a subquery,
|
||||||
# interested about those.
|
# the outer query isn't interested about those.
|
||||||
contains_aggregate = False
|
contains_aggregate = False
|
||||||
|
contains_over_clause = False
|
||||||
|
|
||||||
def __init__(self, alias, columns, targets, query_object):
|
def __init__(self, alias, columns, targets, query_object):
|
||||||
self.alias = alias
|
self.alias = alias
|
||||||
|
@ -741,12 +741,6 @@ instead they are part of the selected columns.
|
|||||||
|
|
||||||
.. class:: Window(expression, partition_by=None, order_by=None, frame=None, output_field=None)
|
.. class:: Window(expression, partition_by=None, order_by=None, frame=None, output_field=None)
|
||||||
|
|
||||||
.. attribute:: filterable
|
|
||||||
|
|
||||||
Defaults to ``False``. The SQL standard disallows referencing window
|
|
||||||
functions in the ``WHERE`` clause and Django raises an exception when
|
|
||||||
constructing a ``QuerySet`` that would do that.
|
|
||||||
|
|
||||||
.. attribute:: template
|
.. attribute:: template
|
||||||
|
|
||||||
Defaults to ``%(expression)s OVER (%(window)s)'``. If only the
|
Defaults to ``%(expression)s OVER (%(window)s)'``. If only the
|
||||||
@ -819,6 +813,31 @@ to reduce repetition::
|
|||||||
>>> ),
|
>>> ),
|
||||||
>>> )
|
>>> )
|
||||||
|
|
||||||
|
Filtering against window functions is supported as long as lookups are not
|
||||||
|
disjunctive (not using ``OR`` or ``XOR`` as a connector) and against a queryset
|
||||||
|
performing aggregation.
|
||||||
|
|
||||||
|
For example, a query that relies on aggregation and has an ``OR``-ed filter
|
||||||
|
against a window function and a field is not supported. Applying combined
|
||||||
|
predicates post-aggregation could cause rows that would normally be excluded
|
||||||
|
from groups to be included::
|
||||||
|
|
||||||
|
>>> qs = Movie.objects.annotate(
|
||||||
|
>>> category_rank=Window(
|
||||||
|
>>> Rank(), partition_by='category', order_by='-rating'
|
||||||
|
>>> ),
|
||||||
|
>>> scenes_count=Count('actors'),
|
||||||
|
>>> ).filter(
|
||||||
|
>>> Q(category_rank__lte=3) | Q(title__contains='Batman')
|
||||||
|
>>> )
|
||||||
|
>>> list(qs)
|
||||||
|
NotImplementedError: Heterogeneous disjunctive predicates against window functions
|
||||||
|
are not implemented when performing conditional aggregation.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.2
|
||||||
|
|
||||||
|
Support for filtering against window functions was added.
|
||||||
|
|
||||||
Among Django's built-in database backends, MySQL 8.0.2+, PostgreSQL, and Oracle
|
Among Django's built-in database backends, MySQL 8.0.2+, PostgreSQL, and Oracle
|
||||||
support window expressions. Support for different window expression features
|
support window expressions. Support for different window expression features
|
||||||
varies among the different databases. For example, the options in
|
varies among the different databases. For example, the options in
|
||||||
|
@ -189,7 +189,9 @@ Migrations
|
|||||||
Models
|
Models
|
||||||
~~~~~~
|
~~~~~~
|
||||||
|
|
||||||
* ...
|
* ``QuerySet`` now extensively supports filtering against
|
||||||
|
:ref:`window-functions` with the exception of disjunctive filter lookups
|
||||||
|
against window functions when performing aggregation.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -17,6 +17,13 @@ class Employee(models.Model):
|
|||||||
bonus = models.DecimalField(decimal_places=2, max_digits=15, null=True)
|
bonus = models.DecimalField(decimal_places=2, max_digits=15, null=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PastEmployeeDepartment(models.Model):
|
||||||
|
employee = models.ForeignKey(
|
||||||
|
Employee, related_name="past_departments", on_delete=models.CASCADE
|
||||||
|
)
|
||||||
|
department = models.CharField(max_length=40, blank=False, null=False)
|
||||||
|
|
||||||
|
|
||||||
class Detail(models.Model):
|
class Detail(models.Model):
|
||||||
value = models.JSONField()
|
value = models.JSONField()
|
||||||
|
|
||||||
|
@ -6,10 +6,9 @@ from django.core.exceptions import FieldError
|
|||||||
from django.db import NotSupportedError, connection
|
from django.db import NotSupportedError, connection
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
Avg,
|
Avg,
|
||||||
BooleanField,
|
|
||||||
Case,
|
Case,
|
||||||
|
Count,
|
||||||
F,
|
F,
|
||||||
Func,
|
|
||||||
IntegerField,
|
IntegerField,
|
||||||
Max,
|
Max,
|
||||||
Min,
|
Min,
|
||||||
@ -41,15 +40,17 @@ from django.db.models.functions import (
|
|||||||
RowNumber,
|
RowNumber,
|
||||||
Upper,
|
Upper,
|
||||||
)
|
)
|
||||||
|
from django.db.models.lookups import Exact
|
||||||
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
from .models import Detail, Employee
|
from .models import Classification, Detail, Employee, PastEmployeeDepartment
|
||||||
|
|
||||||
|
|
||||||
@skipUnlessDBFeature("supports_over_clause")
|
@skipUnlessDBFeature("supports_over_clause")
|
||||||
class WindowFunctionTests(TestCase):
|
class WindowFunctionTests(TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpTestData(cls):
|
def setUpTestData(cls):
|
||||||
|
classification = Classification.objects.create()
|
||||||
Employee.objects.bulk_create(
|
Employee.objects.bulk_create(
|
||||||
[
|
[
|
||||||
Employee(
|
Employee(
|
||||||
@ -59,6 +60,7 @@ class WindowFunctionTests(TestCase):
|
|||||||
hire_date=e[3],
|
hire_date=e[3],
|
||||||
age=e[4],
|
age=e[4],
|
||||||
bonus=Decimal(e[1]) / 400,
|
bonus=Decimal(e[1]) / 400,
|
||||||
|
classification=classification,
|
||||||
)
|
)
|
||||||
for e in [
|
for e in [
|
||||||
("Jones", 45000, "Accounting", datetime.datetime(2005, 11, 1), 20),
|
("Jones", 45000, "Accounting", datetime.datetime(2005, 11, 1), 20),
|
||||||
@ -82,6 +84,13 @@ class WindowFunctionTests(TestCase):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
employees = list(Employee.objects.order_by("pk"))
|
||||||
|
PastEmployeeDepartment.objects.bulk_create(
|
||||||
|
[
|
||||||
|
PastEmployeeDepartment(employee=employees[6], department="Sales"),
|
||||||
|
PastEmployeeDepartment(employee=employees[10], department="IT"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def test_dense_rank(self):
|
def test_dense_rank(self):
|
||||||
tests = [
|
tests = [
|
||||||
@ -902,6 +911,263 @@ class WindowFunctionTests(TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(qs.count(), 12)
|
self.assertEqual(qs.count(), 12)
|
||||||
|
|
||||||
|
def test_filter(self):
|
||||||
|
qs = Employee.objects.annotate(
|
||||||
|
department_salary_rank=Window(
|
||||||
|
Rank(), partition_by="department", order_by="-salary"
|
||||||
|
),
|
||||||
|
department_avg_age_diff=(
|
||||||
|
Window(Avg("age"), partition_by="department") - F("age")
|
||||||
|
),
|
||||||
|
).order_by("department", "name")
|
||||||
|
# Direct window reference.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.filter(department_salary_rank=1),
|
||||||
|
["Adams", "Wilkinson", "Miller", "Johnson", "Smith"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
# Through a combined expression containing a window.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.filter(department_avg_age_diff__gt=0),
|
||||||
|
["Jenson", "Jones", "Williams", "Miller", "Smith"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
# Intersection of multiple windows.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.filter(department_salary_rank=1, department_avg_age_diff__gt=0),
|
||||||
|
["Miller"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
# Union of multiple windows.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.filter(Q(department_salary_rank=1) | Q(department_avg_age_diff__gt=0)),
|
||||||
|
[
|
||||||
|
"Adams",
|
||||||
|
"Jenson",
|
||||||
|
"Jones",
|
||||||
|
"Williams",
|
||||||
|
"Wilkinson",
|
||||||
|
"Miller",
|
||||||
|
"Johnson",
|
||||||
|
"Smith",
|
||||||
|
"Smith",
|
||||||
|
],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_conditional_annotation(self):
|
||||||
|
qs = (
|
||||||
|
Employee.objects.annotate(
|
||||||
|
rank=Window(Rank(), partition_by="department", order_by="-salary"),
|
||||||
|
case_first_rank=Case(
|
||||||
|
When(rank=1, then=True),
|
||||||
|
default=False,
|
||||||
|
),
|
||||||
|
q_first_rank=Q(rank=1),
|
||||||
|
)
|
||||||
|
.order_by("name")
|
||||||
|
.values_list("name", flat=True)
|
||||||
|
)
|
||||||
|
for annotation in ["case_first_rank", "q_first_rank"]:
|
||||||
|
with self.subTest(annotation=annotation):
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(**{annotation: True}),
|
||||||
|
["Adams", "Johnson", "Miller", "Smith", "Wilkinson"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_conditional_expression(self):
|
||||||
|
qs = (
|
||||||
|
Employee.objects.filter(
|
||||||
|
Exact(Window(Rank(), partition_by="department", order_by="-salary"), 1)
|
||||||
|
)
|
||||||
|
.order_by("name")
|
||||||
|
.values_list("name", flat=True)
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_column_ref_rhs(self):
|
||||||
|
qs = (
|
||||||
|
Employee.objects.annotate(
|
||||||
|
max_dept_salary=Window(Max("salary"), partition_by="department")
|
||||||
|
)
|
||||||
|
.filter(max_dept_salary=F("salary"))
|
||||||
|
.order_by("name")
|
||||||
|
.values_list("name", flat=True)
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_values(self):
|
||||||
|
qs = (
|
||||||
|
Employee.objects.annotate(
|
||||||
|
department_salary_rank=Window(
|
||||||
|
Rank(), partition_by="department", order_by="-salary"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.order_by("department", "name")
|
||||||
|
.values_list(Upper("name"), flat=True)
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(department_salary_rank=1),
|
||||||
|
["ADAMS", "WILKINSON", "MILLER", "JOHNSON", "SMITH"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_alias(self):
|
||||||
|
qs = Employee.objects.alias(
|
||||||
|
department_avg_age_diff=(
|
||||||
|
Window(Avg("age"), partition_by="department") - F("age")
|
||||||
|
),
|
||||||
|
).order_by("department", "name")
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.filter(department_avg_age_diff__gt=0),
|
||||||
|
["Jenson", "Jones", "Williams", "Miller", "Smith"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_select_related(self):
|
||||||
|
qs = (
|
||||||
|
Employee.objects.alias(
|
||||||
|
department_avg_age_diff=(
|
||||||
|
Window(Avg("age"), partition_by="department") - F("age")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.select_related("classification")
|
||||||
|
.filter(department_avg_age_diff__gt=0)
|
||||||
|
.order_by("department", "name")
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs,
|
||||||
|
["Jenson", "Jones", "Williams", "Miller", "Smith"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
with self.assertNumQueries(0):
|
||||||
|
qs[0].classification
|
||||||
|
|
||||||
|
def test_exclude(self):
|
||||||
|
qs = Employee.objects.annotate(
|
||||||
|
department_salary_rank=Window(
|
||||||
|
Rank(), partition_by="department", order_by="-salary"
|
||||||
|
),
|
||||||
|
department_avg_age_diff=(
|
||||||
|
Window(Avg("age"), partition_by="department") - F("age")
|
||||||
|
),
|
||||||
|
).order_by("department", "name")
|
||||||
|
# Direct window reference.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.exclude(department_salary_rank__gt=1),
|
||||||
|
["Adams", "Wilkinson", "Miller", "Johnson", "Smith"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
# Through a combined expression containing a window.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.exclude(department_avg_age_diff__lte=0),
|
||||||
|
["Jenson", "Jones", "Williams", "Miller", "Smith"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
# Union of multiple windows.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.exclude(
|
||||||
|
Q(department_salary_rank__gt=1) | Q(department_avg_age_diff__lte=0)
|
||||||
|
),
|
||||||
|
["Miller"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
# Intersection of multiple windows.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
qs.exclude(department_salary_rank__gt=1, department_avg_age_diff__lte=0),
|
||||||
|
[
|
||||||
|
"Adams",
|
||||||
|
"Jenson",
|
||||||
|
"Jones",
|
||||||
|
"Williams",
|
||||||
|
"Wilkinson",
|
||||||
|
"Miller",
|
||||||
|
"Johnson",
|
||||||
|
"Smith",
|
||||||
|
"Smith",
|
||||||
|
],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_heterogeneous_filter(self):
|
||||||
|
qs = (
|
||||||
|
Employee.objects.annotate(
|
||||||
|
department_salary_rank=Window(
|
||||||
|
Rank(), partition_by="department", order_by="-salary"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.order_by("name")
|
||||||
|
.values_list("name", flat=True)
|
||||||
|
)
|
||||||
|
# Heterogeneous filter between window function and aggregates pushes
|
||||||
|
# the WHERE clause to the QUALIFY outer query.
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(
|
||||||
|
department_salary_rank=1, department__in=["Accounting", "Management"]
|
||||||
|
),
|
||||||
|
["Adams", "Miller"],
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(
|
||||||
|
Q(department_salary_rank=1)
|
||||||
|
| Q(department__in=["Accounting", "Management"])
|
||||||
|
),
|
||||||
|
[
|
||||||
|
"Adams",
|
||||||
|
"Jenson",
|
||||||
|
"Johnson",
|
||||||
|
"Johnson",
|
||||||
|
"Jones",
|
||||||
|
"Miller",
|
||||||
|
"Smith",
|
||||||
|
"Wilkinson",
|
||||||
|
"Williams",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Heterogeneous filter between window function and aggregates pushes
|
||||||
|
# the HAVING clause to the QUALIFY outer query.
|
||||||
|
qs = qs.annotate(past_department_count=Count("past_departments"))
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(department_salary_rank=1, past_department_count__gte=1),
|
||||||
|
["Johnson", "Miller"],
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(Q(department_salary_rank=1) | Q(past_department_count__gte=1)),
|
||||||
|
["Adams", "Johnson", "Miller", "Smith", "Wilkinson"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_limited_filter(self):
|
||||||
|
"""
|
||||||
|
A query filtering against a window function have its limit applied
|
||||||
|
after window filtering takes place.
|
||||||
|
"""
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Employee.objects.annotate(
|
||||||
|
department_salary_rank=Window(
|
||||||
|
Rank(), partition_by="department", order_by="-salary"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.filter(department_salary_rank=1)
|
||||||
|
.order_by("department")[0:3],
|
||||||
|
["Adams", "Wilkinson", "Miller"],
|
||||||
|
lambda employee: employee.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_count(self):
|
||||||
|
self.assertEqual(
|
||||||
|
Employee.objects.annotate(
|
||||||
|
department_salary_rank=Window(
|
||||||
|
Rank(), partition_by="department", order_by="-salary"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.filter(department_salary_rank=1)
|
||||||
|
.count(),
|
||||||
|
5,
|
||||||
|
)
|
||||||
|
|
||||||
@skipUnlessDBFeature("supports_frame_range_fixed_distance")
|
@skipUnlessDBFeature("supports_frame_range_fixed_distance")
|
||||||
def test_range_n_preceding_and_following(self):
|
def test_range_n_preceding_and_following(self):
|
||||||
qs = Employee.objects.annotate(
|
qs = Employee.objects.annotate(
|
||||||
@ -1071,6 +1337,7 @@ class WindowFunctionTests(TestCase):
|
|||||||
),
|
),
|
||||||
year=ExtractYear("hire_date"),
|
year=ExtractYear("hire_date"),
|
||||||
)
|
)
|
||||||
|
.filter(sum__gte=45000)
|
||||||
.values("year", "sum")
|
.values("year", "sum")
|
||||||
.distinct("year")
|
.distinct("year")
|
||||||
.order_by("year")
|
.order_by("year")
|
||||||
@ -1081,7 +1348,6 @@ class WindowFunctionTests(TestCase):
|
|||||||
{"year": 2008, "sum": 45000},
|
{"year": 2008, "sum": 45000},
|
||||||
{"year": 2009, "sum": 128000},
|
{"year": 2009, "sum": 128000},
|
||||||
{"year": 2011, "sum": 60000},
|
{"year": 2011, "sum": 60000},
|
||||||
{"year": 2012, "sum": 40000},
|
|
||||||
{"year": 2013, "sum": 84000},
|
{"year": 2013, "sum": 84000},
|
||||||
]
|
]
|
||||||
for idx, val in zip(range(len(results)), results):
|
for idx, val in zip(range(len(results)), results):
|
||||||
@ -1348,34 +1614,18 @@ class NonQueryWindowTests(SimpleTestCase):
|
|||||||
frame.window_frame_start_end(None, None, None)
|
frame.window_frame_start_end(None, None, None)
|
||||||
|
|
||||||
def test_invalid_filter(self):
|
def test_invalid_filter(self):
|
||||||
msg = "Window is disallowed in the filter clause"
|
msg = (
|
||||||
qs = Employee.objects.annotate(dense_rank=Window(expression=DenseRank()))
|
"Heterogeneous disjunctive predicates against window functions are not "
|
||||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
"implemented when performing conditional aggregation."
|
||||||
qs.filter(dense_rank__gte=1)
|
|
||||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
|
||||||
qs.annotate(inc_rank=F("dense_rank") + Value(1)).filter(inc_rank__gte=1)
|
|
||||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
|
||||||
qs.filter(id=F("dense_rank"))
|
|
||||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
|
||||||
qs.filter(id=Func("dense_rank", 2, function="div"))
|
|
||||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
|
||||||
qs.annotate(total=Sum("dense_rank", filter=Q(name="Jones"))).filter(total=1)
|
|
||||||
|
|
||||||
def test_conditional_annotation(self):
|
|
||||||
qs = Employee.objects.annotate(
|
|
||||||
dense_rank=Window(expression=DenseRank()),
|
|
||||||
).annotate(
|
|
||||||
equal=Case(
|
|
||||||
When(id=F("dense_rank"), then=Value(True)),
|
|
||||||
default=Value(False),
|
|
||||||
output_field=BooleanField(),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# The SQL standard disallows referencing window functions in the WHERE
|
qs = Employee.objects.annotate(
|
||||||
# clause.
|
window=Window(Rank()),
|
||||||
msg = "Window is disallowed in the filter clause"
|
past_dept_cnt=Count("past_departments"),
|
||||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
)
|
||||||
qs.filter(equal=True)
|
with self.assertRaisesMessage(NotImplementedError, msg):
|
||||||
|
list(qs.filter(Q(window=1) | Q(department="Accounting")))
|
||||||
|
with self.assertRaisesMessage(NotImplementedError, msg):
|
||||||
|
list(qs.exclude(window=1, department="Accounting"))
|
||||||
|
|
||||||
def test_invalid_order_by(self):
|
def test_invalid_order_by(self):
|
||||||
msg = (
|
msg = (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user