From 41239fe34d64e801212dccaa4585e4802d0fac68 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Mon, 27 Jan 2025 23:10:13 -0500 Subject: [PATCH] Fixed #36149 -- Allowed subquery values against tuple exact and in lookups. Non-tuple exact and in lookups have specialized logic for subqueries that can be adapted to properly assign select mask if unspecified and ensure the number of involved members are matching on both side of the operator. --- django/db/models/expressions.py | 3 ++ django/db/models/fields/related_lookups.py | 11 +++- django/db/models/fields/tuple_lookups.py | 51 +++++++------------ django/db/models/lookups.py | 23 ++++++--- django/db/models/query.py | 4 -- django/db/models/sql/query.py | 6 +++ tests/composite_pk/models/tenant.py | 1 + tests/composite_pk/test_filter.py | 58 ++++++++++++++++++++-- tests/composite_pk/tests.py | 9 ++-- tests/foreign_object/test_tuple_lookups.py | 18 ++++--- tests/lookup/tests.py | 16 ++++++ tests/queries/tests.py | 14 ------ 12 files changed, 137 insertions(+), 77 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index a89acaf5a9..57ceadcec4 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1367,6 +1367,9 @@ class ColPairs(Expression): def resolve_expression(self, *args, **kwargs): return self + def select_format(self, compiler, sql, params): + return sql, params + class Ref(Expression): """ diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index a6e28b11fb..38d6308f53 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -40,7 +40,16 @@ def get_normalized_value(value, lhs): class RelatedIn(In): def get_prep_lookup(self): - if not isinstance(self.lhs, ColPairs): + from django.db.models.sql.query import Query # avoid circular import + + if isinstance(self.lhs, ColPairs): + if ( + isinstance(self.rhs, Query) + and not self.rhs.has_select_fields + and self.lhs.output_field.related_model is self.rhs.model + ): + self.rhs.set_values([f.name for f in self.lhs.sources]) + else: if self.rhs_is_direct_value(): # If we get here, we are dealing with single-column relations. self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index b45bcaf2cd..f62a49bd60 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -47,7 +47,8 @@ class TupleLookupMixin: self.check_rhs_is_tuple_or_list() self.check_rhs_length_equals_lhs_length() else: - self.check_rhs_is_outer_ref() + self.check_rhs_is_supported_expression() + super().get_prep_lookup() return self.rhs def check_rhs_is_tuple_or_list(self): @@ -65,13 +66,13 @@ class TupleLookupMixin: f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements" ) - def check_rhs_is_outer_ref(self): - if not isinstance(self.rhs, ResolvedOuterRef): + def check_rhs_is_supported_expression(self): + if not isinstance(self.rhs, (ResolvedOuterRef, Query)): lhs_str = self.get_lhs_str() rhs_cls = self.rhs.__class__.__name__ raise ValueError( f"{self.lookup_name!r} subquery lookup of {lhs_str} " - f"only supports OuterRef objects (received {rhs_cls!r})" + f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})" ) def get_lhs_str(self): @@ -101,11 +102,14 @@ class TupleLookupMixin: return compiler.compile(Tuple(*args)) else: sql, params = compiler.compile(self.rhs) - if not isinstance(self.rhs, ColPairs): + if isinstance(self.rhs, ColPairs): + return "(%s)" % sql, params + elif isinstance(self.rhs, Query): + return super().process_rhs(compiler, connection) + else: raise ValueError( "Composite field lookups only work with composite expressions." ) - return "(%s)" % sql, params def get_fallback_sql(self, compiler, connection): raise NotImplementedError( @@ -121,6 +125,8 @@ class TupleLookupMixin: class TupleExact(TupleLookupMixin, Exact): def get_fallback_sql(self, compiler, connection): + if isinstance(self.rhs, Query): + return super(TupleLookupMixin, self).as_sql(compiler, connection) # Process right-hand-side to trigger sanitization. self.process_rhs(compiler, connection) # e.g.: (a, b, c) == (x, y, z) as SQL: @@ -273,7 +279,7 @@ class TupleIn(TupleLookupMixin, In): self.check_rhs_elements_length_equals_lhs_length() else: self.check_rhs_is_query() - self.check_rhs_select_length_equals_lhs_length() + super(TupleLookupMixin, self).get_prep_lookup() return self.rhs # skip checks from mixin @@ -303,19 +309,10 @@ class TupleIn(TupleLookupMixin, In): f"must be a Query object (received {rhs_cls!r})" ) - def check_rhs_select_length_equals_lhs_length(self): - len_rhs = len(self.rhs.select) - if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs): - len_rhs = len(self.rhs.select[0]) - len_lhs = len(self.lhs) - if len_rhs != len_lhs: - lhs_str = self.get_lhs_str() - raise ValueError( - f"{self.lookup_name!r} subquery lookup of {lhs_str} " - f"must have {len_lhs} fields (received {len_rhs})" - ) - def process_rhs(self, compiler, connection): + if not self.rhs_is_direct_value(): + return super(TupleLookupMixin, self).process_rhs(compiler, connection) + rhs = self.rhs if not rhs: raise EmptyResultSet @@ -337,19 +334,12 @@ class TupleIn(TupleLookupMixin, In): return compiler.compile(Tuple(*result)) - def as_subquery_sql(self, compiler, connection): - lhs = self.lhs - rhs = self.rhs - if isinstance(lhs, ColPairs): - rhs = rhs.clone() - rhs.set_values([source.name for source in lhs.sources]) - lhs = Tuple(lhs) - return compiler.compile(In(lhs, rhs)) - def get_fallback_sql(self, compiler, connection): rhs = self.rhs if not rhs: raise EmptyResultSet + if not self.rhs_is_direct_value(): + return super(TupleLookupMixin, self).as_sql(compiler, connection) # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) @@ -362,11 +352,6 @@ class TupleIn(TupleLookupMixin, In): return root.as_sql(compiler, connection) - def as_sql(self, compiler, connection): - if not self.rhs_is_direct_value(): - return self.as_subquery_sql(compiler, connection) - return super().as_sql(compiler, connection) - tuple_lookups = { "exact": TupleExact, diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index c95f572e00..51817710e9 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -373,16 +373,21 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): def get_prep_lookup(self): from django.db.models.sql.query import Query # avoid circular import - if isinstance(self.rhs, Query): - if self.rhs.has_limit_one(): - if not self.rhs.has_select_fields: - self.rhs.clear_select_clause() - self.rhs.add_fields(["pk"]) - else: + if isinstance(query := self.rhs, Query): + if not query.has_limit_one(): raise ValueError( "The QuerySet value for an exact lookup must be limited to " "one result using slicing." ) + lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1 + if (rhs_len := query._subquery_fields_len) != lhs_len: + raise ValueError( + f"The QuerySet value for the exact lookup must have {lhs_len} " + f"selected fields (received {rhs_len})" + ) + if not query.has_select_fields: + query.clear_select_clause() + query.add_fields(["pk"]) return super().get_prep_lookup() def as_sql(self, compiler, connection): @@ -499,6 +504,12 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): from django.db.models.sql.query import Query # avoid circular import if isinstance(self.rhs, Query): + lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1 + if (rhs_len := self.rhs._subquery_fields_len) != lhs_len: + raise ValueError( + f"The QuerySet value for the 'in' lookup must have {lhs_len} " + f"selected fields (received {rhs_len})" + ) self.rhs.clear_ordering(clear_default=True) if not self.rhs.has_select_fields: self.rhs.clear_select_clause() diff --git a/django/db/models/query.py b/django/db/models/query.py index 84806a5f72..175073b961 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1958,10 +1958,6 @@ class QuerySet(AltersData): self._known_related_objects.setdefault(field, {}).update(objects) def resolve_expression(self, *args, **kwargs): - if self._fields and len(self._fields) > 1: - # values() queryset can only be used as nested queries - # if they are set up to select only a single field. - raise TypeError("Cannot use multi-field values as a filter value.") query = self.query.resolve_expression(*args, **kwargs) query._db = self._db return query diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ec47d9aa24..0d1fe5fb43 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1224,6 +1224,12 @@ class Query(BaseExpression): if self.selected: self.selected[alias] = alias + @property + def _subquery_fields_len(self): + if self.has_select_fields: + return len(self.selected) + return len(self.model._meta.pk_fields) + def resolve_expression(self, query, *args, **kwargs): clone = self.clone() # Subqueries need to use a different set of aliases than the outer query. diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py index 6286ed2354..c85869afa7 100644 --- a/tests/composite_pk/models/tenant.py +++ b/tests/composite_pk/models/tenant.py @@ -44,6 +44,7 @@ class Comment(models.Model): related_name="comments", ) text = models.TextField(default="", blank=True) + integer = models.IntegerField(default=0) class Post(models.Model): diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py index fe942b9e5b..d4c6ef13e0 100644 --- a/tests/composite_pk/test_filter.py +++ b/tests/composite_pk/test_filter.py @@ -10,7 +10,7 @@ from django.db.models import ( ) from django.db.models.functions import Cast from django.db.models.lookups import Exact -from django.test import TestCase +from django.test import TestCase, skipUnlessDBFeature from .models import Comment, Tenant, User @@ -182,6 +182,30 @@ class CompositePKFilterTests(TestCase): Comment.objects.filter(pk__in=pks).order_by("pk"), objs ) + def test_filter_comments_by_pk_in_subquery(self): + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.filter(pk=self.comment_1.pk), + ), + [self.comment_1], + ) + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.filter(pk=self.comment_1.pk).values( + "tenant_id", "id" + ), + ), + [self.comment_1], + ) + self.comment_2.integer = self.comment_1.id + self.comment_2.save() + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.values("tenant_id", "integer"), + ), + [self.comment_1], + ) + def test_filter_comments_by_user_and_order_by_pk_asc(self): self.assertSequenceEqual( Comment.objects.filter(user=self.user_1).order_by("pk"), @@ -440,16 +464,40 @@ class CompositePKFilterTests(TestCase): queryset = Comment.objects.filter(**{f"id{lookup}": subquery}) self.assertEqual(queryset.count(), expected_count) - def test_non_outer_ref_subquery(self): - # If rhs is any non-OuterRef object with an as_sql() function. + def test_unsupported_rhs(self): pk = Exact(F("tenant_id"), 1) msg = ( - "'exact' subquery lookup of 'pk' only supports OuterRef objects " - "(received 'Exact')" + "'exact' subquery lookup of 'pk' only supports OuterRef " + "and QuerySet objects (received 'Exact')" ) with self.assertRaisesMessage(ValueError, msg): Comment.objects.filter(pk=pk) + @skipUnlessDBFeature("allow_sliced_subqueries_with_in") + def test_filter_comments_by_pk_exact_subquery(self): + self.assertSequenceEqual( + Comment.objects.filter( + pk=Comment.objects.filter(pk=self.comment_1.pk)[:1], + ), + [self.comment_1], + ) + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.filter(pk=self.comment_1.pk).values( + "tenant_id", "id" + )[:1], + ), + [self.comment_1], + ) + self.comment_2.integer = self.comment_1.id + self.comment_2.save() + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.values("tenant_id", "integer"), + )[:1], + [self.comment_1], + ) + def test_outer_ref_not_composite_pk(self): subquery = Comment.objects.filter(pk=OuterRef("id")).values("id") queryset = Comment.objects.filter(id=Subquery(subquery)) diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py index 6b09480fb0..18fa53d9c0 100644 --- a/tests/composite_pk/tests.py +++ b/tests/composite_pk/tests.py @@ -109,13 +109,10 @@ class CompositePKTests(TestCase): def test_composite_pk_in_fields(self): user_fields = {f.name for f in User._meta.get_fields()} - self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"}) + self.assertTrue({"pk", "tenant", "id"}.issubset(user_fields)) comment_fields = {f.name for f in Comment._meta.get_fields()} - self.assertEqual( - comment_fields, - {"pk", "tenant", "id", "user_id", "user", "text"}, - ) + self.assertTrue({"pk", "tenant", "id"}.issubset(comment_fields)) def test_pk_field(self): pk = User._meta.get_field("pk") @@ -174,7 +171,7 @@ class CompositePKTests(TestCase): self.assertEqual(user.email, self.user.email) def test_model_forms(self): - fields = ["tenant", "id", "user_id", "text"] + fields = ["tenant", "id", "user_id", "text", "integer"] self.assertEqual(list(CommentForm.base_fields), fields) form = modelform_factory(Comment, fields="__all__") diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py index 42717c4f11..008f118994 100644 --- a/tests/foreign_object/test_tuple_lookups.py +++ b/tests/foreign_object/test_tuple_lookups.py @@ -63,9 +63,11 @@ class TupleLookupsTests(TestCase): ) def test_exact_subquery(self): - with self.assertRaisesMessage( - ValueError, "'exact' doesn't support multi-column subqueries." - ): + msg = ( + "The QuerySet value for the exact lookup must have 2 selected " + "fields (received 1)" + ) + with self.assertRaisesMessage(ValueError, msg): subquery = Customer.objects.filter(id=self.customer_1.id)[:1] self.assertSequenceEqual( Contact.objects.filter(customer=subquery).order_by("id"), () @@ -140,11 +142,11 @@ class TupleLookupsTests(TestCase): def test_tuple_in_subquery_must_have_2_fields(self): lhs = (F("customer_code"), F("company_code")) rhs = Customer.objects.values_list("customer_id").query - with self.assertRaisesMessage( - ValueError, - "'in' subquery lookup of ('customer_code', 'company_code') " - "must have 2 fields (received 1)", - ): + msg = ( + "The QuerySet value for the 'in' lookup must have 2 selected " + "fields (received 1)" + ) + with self.assertRaisesMessage(ValueError, msg): TupleIn(lhs, rhs) def test_tuple_in_subquery(self): diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index df96546d04..e19fbca521 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -789,6 +789,14 @@ class LookupTests(TestCase): sql = ctx.captured_queries[0]["sql"] self.assertIn("IN (%s)" % self.a1.pk, sql) + def test_in_select_mismatch(self): + msg = ( + "The QuerySet value for the 'in' lookup must have 1 " + "selected fields (received 2)" + ) + with self.assertRaisesMessage(ValueError, msg): + Article.objects.filter(id__in=Article.objects.values("id", "headline")) + def test_error_messages(self): # Programming errors are pointed out with nice error messages with self.assertRaisesMessage( @@ -1364,6 +1372,14 @@ class LookupTests(TestCase): authors = Author.objects.filter(id=authors_max_ids[:1]) self.assertEqual(authors.get(), newest_author) + def test_exact_query_rhs_with_selected_columns_mismatch(self): + msg = ( + "The QuerySet value for the exact lookup must have 1 " + "selected fields (received 2)" + ) + with self.assertRaisesMessage(ValueError, msg): + Author.objects.filter(id=Author.objects.values("id", "name")[:1]) + def test_isnull_non_boolean_value(self): msg = "The QuerySet value for an isnull lookup must be True or False." tests = [ diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 45866fd50f..c429a93af3 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -922,20 +922,6 @@ class Queries1Tests(TestCase): [self.t2, self.t3], ) - # Multi-valued values() and values_list() querysets should raise errors. - with self.assertRaisesMessage( - TypeError, "Cannot use multi-field values as a filter value." - ): - Tag.objects.filter( - name__in=Tag.objects.filter(parent=self.t1).values("name", "id") - ) - with self.assertRaisesMessage( - TypeError, "Cannot use multi-field values as a filter value." - ): - Tag.objects.filter( - name__in=Tag.objects.filter(parent=self.t1).values_list("name", "id") - ) - def test_ticket9985(self): # qs.values_list(...).values(...) combinations should work. self.assertSequenceEqual(