1
0
mirror of https://github.com/django/django.git synced 2025-10-23 21:59:11 +00:00

Refs #28333 -- Added partial support for filtering against window functions.

Adds support for joint predicates against window annotations through
subquery wrapping while maintaining errors for disjointed filter
attempts.

The "qualify" wording was used to refer to predicates against window
annotations as it's the name of a specialized Snowflake extension to
SQL that is to window functions what HAVING is to aggregates.

While not complete the implementation should cover most of the common
use cases for filtering against window functions without requiring
the complex subquery pushdown and predicate re-aliasing machinery to
deal with disjointed predicates against columns, aggregates, and window
functions.

A complete disjointed filtering implementation should likely be
deferred until proper QUALIFY support lands or the ORM gains a proper
subquery pushdown interface.
This commit is contained in:
Simon Charette
2022-08-10 08:22:01 -04:00
committed by Mariusz Felisiak
parent f3f9d03edf
commit f387d024fc
9 changed files with 489 additions and 67 deletions

View File

@@ -6,10 +6,9 @@ from django.core.exceptions import FieldError
from django.db import NotSupportedError, connection
from django.db.models import (
Avg,
BooleanField,
Case,
Count,
F,
Func,
IntegerField,
Max,
Min,
@@ -41,15 +40,17 @@ from django.db.models.functions import (
RowNumber,
Upper,
)
from django.db.models.lookups import Exact
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import Detail, Employee
from .models import Classification, Detail, Employee, PastEmployeeDepartment
@skipUnlessDBFeature("supports_over_clause")
class WindowFunctionTests(TestCase):
@classmethod
def setUpTestData(cls):
classification = Classification.objects.create()
Employee.objects.bulk_create(
[
Employee(
@@ -59,6 +60,7 @@ class WindowFunctionTests(TestCase):
hire_date=e[3],
age=e[4],
bonus=Decimal(e[1]) / 400,
classification=classification,
)
for e in [
("Jones", 45000, "Accounting", datetime.datetime(2005, 11, 1), 20),
@@ -82,6 +84,13 @@ class WindowFunctionTests(TestCase):
]
]
)
employees = list(Employee.objects.order_by("pk"))
PastEmployeeDepartment.objects.bulk_create(
[
PastEmployeeDepartment(employee=employees[6], department="Sales"),
PastEmployeeDepartment(employee=employees[10], department="IT"),
]
)
def test_dense_rank(self):
tests = [
@@ -902,6 +911,263 @@ class WindowFunctionTests(TestCase):
)
self.assertEqual(qs.count(), 12)
def test_filter(self):
qs = Employee.objects.annotate(
department_salary_rank=Window(
Rank(), partition_by="department", order_by="-salary"
),
department_avg_age_diff=(
Window(Avg("age"), partition_by="department") - F("age")
),
).order_by("department", "name")
# Direct window reference.
self.assertQuerysetEqual(
qs.filter(department_salary_rank=1),
["Adams", "Wilkinson", "Miller", "Johnson", "Smith"],
lambda employee: employee.name,
)
# Through a combined expression containing a window.
self.assertQuerysetEqual(
qs.filter(department_avg_age_diff__gt=0),
["Jenson", "Jones", "Williams", "Miller", "Smith"],
lambda employee: employee.name,
)
# Intersection of multiple windows.
self.assertQuerysetEqual(
qs.filter(department_salary_rank=1, department_avg_age_diff__gt=0),
["Miller"],
lambda employee: employee.name,
)
# Union of multiple windows.
self.assertQuerysetEqual(
qs.filter(Q(department_salary_rank=1) | Q(department_avg_age_diff__gt=0)),
[
"Adams",
"Jenson",
"Jones",
"Williams",
"Wilkinson",
"Miller",
"Johnson",
"Smith",
"Smith",
],
lambda employee: employee.name,
)
def test_filter_conditional_annotation(self):
qs = (
Employee.objects.annotate(
rank=Window(Rank(), partition_by="department", order_by="-salary"),
case_first_rank=Case(
When(rank=1, then=True),
default=False,
),
q_first_rank=Q(rank=1),
)
.order_by("name")
.values_list("name", flat=True)
)
for annotation in ["case_first_rank", "q_first_rank"]:
with self.subTest(annotation=annotation):
self.assertSequenceEqual(
qs.filter(**{annotation: True}),
["Adams", "Johnson", "Miller", "Smith", "Wilkinson"],
)
def test_filter_conditional_expression(self):
qs = (
Employee.objects.filter(
Exact(Window(Rank(), partition_by="department", order_by="-salary"), 1)
)
.order_by("name")
.values_list("name", flat=True)
)
self.assertSequenceEqual(
qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"]
)
def test_filter_column_ref_rhs(self):
qs = (
Employee.objects.annotate(
max_dept_salary=Window(Max("salary"), partition_by="department")
)
.filter(max_dept_salary=F("salary"))
.order_by("name")
.values_list("name", flat=True)
)
self.assertSequenceEqual(
qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"]
)
def test_filter_values(self):
qs = (
Employee.objects.annotate(
department_salary_rank=Window(
Rank(), partition_by="department", order_by="-salary"
),
)
.order_by("department", "name")
.values_list(Upper("name"), flat=True)
)
self.assertSequenceEqual(
qs.filter(department_salary_rank=1),
["ADAMS", "WILKINSON", "MILLER", "JOHNSON", "SMITH"],
)
def test_filter_alias(self):
qs = Employee.objects.alias(
department_avg_age_diff=(
Window(Avg("age"), partition_by="department") - F("age")
),
).order_by("department", "name")
self.assertQuerysetEqual(
qs.filter(department_avg_age_diff__gt=0),
["Jenson", "Jones", "Williams", "Miller", "Smith"],
lambda employee: employee.name,
)
def test_filter_select_related(self):
qs = (
Employee.objects.alias(
department_avg_age_diff=(
Window(Avg("age"), partition_by="department") - F("age")
),
)
.select_related("classification")
.filter(department_avg_age_diff__gt=0)
.order_by("department", "name")
)
self.assertQuerysetEqual(
qs,
["Jenson", "Jones", "Williams", "Miller", "Smith"],
lambda employee: employee.name,
)
with self.assertNumQueries(0):
qs[0].classification
def test_exclude(self):
qs = Employee.objects.annotate(
department_salary_rank=Window(
Rank(), partition_by="department", order_by="-salary"
),
department_avg_age_diff=(
Window(Avg("age"), partition_by="department") - F("age")
),
).order_by("department", "name")
# Direct window reference.
self.assertQuerysetEqual(
qs.exclude(department_salary_rank__gt=1),
["Adams", "Wilkinson", "Miller", "Johnson", "Smith"],
lambda employee: employee.name,
)
# Through a combined expression containing a window.
self.assertQuerysetEqual(
qs.exclude(department_avg_age_diff__lte=0),
["Jenson", "Jones", "Williams", "Miller", "Smith"],
lambda employee: employee.name,
)
# Union of multiple windows.
self.assertQuerysetEqual(
qs.exclude(
Q(department_salary_rank__gt=1) | Q(department_avg_age_diff__lte=0)
),
["Miller"],
lambda employee: employee.name,
)
# Intersection of multiple windows.
self.assertQuerysetEqual(
qs.exclude(department_salary_rank__gt=1, department_avg_age_diff__lte=0),
[
"Adams",
"Jenson",
"Jones",
"Williams",
"Wilkinson",
"Miller",
"Johnson",
"Smith",
"Smith",
],
lambda employee: employee.name,
)
def test_heterogeneous_filter(self):
qs = (
Employee.objects.annotate(
department_salary_rank=Window(
Rank(), partition_by="department", order_by="-salary"
),
)
.order_by("name")
.values_list("name", flat=True)
)
# Heterogeneous filter between window function and aggregates pushes
# the WHERE clause to the QUALIFY outer query.
self.assertSequenceEqual(
qs.filter(
department_salary_rank=1, department__in=["Accounting", "Management"]
),
["Adams", "Miller"],
)
self.assertSequenceEqual(
qs.filter(
Q(department_salary_rank=1)
| Q(department__in=["Accounting", "Management"])
),
[
"Adams",
"Jenson",
"Johnson",
"Johnson",
"Jones",
"Miller",
"Smith",
"Wilkinson",
"Williams",
],
)
# Heterogeneous filter between window function and aggregates pushes
# the HAVING clause to the QUALIFY outer query.
qs = qs.annotate(past_department_count=Count("past_departments"))
self.assertSequenceEqual(
qs.filter(department_salary_rank=1, past_department_count__gte=1),
["Johnson", "Miller"],
)
self.assertSequenceEqual(
qs.filter(Q(department_salary_rank=1) | Q(past_department_count__gte=1)),
["Adams", "Johnson", "Miller", "Smith", "Wilkinson"],
)
def test_limited_filter(self):
"""
A query filtering against a window function have its limit applied
after window filtering takes place.
"""
self.assertQuerysetEqual(
Employee.objects.annotate(
department_salary_rank=Window(
Rank(), partition_by="department", order_by="-salary"
)
)
.filter(department_salary_rank=1)
.order_by("department")[0:3],
["Adams", "Wilkinson", "Miller"],
lambda employee: employee.name,
)
def test_filter_count(self):
self.assertEqual(
Employee.objects.annotate(
department_salary_rank=Window(
Rank(), partition_by="department", order_by="-salary"
)
)
.filter(department_salary_rank=1)
.count(),
5,
)
@skipUnlessDBFeature("supports_frame_range_fixed_distance")
def test_range_n_preceding_and_following(self):
qs = Employee.objects.annotate(
@@ -1071,6 +1337,7 @@ class WindowFunctionTests(TestCase):
),
year=ExtractYear("hire_date"),
)
.filter(sum__gte=45000)
.values("year", "sum")
.distinct("year")
.order_by("year")
@@ -1081,7 +1348,6 @@ class WindowFunctionTests(TestCase):
{"year": 2008, "sum": 45000},
{"year": 2009, "sum": 128000},
{"year": 2011, "sum": 60000},
{"year": 2012, "sum": 40000},
{"year": 2013, "sum": 84000},
]
for idx, val in zip(range(len(results)), results):
@@ -1348,34 +1614,18 @@ class NonQueryWindowTests(SimpleTestCase):
frame.window_frame_start_end(None, None, None)
def test_invalid_filter(self):
msg = "Window is disallowed in the filter clause"
qs = Employee.objects.annotate(dense_rank=Window(expression=DenseRank()))
with self.assertRaisesMessage(NotSupportedError, msg):
qs.filter(dense_rank__gte=1)
with self.assertRaisesMessage(NotSupportedError, msg):
qs.annotate(inc_rank=F("dense_rank") + Value(1)).filter(inc_rank__gte=1)
with self.assertRaisesMessage(NotSupportedError, msg):
qs.filter(id=F("dense_rank"))
with self.assertRaisesMessage(NotSupportedError, msg):
qs.filter(id=Func("dense_rank", 2, function="div"))
with self.assertRaisesMessage(NotSupportedError, msg):
qs.annotate(total=Sum("dense_rank", filter=Q(name="Jones"))).filter(total=1)
def test_conditional_annotation(self):
qs = Employee.objects.annotate(
dense_rank=Window(expression=DenseRank()),
).annotate(
equal=Case(
When(id=F("dense_rank"), then=Value(True)),
default=Value(False),
output_field=BooleanField(),
),
msg = (
"Heterogeneous disjunctive predicates against window functions are not "
"implemented when performing conditional aggregation."
)
# The SQL standard disallows referencing window functions in the WHERE
# clause.
msg = "Window is disallowed in the filter clause"
with self.assertRaisesMessage(NotSupportedError, msg):
qs.filter(equal=True)
qs = Employee.objects.annotate(
window=Window(Rank()),
past_dept_cnt=Count("past_departments"),
)
with self.assertRaisesMessage(NotImplementedError, msg):
list(qs.filter(Q(window=1) | Q(department="Accounting")))
with self.assertRaisesMessage(NotImplementedError, msg):
list(qs.exclude(window=1, department="Accounting"))
def test_invalid_order_by(self):
msg = (