diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 0ee2ffc5c4..b1d818d90e 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -714,42 +714,50 @@ class BaseDatabaseOperations: "This backend does not support %s subtraction." % internal_type ) - def window_frame_start(self, start): - if isinstance(start, int): - if start < 0: - return "%d %s" % (abs(start), self.PRECEDING) - elif start == 0: + def window_frame_value(self, value): + if isinstance(value, int): + if value == 0: return self.CURRENT_ROW - elif start is None: - return self.UNBOUNDED_PRECEDING - raise ValueError( - "start argument must be a negative integer, zero, or None, but got '%s'." - % start - ) - - def window_frame_end(self, end): - if isinstance(end, int): - if end == 0: - return self.CURRENT_ROW - elif end > 0: - return "%d %s" % (end, self.FOLLOWING) - elif end is None: - return self.UNBOUNDED_FOLLOWING - raise ValueError( - "end argument must be a positive integer, zero, or None, but got '%s'." - % end - ) + elif value < 0: + return "%d %s" % (abs(value), self.PRECEDING) + else: + return "%d %s" % (value, self.FOLLOWING) def window_frame_rows_start_end(self, start=None, end=None): """ Return SQL for start and end points in an OVER clause window frame. """ - if not self.connection.features.supports_over_clause: - raise NotSupportedError("This backend does not support window expressions.") - return self.window_frame_start(start), self.window_frame_end(end) + if isinstance(start, int) and isinstance(end, int) and start > end: + raise ValueError("start cannot be greater than end.") + if start is not None and not isinstance(start, int): + raise ValueError( + f"start argument must be an integer, zero, or None, but got '{start}'." + ) + if end is not None and not isinstance(end, int): + raise ValueError( + f"end argument must be an integer, zero, or None, but got '{end}'." + ) + start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING + end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING + return start_, end_ def window_frame_range_start_end(self, start=None, end=None): - start_, end_ = self.window_frame_rows_start_end(start, end) + if (start is not None and not isinstance(start, int)) or ( + isinstance(start, int) and start > 0 + ): + raise ValueError( + "start argument must be a negative integer, zero, or None, " + "but got '%s'." % start + ) + if (end is not None and not isinstance(end, int)) or ( + isinstance(end, int) and end < 0 + ): + raise ValueError( + "end argument must be a positive integer, zero, or None, but got '%s'." + % end + ) + start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING + end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING features = self.connection.features if features.only_supports_unbounded_with_preceding_and_following and ( (start and start < 0) or (end and end > 0) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 9aa78a281f..b1b4aadb8e 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1895,6 +1895,8 @@ class WindowFrame(Expression): start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING) elif self.start.value is not None and self.start.value == 0: start = connection.ops.CURRENT_ROW + elif self.start.value is not None and self.start.value > 0: + start = "%d %s" % (self.start.value, connection.ops.FOLLOWING) else: start = connection.ops.UNBOUNDED_PRECEDING @@ -1902,6 +1904,8 @@ class WindowFrame(Expression): end = "%d %s" % (self.end.value, connection.ops.FOLLOWING) elif self.end.value is not None and self.end.value == 0: end = connection.ops.CURRENT_ROW + elif self.end.value is not None and self.end.value < 0: + end = "%d %s" % (abs(self.end.value), connection.ops.PRECEDING) else: end = connection.ops.UNBOUNDED_FOLLOWING return self.template % { diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 2697e948ff..6dfe6c453a 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -923,9 +923,12 @@ SQL generated by the ORM and is by default ``UNBOUNDED FOLLOWING``. The default frame includes all rows from the partition to the last row in the set. The accepted values for the ``start`` and ``end`` arguments are ``None``, an -integer, or zero. A negative integer for ``start`` results in ``N preceding``, -while ``None`` yields ``UNBOUNDED PRECEDING``. For both ``start`` and ``end``, -zero will return ``CURRENT ROW``. Positive integers are accepted for ``end``. +integer, or zero. A negative integer for ``start`` results in ``N PRECEDING``, +while ``None`` yields ``UNBOUNDED PRECEDING``. In ``ROWS`` mode, a positive +integer can be used for ```start`` resulting in ``N FOLLOWING``. Positive +integers are accepted for ``end`` and results in ``N FOLLOWING``. In ``ROWS`` +mode, a negative integer can be used for ```end`` resulting in ``N PRECEDING``. +For both ``start`` and ``end``, zero will return ``CURRENT ROW``. There's a difference in what ``CURRENT ROW`` includes. When specified in ``ROWS`` mode, the frame starts or ends with the current row. When specified in @@ -970,6 +973,11 @@ released between twelve months before and twelve months after each movie: ... ), ... ) +.. versionchanged:: 5.1 + + Support for positive integer ``start`` and negative integer ``end`` was + added for ``RowRange``. + .. currentmodule:: django.db.models Technical Information diff --git a/docs/releases/5.1.txt b/docs/releases/5.1.txt index e352b9c04f..42fb17d003 100644 --- a/docs/releases/5.1.txt +++ b/docs/releases/5.1.txt @@ -171,6 +171,9 @@ Models * :meth:`.QuerySet.explain` now supports the ``generic_plan`` option on PostgreSQL 16+. +* :class:`~django.db.models.expressions.RowRange` now accepts positive integers + for the ``start`` argument and negative integers for the ``end`` argument. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py index 3a02a36707..f73bd511e5 100644 --- a/tests/expressions_window/tests.py +++ b/tests/expressions_window/tests.py @@ -1328,6 +1328,84 @@ class WindowFunctionTests(TestCase): ), ) + def test_row_range_both_preceding(self): + """ + A query with ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING. + The resulting sum is the sum of the previous two (if they exist) rows + according to the ordering clause. + """ + qs = Employee.objects.annotate( + sum=Window( + expression=Sum("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange(start=-2, end=-1), + ) + ).order_by("hire_date") + self.assertIn("ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING", str(qs.query)) + self.assertQuerySetEqual( + qs, + [ + ("Miller", 100000, "Management", datetime.date(2005, 6, 1), None), + ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000), + ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000), + ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 125000), + ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 100000), + ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 100000), + ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 82000), + ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 90000), + ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 91000), + ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 98000), + ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 100000), + ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 90000), + ], + transform=lambda row: ( + row.name, + row.salary, + row.department, + row.hire_date, + row.sum, + ), + ) + + def test_row_range_both_following(self): + """ + A query with ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING. + The resulting sum is the sum of the following two (if they exist) rows + according to the ordering clause. + """ + qs = Employee.objects.annotate( + sum=Window( + expression=Sum("salary"), + order_by=[F("hire_date").asc(), F("name").desc()], + frame=RowRange(start=1, end=2), + ) + ).order_by("hire_date") + self.assertIn("ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING", str(qs.query)) + self.assertQuerySetEqual( + qs, + [ + ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 125000), + ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000), + ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 100000), + ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 82000), + ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 90000), + ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 91000), + ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 98000), + ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 100000), + ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 90000), + ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 84000), + ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 34000), + ("Moore", 34000, "IT", datetime.date(2013, 8, 1), None), + ], + transform=lambda row: ( + row.name, + row.salary, + row.department, + row.hire_date, + row.sum, + ), + ) + @skipUnlessDBFeature("can_distinct_on_fields") def test_distinct_window_function(self): """ @@ -1479,6 +1557,19 @@ class WindowFunctionTests(TestCase): ) ) + def test_invalid_start_end_value_for_row_range(self): + msg = "start cannot be greater than end." + with self.assertRaisesMessage(ValueError, msg): + list( + Employee.objects.annotate( + test=Window( + expression=Sum("salary"), + order_by=F("hire_date").asc(), + frame=RowRange(start=4, end=-3), + ) + ) + ) + def test_invalid_type_end_value_range(self): msg = "end argument must be a positive integer, zero, or None, but got 'a'." with self.assertRaisesMessage(ValueError, msg): @@ -1505,7 +1596,7 @@ class WindowFunctionTests(TestCase): ) def test_invalid_type_end_row_range(self): - msg = "end argument must be a positive integer, zero, or None, but got 'a'." + msg = "end argument must be an integer, zero, or None, but got 'a'." with self.assertRaisesMessage(ValueError, msg): list( Employee.objects.annotate( @@ -1551,7 +1642,7 @@ class WindowFunctionTests(TestCase): ) def test_invalid_type_start_row_range(self): - msg = "start argument must be a negative integer, zero, or None, but got 'a'." + msg = "start argument must be an integer, zero, or None, but got 'a'." with self.assertRaisesMessage(ValueError, msg): list( Employee.objects.annotate( @@ -1636,6 +1727,14 @@ class NonQueryWindowTests(SimpleTestCase): repr(RowRange(start=0, end=0)), "", ) + self.assertEqual( + repr(RowRange(start=-2, end=-1)), + "", + ) + self.assertEqual( + repr(RowRange(start=1, end=2)), + "", + ) def test_empty_group_by_cols(self): window = Window(expression=Sum("pk"))