1
0
mirror of https://github.com/django/django.git synced 2025-08-21 01:09:13 +00:00

Fixed #36210, Refs #36181 -- Allowed Subquery usage in further lookups against composite pks.

Follow-up to 8561100425876bde3be4b2a22324655f74ff9609.

Co-authored-by: Simon Charette <charette.s@gmail.com>
This commit is contained in:
Jacob Walls 2025-05-11 22:04:09 -04:00 committed by Sarah Boyce
parent de7bb7eab8
commit fd569dd45b
7 changed files with 101 additions and 6 deletions

View File

@ -385,6 +385,10 @@ class BaseDatabaseFeatures:
# Does the backend support native tuple lookups (=, >, <, IN)? # Does the backend support native tuple lookups (=, >, <, IN)?
supports_tuple_lookups = True supports_tuple_lookups = True
# Does the backend support native tuple gt(e), lt(e) comparisons against
# subqueries?
supports_tuple_comparison_against_subquery = True
# Collation names for use by the Django test suite. # Collation names for use by the Django test suite.
test_collations = { test_collations = {
"ci": None, # Case-insensitive. "ci": None, # Case-insensitive.

View File

@ -21,6 +21,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_return_columns_from_insert = True can_return_columns_from_insert = True
supports_subqueries_in_group_by = False supports_subqueries_in_group_by = False
ignores_unnecessary_order_by_in_subqueries = False ignores_unnecessary_order_by_in_subqueries = False
supports_tuple_comparison_against_subquery = False
supports_transactions = True supports_transactions = True
supports_timezones = False supports_timezones = False
has_native_duration_field = True has_native_duration_field = True

View File

@ -1781,6 +1781,7 @@ class Subquery(BaseExpression, Combinable):
# Allow the usage of both QuerySet and sql.Query objects. # Allow the usage of both QuerySet and sql.Query objects.
self.query = getattr(queryset, "query", queryset).clone() self.query = getattr(queryset, "query", queryset).clone()
self.query.subquery = True self.query.subquery = True
self.template = extra.pop("template", self.template)
self.extra = extra self.extra = extra
super().__init__(output_field) super().__init__(output_field)
@ -1793,6 +1794,21 @@ class Subquery(BaseExpression, Combinable):
def _resolve_output_field(self): def _resolve_output_field(self):
return self.query.output_field return self.query.output_field
def resolve_expression(self, *args, **kwargs):
resolved = super().resolve_expression(*args, **kwargs)
if type(self) is Subquery and self.template == Subquery.template:
resolved.query.contains_subquery = True
# Subquery is an unnecessary shim for a resolved query as it
# complexifies the lookup's right-hand-side introspection.
try:
self.output_field
except AttributeError:
return resolved.query
if self.output_field and self.output_field != resolved.query.output_field:
return ExpressionWrapper(resolved.query, output_field=self.output_field)
return resolved.query
return resolved
def copy(self): def copy(self):
clone = super().copy() clone = super().copy()
clone.query = clone.query.clone() clone.query = clone.query.clone()

View File

@ -1,7 +1,7 @@
import itertools import itertools
from django.core.exceptions import EmptyResultSet from django.core.exceptions import EmptyResultSet
from django.db import models from django.db import NotSupportedError, models
from django.db.models.expressions import ( from django.db.models.expressions import (
ColPairs, ColPairs,
Exists, Exists,
@ -129,6 +129,20 @@ class TupleLookupMixin:
) )
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if (
not connection.features.supports_tuple_comparison_against_subquery
and isinstance(self.rhs, Query)
and self.rhs.subquery
and isinstance(
self, (GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual)
)
):
lookup = self.lookup_name
msg = (
f'"{lookup}" cannot be used to target composite fields '
"through subqueries on this backend"
)
raise NotSupportedError(msg)
if not connection.features.supports_tuple_lookups: if not connection.features.supports_tuple_lookups:
return self.get_fallback_sql(compiler, connection) return self.get_fallback_sql(compiler, connection)
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection)

View File

@ -242,6 +242,7 @@ class Query(BaseExpression):
filter_is_sticky = False filter_is_sticky = False
subquery = False subquery = False
contains_subquery = False
# SQL-related attributes. # SQL-related attributes.
# Select and related select clauses are expressions to use in the SELECT # Select and related select clauses are expressions to use in the SELECT

View File

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
from django.db import connection from django.db import NotSupportedError, connection
from django.db.models import ( from django.db.models import (
Case, Case,
F, F,
@ -14,7 +14,7 @@ from django.db.models import (
) )
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.db.models.lookups import Exact from django.db.models.lookups import Exact
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from .models import Comment, Tenant, User from .models import Comment, Tenant, User
@ -492,6 +492,39 @@ class CompositePKFilterTests(TestCase):
queryset = Comment.objects.filter(**{f"id{lookup}": subquery}) queryset = Comment.objects.filter(**{f"id{lookup}": subquery})
self.assertEqual(queryset.count(), expected_count) self.assertEqual(queryset.count(), expected_count)
def test_outer_ref_pk_filter_on_pk_exact(self):
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
qs = Comment.objects.filter(pk=subquery)
self.assertEqual(qs.count(), 2)
@skipUnlessDBFeature("supports_tuple_comparison_against_subquery")
def test_outer_ref_pk_filter_on_pk_comparison(self):
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
tests = [
("gt", 0),
("gte", 2),
("lt", 0),
("lte", 2),
]
for lookup, expected_count in tests:
with self.subTest(f"pk__{lookup}"):
qs = Comment.objects.filter(**{f"pk__{lookup}": subquery})
self.assertEqual(qs.count(), expected_count)
@skipIfDBFeature("supports_tuple_comparison_against_subquery")
def test_outer_ref_pk_filter_on_pk_comparison_unsupported(self):
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
tests = ["gt", "gte", "lt", "lte"]
for lookup in tests:
with self.subTest(f"pk__{lookup}"):
qs = Comment.objects.filter(**{f"pk__{lookup}": subquery})
with self.assertRaisesMessage(
NotSupportedError,
f'"{lookup}" cannot be used to target composite fields '
"through subqueries on this backend",
):
qs.count()
def test_unsupported_rhs(self): def test_unsupported_rhs(self):
pk = Exact(F("tenant_id"), 1) pk = Exact(F("tenant_id"), 1)
msg = ( msg = (
@ -561,7 +594,11 @@ class CompositePKFilterTests(TestCase):
@skipUnlessDBFeature("supports_tuple_lookups") @skipUnlessDBFeature("supports_tuple_lookups")
class CompositePKFilterTupleLookupFallbackTests(CompositePKFilterTests): class CompositePKFilterTupleLookupFallbackTests(CompositePKFilterTests):
def setUp(self): def setUp(self):
feature_patch = patch.object( feature_patch_1 = patch.object(
connection.features, "supports_tuple_lookups", False connection.features, "supports_tuple_lookups", False
) )
self.enterContext(feature_patch) feature_patch_2 = patch.object(
connection.features, "supports_tuple_comparison_against_subquery", False
)
self.enterContext(feature_patch_1)
self.enterContext(feature_patch_2)

View File

@ -988,11 +988,24 @@ class BasicExpressionsTests(TestCase):
) )
.order_by("-salary_raise") .order_by("-salary_raise")
.values("salary_raise")[:1], .values("salary_raise")[:1],
output_field=IntegerField(),
), ),
).get(pk=self.gmbh.pk) ).get(pk=self.gmbh.pk)
self.assertEqual(gmbh_salary.max_ceo_salary_raise, 2332) self.assertEqual(gmbh_salary.max_ceo_salary_raise, 2332)
def test_annotation_with_outerref_and_output_field(self):
gmbh_salary = Company.objects.annotate(
max_ceo_salary_raise=Subquery(
Company.objects.annotate(
salary_raise=OuterRef("num_employees") + F("num_employees"),
)
.order_by("-salary_raise")
.values("salary_raise")[:1],
output_field=DecimalField(),
),
).get(pk=self.gmbh.pk)
self.assertEqual(gmbh_salary.max_ceo_salary_raise, 2332.0)
self.assertIsInstance(gmbh_salary.max_ceo_salary_raise, Decimal)
def test_annotation_with_nested_outerref(self): def test_annotation_with_nested_outerref(self):
self.gmbh.point_of_contact = Employee.objects.get(lastname="Meyer") self.gmbh.point_of_contact = Employee.objects.get(lastname="Meyer")
self.gmbh.save() self.gmbh.save()
@ -2542,6 +2555,15 @@ class ExistsTests(TestCase):
self.assertSequenceEqual(qs, [manager]) self.assertSequenceEqual(qs, [manager])
self.assertIs(qs.get().exists, False) self.assertIs(qs.get().exists, False)
def test_annotate_by_empty_custom_exists(self):
class CustomExists(Exists):
template = Subquery.template
manager = Manager.objects.create()
qs = Manager.objects.annotate(exists=CustomExists(Manager.objects.none()))
self.assertSequenceEqual(qs, [manager])
self.assertIs(qs.get().exists, False)
class FieldTransformTests(TestCase): class FieldTransformTests(TestCase):
@classmethod @classmethod