From e4d012ca05b23445f422b0400d6dd7aa85743885 Mon Sep 17 00:00:00 2001 From: Sarah Boyce <42296566+sarahboyce@users.noreply.github.com> Date: Wed, 25 Oct 2023 17:30:32 +0200 Subject: [PATCH] Refs #29850 -- Added exclusion support to window frames. --- django/db/backends/base/features.py | 1 + django/db/backends/oracle/features.py | 4 + django/db/backends/postgresql/features.py | 1 + django/db/backends/sqlite3/features.py | 1 + django/db/models/__init__.py | 2 + django/db/models/expressions.py | 32 ++- docs/ref/models/expressions.txt | 34 +++- docs/releases/5.1.txt | 5 + tests/expressions_window/tests.py | 235 +++++++++++++++++++++- 9 files changed, 310 insertions(+), 5 deletions(-) diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index b1f0b9d491..c818fb10fe 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -263,6 +263,7 @@ class BaseDatabaseFeatures: # Does the backend support window expressions (expression OVER (...))? supports_over_clause = False supports_frame_range_fixed_distance = False + supports_frame_exclusion = False only_supports_unbounded_with_preceding_and_following = False # Does the backend support CAST with precision? diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 9b894d0df6..e04dde621d 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -162,3 +162,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def supports_primitives_in_json_field(self): return self.connection.oracle_version >= (21,) + + @cached_property + def supports_frame_exclusion(self): + return self.connection.oracle_version >= (21,) diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 0fe5d950a5..7e6455ae27 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -61,6 +61,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): """ requires_casted_case_in_updates = True supports_over_clause = True + supports_frame_exclusion = True only_supports_unbounded_with_preceding_and_following = True supports_aggregate_filter_clause = True supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"} diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 44ace18681..3d66466af5 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -32,6 +32,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_defer_constraint_checks = True supports_over_clause = True supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0) + supports_frame_exclusion = Database.sqlite_version_info >= (3, 28, 0) supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1) supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0) # NULLS LAST/FIRST emulation on < 3.30 requires subquery wrapping. diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 9426280215..3923cea591 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -34,6 +34,7 @@ from django.db.models.expressions import ( When, Window, WindowFrame, + WindowFrameExclusion, ) from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all @@ -91,6 +92,7 @@ __all__ += [ "When", "Window", "WindowFrame", + "WindowFrameExclusion", "FileField", "ImageField", "GeneratedField", diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index b1b4aadb8e..4a5946ed8d 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -4,6 +4,7 @@ import functools import inspect from collections import defaultdict from decimal import Decimal +from enum import Enum from types import NoneType from uuid import UUID @@ -1848,6 +1849,16 @@ class Window(SQLiteNumericMixin, Expression): return group_by_cols +class WindowFrameExclusion(Enum): + CURRENT_ROW = "CURRENT ROW" + GROUP = "GROUP" + TIES = "TIES" + NO_OTHERS = "NO OTHERS" + + def __repr__(self): + return f"{self.__class__.__qualname__}.{self._name_}" + + class WindowFrame(Expression): """ Model the frame clause in window expressions. There are two types of frame @@ -1857,11 +1868,17 @@ class WindowFrame(Expression): row in the frame). """ - template = "%(frame_type)s BETWEEN %(start)s AND %(end)s" + template = "%(frame_type)s BETWEEN %(start)s AND %(end)s%(exclude)s" - def __init__(self, start=None, end=None): + def __init__(self, start=None, end=None, exclusion=None): self.start = Value(start) self.end = Value(end) + if not isinstance(exclusion, (NoneType, WindowFrameExclusion)): + raise TypeError( + f"{self.__class__.__qualname__}.exclusion must be a " + "WindowFrameExclusion instance." + ) + self.exclusion = exclusion def set_source_expressions(self, exprs): self.start, self.end = exprs @@ -1869,17 +1886,27 @@ class WindowFrame(Expression): def get_source_expressions(self): return [self.start, self.end] + def get_exclusion(self): + if self.exclusion is None: + return "" + return f" EXCLUDE {self.exclusion.value}" + def as_sql(self, compiler, connection): connection.ops.check_expression_support(self) start, end = self.window_frame_start_end( connection, self.start.value, self.end.value ) + if self.exclusion and not connection.features.supports_frame_exclusion: + raise NotSupportedError( + "This backend does not support window frame exclusions." + ) return ( self.template % { "frame_type": self.frame_type, "start": start, "end": end, + "exclude": self.get_exclusion(), }, [], ) @@ -1912,6 +1939,7 @@ class WindowFrame(Expression): "frame_type": self.frame_type, "start": start, "end": end, + "exclude": self.get_exclusion(), } def window_frame_start_end(self, connection, start, end): diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 6dfe6c453a..807a946551 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -889,7 +889,7 @@ Frames For a window frame, you can choose either a range-based sequence of rows or an ordinary sequence of rows. -.. class:: ValueRange(start=None, end=None) +.. class:: ValueRange(start=None, end=None, exclusion=None) .. attribute:: frame_type @@ -899,18 +899,48 @@ ordinary sequence of rows. the standard start and end points, such as ``CURRENT ROW`` and ``UNBOUNDED FOLLOWING``. -.. class:: RowRange(start=None, end=None) + .. versionchanged:: 5.1 + + The ``exclusion`` argument was added. + +.. class:: RowRange(start=None, end=None, exclusion=None) .. attribute:: frame_type This attribute is set to ``'ROWS'``. + .. versionchanged:: 5.1 + + The ``exclusion`` argument was added. + Both classes return SQL with the template: .. code-block:: sql %(frame_type)s BETWEEN %(start)s AND %(end)s +.. class:: WindowFrameExclusion + + .. versionadded:: 5.1 + + .. attribute:: CURRENT_ROW + + .. attribute:: GROUP + + .. attribute:: TIES + + .. attribute:: NO_OTHERS + +The ``exclusion`` argument allows excluding rows +(:attr:`~WindowFrameExclusion.CURRENT_ROW`), groups +(:attr:`~WindowFrameExclusion.GROUP`), and ties +(:attr:`~WindowFrameExclusion.TIES`) from the window frames on supported +databases: + +.. code-block:: sql + + %(frame_type)s BETWEEN %(start)s AND %(end)s EXCLUDE %(exclusion)s + Frames narrow the rows that are used for computing the result. They shift from some start point to some specified end point. Frames can be used with and without partitions, but it's often a good idea to specify an ordering of the diff --git a/docs/releases/5.1.txt b/docs/releases/5.1.txt index 42fb17d003..73343b7499 100644 --- a/docs/releases/5.1.txt +++ b/docs/releases/5.1.txt @@ -174,6 +174,11 @@ Models * :class:`~django.db.models.expressions.RowRange` now accepts positive integers for the ``start`` argument and negative integers for the ``end`` argument. +* The new ``exclusion`` argument of + :class:`~django.db.models.expressions.RowRange` and + :class:`~django.db.models.expressions.ValueRange` allows excluding rows, + groups, and ties from the window frames. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py index f73bd511e5..ee51ace93b 100644 --- a/tests/expressions_window/tests.py +++ b/tests/expressions_window/tests.py @@ -22,6 +22,7 @@ from django.db.models import ( When, Window, WindowFrame, + WindowFrameExclusion, ) from django.db.models.fields.json import KeyTextTransform, KeyTransform from django.db.models.functions import ( @@ -41,7 +42,7 @@ from django.db.models.functions import ( Upper, ) from django.db.models.lookups import Exact -from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature +from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext from .models import Classification, Detail, Employee, PastEmployeeDepartment @@ -1211,6 +1212,47 @@ class WindowFunctionTests(TestCase): ordered=False, ) + @skipUnlessDBFeature( + "supports_frame_exclusion", "supports_frame_range_fixed_distance" + ) + def test_range_exclude_current(self): + qs = Employee.objects.annotate( + sum=Window( + expression=Sum("salary"), + order_by=F("salary").asc(), + partition_by="department", + frame=ValueRange(end=2, exclusion=WindowFrameExclusion.CURRENT_ROW), + ) + ).order_by("department", "salary") + self.assertIn( + "RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING EXCLUDE CURRENT ROW", + str(qs.query), + ) + self.assertQuerySetEqual( + qs, + [ + ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), None), + ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 82000), + ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 82000), + ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 127000), + ("Moore", 34000, "IT", datetime.date(2013, 8, 1), None), + ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 34000), + ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), None), + ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000), + ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), None), + ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 38000), + ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), None), + ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 53000), + ], + transform=lambda row: ( + row.name, + row.salary, + row.department, + row.hire_date, + row.sum, + ), + ) + def test_range_unbound(self): """A query with RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING.""" qs = Employee.objects.annotate( @@ -1289,6 +1331,190 @@ class WindowFunctionTests(TestCase): ), ) + @skipUnlessDBFeature("supports_frame_exclusion") + def test_row_range_rank_exclude_current_row(self): + qs = Employee.objects.annotate( + avg_salary_cohort=Window( + expression=Avg("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange( + start=-1, end=1, exclusion=WindowFrameExclusion.CURRENT_ROW + ), + ) + ).order_by("hire_date") + self.assertIn( + "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE CURRENT ROW", + str(qs.query), + ) + self.assertQuerySetEqual( + qs, + [ + ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000), + ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 72500), + ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 67500), + ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 45000), + ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 46000), + ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 49000), + ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 37500), + ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 56500), + ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 39000), + ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 55000), + ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 37000), + ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 50000), + ], + transform=lambda row: ( + row.name, + row.salary, + row.department, + row.hire_date, + row.avg_salary_cohort, + ), + ) + + @skipUnlessDBFeature("supports_frame_exclusion") + def test_row_range_rank_exclude_group(self): + qs = Employee.objects.annotate( + avg_salary_cohort=Window( + expression=Avg("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange(start=-1, end=1, exclusion=WindowFrameExclusion.GROUP), + ) + ).order_by("hire_date") + self.assertIn( + "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE GROUP", + str(qs.query), + ) + self.assertQuerySetEqual( + qs, + [ + ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000), + ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 72500), + ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 67500), + ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 45000), + ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 46000), + ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 49000), + ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 37500), + ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 56500), + ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 39000), + ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 55000), + ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 37000), + ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 50000), + ], + transform=lambda row: ( + row.name, + row.salary, + row.department, + row.hire_date, + row.avg_salary_cohort, + ), + ) + + @skipUnlessDBFeature("supports_frame_exclusion") + def test_row_range_rank_exclude_ties(self): + qs = Employee.objects.annotate( + sum_salary_cohort=Window( + expression=Sum("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange(start=-1, end=1, exclusion=WindowFrameExclusion.TIES), + ) + ).order_by("hire_date") + self.assertIn( + "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE TIES", + str(qs.query), + ) + self.assertQuerySetEqual( + qs, + [ + ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 180000), + ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 225000), + ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000), + ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 145000), + ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 137000), + ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 135000), + ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 128000), + ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 151000), + ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 138000), + ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 150000), + ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 124000), + ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 84000), + ], + transform=lambda row: ( + row.name, + row.salary, + row.department, + row.hire_date, + row.sum_salary_cohort, + ), + ) + + @skipUnlessDBFeature("supports_frame_exclusion") + def test_row_range_rank_exclude_no_others(self): + qs = Employee.objects.annotate( + sum_salary_cohort=Window( + expression=Sum("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange( + start=-1, end=1, exclusion=WindowFrameExclusion.NO_OTHERS + ), + ) + ).order_by("hire_date") + self.assertIn( + "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE NO OTHERS", + str(qs.query), + ) + self.assertQuerySetEqual( + qs, + [ + ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 180000), + ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 225000), + ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000), + ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 145000), + ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 137000), + ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 135000), + ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 128000), + ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 151000), + ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 138000), + ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 150000), + ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 124000), + ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 84000), + ], + transform=lambda row: ( + row.name, + row.salary, + row.department, + row.hire_date, + row.sum_salary_cohort, + ), + ) + + @skipIfDBFeature("supports_frame_exclusion") + def test_unsupported_frame_exclusion_raises_error(self): + msg = "This backend does not support window frame exclusions." + with self.assertRaisesMessage(NotSupportedError, msg): + list( + Employee.objects.annotate( + avg_salary_cohort=Window( + expression=Avg("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange( + start=-1, end=1, exclusion=WindowFrameExclusion.CURRENT_ROW + ), + ) + ) + ) + + @skipUnlessDBFeature("supports_frame_exclusion") + def test_invalid_frame_exclusion_value_raises_error(self): + msg = "RowRange.exclusion must be a WindowFrameExclusion instance." + with self.assertRaisesMessage(TypeError, msg): + Employee.objects.annotate( + avg_salary_cohort=Window( + expression=Avg("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange(start=-1, end=1, exclusion="RUBBISH"), + ) + ) + def test_row_range_rank(self): """ A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING. @@ -1735,6 +1961,13 @@ class NonQueryWindowTests(SimpleTestCase): repr(RowRange(start=1, end=2)), "", ) + self.assertEqual( + repr(RowRange(start=1, end=2, exclusion=WindowFrameExclusion.CURRENT_ROW)), + "", + ) + + def test_window_frame_exclusion_repr(self): + self.assertEqual(repr(WindowFrameExclusion.TIES), "WindowFrameExclusion.TIES") def test_empty_group_by_cols(self): window = Window(expression=Sum("pk"))