1
0
mirror of https://github.com/django/django.git synced 2025-01-03 06:55:47 +00:00

Refs #29850 -- Added exclusion support to window frames.

This commit is contained in:
Sarah Boyce 2023-10-25 17:30:32 +02:00 committed by Mariusz Felisiak
parent 34b411762b
commit e4d012ca05
9 changed files with 310 additions and 5 deletions

View File

@ -263,6 +263,7 @@ class BaseDatabaseFeatures:
# Does the backend support window expressions (expression OVER (...))?
supports_over_clause = False
supports_frame_range_fixed_distance = False
supports_frame_exclusion = False
only_supports_unbounded_with_preceding_and_following = False
# Does the backend support CAST with precision?

View File

@ -162,3 +162,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def supports_primitives_in_json_field(self):
return self.connection.oracle_version >= (21,)
@cached_property
def supports_frame_exclusion(self):
return self.connection.oracle_version >= (21,)

View File

@ -61,6 +61,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"""
requires_casted_case_in_updates = True
supports_over_clause = True
supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}

View File

@ -32,6 +32,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_defer_constraint_checks = True
supports_over_clause = True
supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0)
supports_frame_exclusion = Database.sqlite_version_info >= (3, 28, 0)
supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1)
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
# NULLS LAST/FIRST emulation on < 3.30 requires subquery wrapping.

View File

@ -34,6 +34,7 @@ from django.db.models.expressions import (
When,
Window,
WindowFrame,
WindowFrameExclusion,
)
from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all
@ -91,6 +92,7 @@ __all__ += [
"When",
"Window",
"WindowFrame",
"WindowFrameExclusion",
"FileField",
"ImageField",
"GeneratedField",

View File

@ -4,6 +4,7 @@ import functools
import inspect
from collections import defaultdict
from decimal import Decimal
from enum import Enum
from types import NoneType
from uuid import UUID
@ -1848,6 +1849,16 @@ class Window(SQLiteNumericMixin, Expression):
return group_by_cols
class WindowFrameExclusion(Enum):
CURRENT_ROW = "CURRENT ROW"
GROUP = "GROUP"
TIES = "TIES"
NO_OTHERS = "NO OTHERS"
def __repr__(self):
return f"{self.__class__.__qualname__}.{self._name_}"
class WindowFrame(Expression):
"""
Model the frame clause in window expressions. There are two types of frame
@ -1857,11 +1868,17 @@ class WindowFrame(Expression):
row in the frame).
"""
template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
template = "%(frame_type)s BETWEEN %(start)s AND %(end)s%(exclude)s"
def __init__(self, start=None, end=None):
def __init__(self, start=None, end=None, exclusion=None):
self.start = Value(start)
self.end = Value(end)
if not isinstance(exclusion, (NoneType, WindowFrameExclusion)):
raise TypeError(
f"{self.__class__.__qualname__}.exclusion must be a "
"WindowFrameExclusion instance."
)
self.exclusion = exclusion
def set_source_expressions(self, exprs):
self.start, self.end = exprs
@ -1869,17 +1886,27 @@ class WindowFrame(Expression):
def get_source_expressions(self):
return [self.start, self.end]
def get_exclusion(self):
if self.exclusion is None:
return ""
return f" EXCLUDE {self.exclusion.value}"
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
start, end = self.window_frame_start_end(
connection, self.start.value, self.end.value
)
if self.exclusion and not connection.features.supports_frame_exclusion:
raise NotSupportedError(
"This backend does not support window frame exclusions."
)
return (
self.template
% {
"frame_type": self.frame_type,
"start": start,
"end": end,
"exclude": self.get_exclusion(),
},
[],
)
@ -1912,6 +1939,7 @@ class WindowFrame(Expression):
"frame_type": self.frame_type,
"start": start,
"end": end,
"exclude": self.get_exclusion(),
}
def window_frame_start_end(self, connection, start, end):

View File

@ -889,7 +889,7 @@ Frames
For a window frame, you can choose either a range-based sequence of rows or an
ordinary sequence of rows.
.. class:: ValueRange(start=None, end=None)
.. class:: ValueRange(start=None, end=None, exclusion=None)
.. attribute:: frame_type
@ -899,18 +899,48 @@ ordinary sequence of rows.
the standard start and end points, such as ``CURRENT ROW`` and ``UNBOUNDED
FOLLOWING``.
.. class:: RowRange(start=None, end=None)
.. versionchanged:: 5.1
The ``exclusion`` argument was added.
.. class:: RowRange(start=None, end=None, exclusion=None)
.. attribute:: frame_type
This attribute is set to ``'ROWS'``.
.. versionchanged:: 5.1
The ``exclusion`` argument was added.
Both classes return SQL with the template:
.. code-block:: sql
%(frame_type)s BETWEEN %(start)s AND %(end)s
.. class:: WindowFrameExclusion
.. versionadded:: 5.1
.. attribute:: CURRENT_ROW
.. attribute:: GROUP
.. attribute:: TIES
.. attribute:: NO_OTHERS
The ``exclusion`` argument allows excluding rows
(:attr:`~WindowFrameExclusion.CURRENT_ROW`), groups
(:attr:`~WindowFrameExclusion.GROUP`), and ties
(:attr:`~WindowFrameExclusion.TIES`) from the window frames on supported
databases:
.. code-block:: sql
%(frame_type)s BETWEEN %(start)s AND %(end)s EXCLUDE %(exclusion)s
Frames narrow the rows that are used for computing the result. They shift from
some start point to some specified end point. Frames can be used with and
without partitions, but it's often a good idea to specify an ordering of the

View File

@ -174,6 +174,11 @@ Models
* :class:`~django.db.models.expressions.RowRange` now accepts positive integers
for the ``start`` argument and negative integers for the ``end`` argument.
* The new ``exclusion`` argument of
:class:`~django.db.models.expressions.RowRange` and
:class:`~django.db.models.expressions.ValueRange` allows excluding rows,
groups, and ties from the window frames.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -22,6 +22,7 @@ from django.db.models import (
When,
Window,
WindowFrame,
WindowFrameExclusion,
)
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.functions import (
@ -41,7 +42,7 @@ from django.db.models.functions import (
Upper,
)
from django.db.models.lookups import Exact
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
from .models import Classification, Detail, Employee, PastEmployeeDepartment
@ -1211,6 +1212,47 @@ class WindowFunctionTests(TestCase):
ordered=False,
)
@skipUnlessDBFeature(
"supports_frame_exclusion", "supports_frame_range_fixed_distance"
)
def test_range_exclude_current(self):
qs = Employee.objects.annotate(
sum=Window(
expression=Sum("salary"),
order_by=F("salary").asc(),
partition_by="department",
frame=ValueRange(end=2, exclusion=WindowFrameExclusion.CURRENT_ROW),
)
).order_by("department", "salary")
self.assertIn(
"RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING EXCLUDE CURRENT ROW",
str(qs.query),
)
self.assertQuerySetEqual(
qs,
[
("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), None),
("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 82000),
("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 82000),
("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 127000),
("Moore", 34000, "IT", datetime.date(2013, 8, 1), None),
("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 34000),
("Johnson", 80000, "Management", datetime.date(2005, 7, 1), None),
("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000),
("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), None),
("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 38000),
("Brown", 53000, "Sales", datetime.date(2009, 9, 1), None),
("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 53000),
],
transform=lambda row: (
row.name,
row.salary,
row.department,
row.hire_date,
row.sum,
),
)
def test_range_unbound(self):
"""A query with RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING."""
qs = Employee.objects.annotate(
@ -1289,6 +1331,190 @@ class WindowFunctionTests(TestCase):
),
)
@skipUnlessDBFeature("supports_frame_exclusion")
def test_row_range_rank_exclude_current_row(self):
qs = Employee.objects.annotate(
avg_salary_cohort=Window(
expression=Avg("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(
start=-1, end=1, exclusion=WindowFrameExclusion.CURRENT_ROW
),
)
).order_by("hire_date")
self.assertIn(
"ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE CURRENT ROW",
str(qs.query),
)
self.assertQuerySetEqual(
qs,
[
("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000),
("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 72500),
("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 67500),
("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 45000),
("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 46000),
("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 49000),
("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 37500),
("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 56500),
("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 39000),
("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 55000),
("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 37000),
("Moore", 34000, "IT", datetime.date(2013, 8, 1), 50000),
],
transform=lambda row: (
row.name,
row.salary,
row.department,
row.hire_date,
row.avg_salary_cohort,
),
)
@skipUnlessDBFeature("supports_frame_exclusion")
def test_row_range_rank_exclude_group(self):
qs = Employee.objects.annotate(
avg_salary_cohort=Window(
expression=Avg("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(start=-1, end=1, exclusion=WindowFrameExclusion.GROUP),
)
).order_by("hire_date")
self.assertIn(
"ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE GROUP",
str(qs.query),
)
self.assertQuerySetEqual(
qs,
[
("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000),
("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 72500),
("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 67500),
("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 45000),
("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 46000),
("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 49000),
("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 37500),
("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 56500),
("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 39000),
("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 55000),
("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 37000),
("Moore", 34000, "IT", datetime.date(2013, 8, 1), 50000),
],
transform=lambda row: (
row.name,
row.salary,
row.department,
row.hire_date,
row.avg_salary_cohort,
),
)
@skipUnlessDBFeature("supports_frame_exclusion")
def test_row_range_rank_exclude_ties(self):
qs = Employee.objects.annotate(
sum_salary_cohort=Window(
expression=Sum("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(start=-1, end=1, exclusion=WindowFrameExclusion.TIES),
)
).order_by("hire_date")
self.assertIn(
"ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE TIES",
str(qs.query),
)
self.assertQuerySetEqual(
qs,
[
("Miller", 100000, "Management", datetime.date(2005, 6, 1), 180000),
("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 225000),
("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000),
("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 145000),
("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 137000),
("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 135000),
("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 128000),
("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 151000),
("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 138000),
("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 150000),
("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 124000),
("Moore", 34000, "IT", datetime.date(2013, 8, 1), 84000),
],
transform=lambda row: (
row.name,
row.salary,
row.department,
row.hire_date,
row.sum_salary_cohort,
),
)
@skipUnlessDBFeature("supports_frame_exclusion")
def test_row_range_rank_exclude_no_others(self):
qs = Employee.objects.annotate(
sum_salary_cohort=Window(
expression=Sum("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(
start=-1, end=1, exclusion=WindowFrameExclusion.NO_OTHERS
),
)
).order_by("hire_date")
self.assertIn(
"ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE NO OTHERS",
str(qs.query),
)
self.assertQuerySetEqual(
qs,
[
("Miller", 100000, "Management", datetime.date(2005, 6, 1), 180000),
("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 225000),
("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000),
("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 145000),
("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 137000),
("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 135000),
("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 128000),
("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 151000),
("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 138000),
("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 150000),
("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 124000),
("Moore", 34000, "IT", datetime.date(2013, 8, 1), 84000),
],
transform=lambda row: (
row.name,
row.salary,
row.department,
row.hire_date,
row.sum_salary_cohort,
),
)
@skipIfDBFeature("supports_frame_exclusion")
def test_unsupported_frame_exclusion_raises_error(self):
msg = "This backend does not support window frame exclusions."
with self.assertRaisesMessage(NotSupportedError, msg):
list(
Employee.objects.annotate(
avg_salary_cohort=Window(
expression=Avg("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(
start=-1, end=1, exclusion=WindowFrameExclusion.CURRENT_ROW
),
)
)
)
@skipUnlessDBFeature("supports_frame_exclusion")
def test_invalid_frame_exclusion_value_raises_error(self):
msg = "RowRange.exclusion must be a WindowFrameExclusion instance."
with self.assertRaisesMessage(TypeError, msg):
Employee.objects.annotate(
avg_salary_cohort=Window(
expression=Avg("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(start=-1, end=1, exclusion="RUBBISH"),
)
)
def test_row_range_rank(self):
"""
A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING.
@ -1735,6 +1961,13 @@ class NonQueryWindowTests(SimpleTestCase):
repr(RowRange(start=1, end=2)),
"<RowRange: ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING>",
)
self.assertEqual(
repr(RowRange(start=1, end=2, exclusion=WindowFrameExclusion.CURRENT_ROW)),
"<RowRange: ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING EXCLUDE CURRENT ROW>",
)
def test_window_frame_exclusion_repr(self):
self.assertEqual(repr(WindowFrameExclusion.TIES), "WindowFrameExclusion.TIES")
def test_empty_group_by_cols(self):
window = Window(expression=Sum("pk"))