mirror of
				https://github.com/django/django.git
				synced 2025-10-26 07:06:08 +00:00 
			
		
		
		
	Fixed #36149 -- Allowed subquery values against tuple exact and in lookups.
Non-tuple exact and in lookups have specialized logic for subqueries that can be adapted to properly assign select mask if unspecified and ensure the number of involved members are matching on both side of the operator.
This commit is contained in:
		
				
					committed by
					
						 Sarah Boyce
						Sarah Boyce
					
				
			
			
				
	
			
			
			
						parent
						
							0597e8ad1e
						
					
				
				
					commit
					41239fe34d
				
			| @@ -1367,6 +1367,9 @@ class ColPairs(Expression): | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         return self | ||||
|  | ||||
|     def select_format(self, compiler, sql, params): | ||||
|         return sql, params | ||||
|  | ||||
|  | ||||
| class Ref(Expression): | ||||
|     """ | ||||
|   | ||||
| @@ -40,7 +40,16 @@ def get_normalized_value(value, lhs): | ||||
|  | ||||
| class RelatedIn(In): | ||||
|     def get_prep_lookup(self): | ||||
|         if not isinstance(self.lhs, ColPairs): | ||||
|         from django.db.models.sql.query import Query  # avoid circular import | ||||
|  | ||||
|         if isinstance(self.lhs, ColPairs): | ||||
|             if ( | ||||
|                 isinstance(self.rhs, Query) | ||||
|                 and not self.rhs.has_select_fields | ||||
|                 and self.lhs.output_field.related_model is self.rhs.model | ||||
|             ): | ||||
|                 self.rhs.set_values([f.name for f in self.lhs.sources]) | ||||
|         else: | ||||
|             if self.rhs_is_direct_value(): | ||||
|                 # If we get here, we are dealing with single-column relations. | ||||
|                 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] | ||||
|   | ||||
| @@ -47,7 +47,8 @@ class TupleLookupMixin: | ||||
|             self.check_rhs_is_tuple_or_list() | ||||
|             self.check_rhs_length_equals_lhs_length() | ||||
|         else: | ||||
|             self.check_rhs_is_outer_ref() | ||||
|             self.check_rhs_is_supported_expression() | ||||
|             super().get_prep_lookup() | ||||
|         return self.rhs | ||||
|  | ||||
|     def check_rhs_is_tuple_or_list(self): | ||||
| @@ -65,13 +66,13 @@ class TupleLookupMixin: | ||||
|                 f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements" | ||||
|             ) | ||||
|  | ||||
|     def check_rhs_is_outer_ref(self): | ||||
|         if not isinstance(self.rhs, ResolvedOuterRef): | ||||
|     def check_rhs_is_supported_expression(self): | ||||
|         if not isinstance(self.rhs, (ResolvedOuterRef, Query)): | ||||
|             lhs_str = self.get_lhs_str() | ||||
|             rhs_cls = self.rhs.__class__.__name__ | ||||
|             raise ValueError( | ||||
|                 f"{self.lookup_name!r} subquery lookup of {lhs_str} " | ||||
|                 f"only supports OuterRef objects (received {rhs_cls!r})" | ||||
|                 f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})" | ||||
|             ) | ||||
|  | ||||
|     def get_lhs_str(self): | ||||
| @@ -101,11 +102,14 @@ class TupleLookupMixin: | ||||
|             return compiler.compile(Tuple(*args)) | ||||
|         else: | ||||
|             sql, params = compiler.compile(self.rhs) | ||||
|             if not isinstance(self.rhs, ColPairs): | ||||
|             if isinstance(self.rhs, ColPairs): | ||||
|                 return "(%s)" % sql, params | ||||
|             elif isinstance(self.rhs, Query): | ||||
|                 return super().process_rhs(compiler, connection) | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "Composite field lookups only work with composite expressions." | ||||
|                 ) | ||||
|             return "(%s)" % sql, params | ||||
|  | ||||
|     def get_fallback_sql(self, compiler, connection): | ||||
|         raise NotImplementedError( | ||||
| @@ -121,6 +125,8 @@ class TupleLookupMixin: | ||||
|  | ||||
| class TupleExact(TupleLookupMixin, Exact): | ||||
|     def get_fallback_sql(self, compiler, connection): | ||||
|         if isinstance(self.rhs, Query): | ||||
|             return super(TupleLookupMixin, self).as_sql(compiler, connection) | ||||
|         # Process right-hand-side to trigger sanitization. | ||||
|         self.process_rhs(compiler, connection) | ||||
|         # e.g.: (a, b, c) == (x, y, z) as SQL: | ||||
| @@ -273,7 +279,7 @@ class TupleIn(TupleLookupMixin, In): | ||||
|             self.check_rhs_elements_length_equals_lhs_length() | ||||
|         else: | ||||
|             self.check_rhs_is_query() | ||||
|             self.check_rhs_select_length_equals_lhs_length() | ||||
|             super(TupleLookupMixin, self).get_prep_lookup() | ||||
|  | ||||
|         return self.rhs  # skip checks from mixin | ||||
|  | ||||
| @@ -303,19 +309,10 @@ class TupleIn(TupleLookupMixin, In): | ||||
|                 f"must be a Query object (received {rhs_cls!r})" | ||||
|             ) | ||||
|  | ||||
|     def check_rhs_select_length_equals_lhs_length(self): | ||||
|         len_rhs = len(self.rhs.select) | ||||
|         if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs): | ||||
|             len_rhs = len(self.rhs.select[0]) | ||||
|         len_lhs = len(self.lhs) | ||||
|         if len_rhs != len_lhs: | ||||
|             lhs_str = self.get_lhs_str() | ||||
|             raise ValueError( | ||||
|                 f"{self.lookup_name!r} subquery lookup of {lhs_str} " | ||||
|                 f"must have {len_lhs} fields (received {len_rhs})" | ||||
|             ) | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         if not self.rhs_is_direct_value(): | ||||
|             return super(TupleLookupMixin, self).process_rhs(compiler, connection) | ||||
|  | ||||
|         rhs = self.rhs | ||||
|         if not rhs: | ||||
|             raise EmptyResultSet | ||||
| @@ -337,19 +334,12 @@ class TupleIn(TupleLookupMixin, In): | ||||
|  | ||||
|         return compiler.compile(Tuple(*result)) | ||||
|  | ||||
|     def as_subquery_sql(self, compiler, connection): | ||||
|         lhs = self.lhs | ||||
|         rhs = self.rhs | ||||
|         if isinstance(lhs, ColPairs): | ||||
|             rhs = rhs.clone() | ||||
|             rhs.set_values([source.name for source in lhs.sources]) | ||||
|             lhs = Tuple(lhs) | ||||
|         return compiler.compile(In(lhs, rhs)) | ||||
|  | ||||
|     def get_fallback_sql(self, compiler, connection): | ||||
|         rhs = self.rhs | ||||
|         if not rhs: | ||||
|             raise EmptyResultSet | ||||
|         if not self.rhs_is_direct_value(): | ||||
|             return super(TupleLookupMixin, self).as_sql(compiler, connection) | ||||
|  | ||||
|         # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: | ||||
|         # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) | ||||
| @@ -362,11 +352,6 @@ class TupleIn(TupleLookupMixin, In): | ||||
|  | ||||
|         return root.as_sql(compiler, connection) | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if not self.rhs_is_direct_value(): | ||||
|             return self.as_subquery_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| tuple_lookups = { | ||||
|     "exact": TupleExact, | ||||
|   | ||||
| @@ -373,16 +373,21 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): | ||||
|     def get_prep_lookup(self): | ||||
|         from django.db.models.sql.query import Query  # avoid circular import | ||||
|  | ||||
|         if isinstance(self.rhs, Query): | ||||
|             if self.rhs.has_limit_one(): | ||||
|                 if not self.rhs.has_select_fields: | ||||
|                     self.rhs.clear_select_clause() | ||||
|                     self.rhs.add_fields(["pk"]) | ||||
|             else: | ||||
|         if isinstance(query := self.rhs, Query): | ||||
|             if not query.has_limit_one(): | ||||
|                 raise ValueError( | ||||
|                     "The QuerySet value for an exact lookup must be limited to " | ||||
|                     "one result using slicing." | ||||
|                 ) | ||||
|             lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1 | ||||
|             if (rhs_len := query._subquery_fields_len) != lhs_len: | ||||
|                 raise ValueError( | ||||
|                     f"The QuerySet value for the exact lookup must have {lhs_len} " | ||||
|                     f"selected fields (received {rhs_len})" | ||||
|                 ) | ||||
|             if not query.has_select_fields: | ||||
|                 query.clear_select_clause() | ||||
|                 query.add_fields(["pk"]) | ||||
|         return super().get_prep_lookup() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
| @@ -499,6 +504,12 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): | ||||
|         from django.db.models.sql.query import Query  # avoid circular import | ||||
|  | ||||
|         if isinstance(self.rhs, Query): | ||||
|             lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1 | ||||
|             if (rhs_len := self.rhs._subquery_fields_len) != lhs_len: | ||||
|                 raise ValueError( | ||||
|                     f"The QuerySet value for the 'in' lookup must have {lhs_len} " | ||||
|                     f"selected fields (received {rhs_len})" | ||||
|                 ) | ||||
|             self.rhs.clear_ordering(clear_default=True) | ||||
|             if not self.rhs.has_select_fields: | ||||
|                 self.rhs.clear_select_clause() | ||||
|   | ||||
| @@ -1958,10 +1958,6 @@ class QuerySet(AltersData): | ||||
|             self._known_related_objects.setdefault(field, {}).update(objects) | ||||
|  | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         if self._fields and len(self._fields) > 1: | ||||
|             # values() queryset can only be used as nested queries | ||||
|             # if they are set up to select only a single field. | ||||
|             raise TypeError("Cannot use multi-field values as a filter value.") | ||||
|         query = self.query.resolve_expression(*args, **kwargs) | ||||
|         query._db = self._db | ||||
|         return query | ||||
|   | ||||
| @@ -1224,6 +1224,12 @@ class Query(BaseExpression): | ||||
|         if self.selected: | ||||
|             self.selected[alias] = alias | ||||
|  | ||||
|     @property | ||||
|     def _subquery_fields_len(self): | ||||
|         if self.has_select_fields: | ||||
|             return len(self.selected) | ||||
|         return len(self.model._meta.pk_fields) | ||||
|  | ||||
|     def resolve_expression(self, query, *args, **kwargs): | ||||
|         clone = self.clone() | ||||
|         # Subqueries need to use a different set of aliases than the outer query. | ||||
|   | ||||
| @@ -44,6 +44,7 @@ class Comment(models.Model): | ||||
|         related_name="comments", | ||||
|     ) | ||||
|     text = models.TextField(default="", blank=True) | ||||
|     integer = models.IntegerField(default=0) | ||||
|  | ||||
|  | ||||
| class Post(models.Model): | ||||
|   | ||||
| @@ -10,7 +10,7 @@ from django.db.models import ( | ||||
| ) | ||||
| from django.db.models.functions import Cast | ||||
| from django.db.models.lookups import Exact | ||||
| from django.test import TestCase | ||||
| from django.test import TestCase, skipUnlessDBFeature | ||||
|  | ||||
| from .models import Comment, Tenant, User | ||||
|  | ||||
| @@ -182,6 +182,30 @@ class CompositePKFilterTests(TestCase): | ||||
|                     Comment.objects.filter(pk__in=pks).order_by("pk"), objs | ||||
|                 ) | ||||
|  | ||||
|     def test_filter_comments_by_pk_in_subquery(self): | ||||
|         self.assertSequenceEqual( | ||||
|             Comment.objects.filter( | ||||
|                 pk__in=Comment.objects.filter(pk=self.comment_1.pk), | ||||
|             ), | ||||
|             [self.comment_1], | ||||
|         ) | ||||
|         self.assertSequenceEqual( | ||||
|             Comment.objects.filter( | ||||
|                 pk__in=Comment.objects.filter(pk=self.comment_1.pk).values( | ||||
|                     "tenant_id", "id" | ||||
|                 ), | ||||
|             ), | ||||
|             [self.comment_1], | ||||
|         ) | ||||
|         self.comment_2.integer = self.comment_1.id | ||||
|         self.comment_2.save() | ||||
|         self.assertSequenceEqual( | ||||
|             Comment.objects.filter( | ||||
|                 pk__in=Comment.objects.values("tenant_id", "integer"), | ||||
|             ), | ||||
|             [self.comment_1], | ||||
|         ) | ||||
|  | ||||
|     def test_filter_comments_by_user_and_order_by_pk_asc(self): | ||||
|         self.assertSequenceEqual( | ||||
|             Comment.objects.filter(user=self.user_1).order_by("pk"), | ||||
| @@ -440,16 +464,40 @@ class CompositePKFilterTests(TestCase): | ||||
|                 queryset = Comment.objects.filter(**{f"id{lookup}": subquery}) | ||||
|                 self.assertEqual(queryset.count(), expected_count) | ||||
|  | ||||
|     def test_non_outer_ref_subquery(self): | ||||
|         # If rhs is any non-OuterRef object with an as_sql() function. | ||||
|     def test_unsupported_rhs(self): | ||||
|         pk = Exact(F("tenant_id"), 1) | ||||
|         msg = ( | ||||
|             "'exact' subquery lookup of 'pk' only supports OuterRef objects " | ||||
|             "(received 'Exact')" | ||||
|             "'exact' subquery lookup of 'pk' only supports OuterRef " | ||||
|             "and QuerySet objects (received 'Exact')" | ||||
|         ) | ||||
|         with self.assertRaisesMessage(ValueError, msg): | ||||
|             Comment.objects.filter(pk=pk) | ||||
|  | ||||
|     @skipUnlessDBFeature("allow_sliced_subqueries_with_in") | ||||
|     def test_filter_comments_by_pk_exact_subquery(self): | ||||
|         self.assertSequenceEqual( | ||||
|             Comment.objects.filter( | ||||
|                 pk=Comment.objects.filter(pk=self.comment_1.pk)[:1], | ||||
|             ), | ||||
|             [self.comment_1], | ||||
|         ) | ||||
|         self.assertSequenceEqual( | ||||
|             Comment.objects.filter( | ||||
|                 pk__in=Comment.objects.filter(pk=self.comment_1.pk).values( | ||||
|                     "tenant_id", "id" | ||||
|                 )[:1], | ||||
|             ), | ||||
|             [self.comment_1], | ||||
|         ) | ||||
|         self.comment_2.integer = self.comment_1.id | ||||
|         self.comment_2.save() | ||||
|         self.assertSequenceEqual( | ||||
|             Comment.objects.filter( | ||||
|                 pk__in=Comment.objects.values("tenant_id", "integer"), | ||||
|             )[:1], | ||||
|             [self.comment_1], | ||||
|         ) | ||||
|  | ||||
|     def test_outer_ref_not_composite_pk(self): | ||||
|         subquery = Comment.objects.filter(pk=OuterRef("id")).values("id") | ||||
|         queryset = Comment.objects.filter(id=Subquery(subquery)) | ||||
|   | ||||
| @@ -109,13 +109,10 @@ class CompositePKTests(TestCase): | ||||
|  | ||||
|     def test_composite_pk_in_fields(self): | ||||
|         user_fields = {f.name for f in User._meta.get_fields()} | ||||
|         self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"}) | ||||
|         self.assertTrue({"pk", "tenant", "id"}.issubset(user_fields)) | ||||
|  | ||||
|         comment_fields = {f.name for f in Comment._meta.get_fields()} | ||||
|         self.assertEqual( | ||||
|             comment_fields, | ||||
|             {"pk", "tenant", "id", "user_id", "user", "text"}, | ||||
|         ) | ||||
|         self.assertTrue({"pk", "tenant", "id"}.issubset(comment_fields)) | ||||
|  | ||||
|     def test_pk_field(self): | ||||
|         pk = User._meta.get_field("pk") | ||||
| @@ -174,7 +171,7 @@ class CompositePKTests(TestCase): | ||||
|             self.assertEqual(user.email, self.user.email) | ||||
|  | ||||
|     def test_model_forms(self): | ||||
|         fields = ["tenant", "id", "user_id", "text"] | ||||
|         fields = ["tenant", "id", "user_id", "text", "integer"] | ||||
|         self.assertEqual(list(CommentForm.base_fields), fields) | ||||
|  | ||||
|         form = modelform_factory(Comment, fields="__all__") | ||||
|   | ||||
| @@ -63,9 +63,11 @@ class TupleLookupsTests(TestCase): | ||||
|                 ) | ||||
|  | ||||
|     def test_exact_subquery(self): | ||||
|         with self.assertRaisesMessage( | ||||
|             ValueError, "'exact' doesn't support multi-column subqueries." | ||||
|         ): | ||||
|         msg = ( | ||||
|             "The QuerySet value for the exact lookup must have 2 selected " | ||||
|             "fields (received 1)" | ||||
|         ) | ||||
|         with self.assertRaisesMessage(ValueError, msg): | ||||
|             subquery = Customer.objects.filter(id=self.customer_1.id)[:1] | ||||
|             self.assertSequenceEqual( | ||||
|                 Contact.objects.filter(customer=subquery).order_by("id"), () | ||||
| @@ -140,11 +142,11 @@ class TupleLookupsTests(TestCase): | ||||
|     def test_tuple_in_subquery_must_have_2_fields(self): | ||||
|         lhs = (F("customer_code"), F("company_code")) | ||||
|         rhs = Customer.objects.values_list("customer_id").query | ||||
|         with self.assertRaisesMessage( | ||||
|             ValueError, | ||||
|             "'in' subquery lookup of ('customer_code', 'company_code') " | ||||
|             "must have 2 fields (received 1)", | ||||
|         ): | ||||
|         msg = ( | ||||
|             "The QuerySet value for the 'in' lookup must have 2 selected " | ||||
|             "fields (received 1)" | ||||
|         ) | ||||
|         with self.assertRaisesMessage(ValueError, msg): | ||||
|             TupleIn(lhs, rhs) | ||||
|  | ||||
|     def test_tuple_in_subquery(self): | ||||
|   | ||||
| @@ -789,6 +789,14 @@ class LookupTests(TestCase): | ||||
|         sql = ctx.captured_queries[0]["sql"] | ||||
|         self.assertIn("IN (%s)" % self.a1.pk, sql) | ||||
|  | ||||
|     def test_in_select_mismatch(self): | ||||
|         msg = ( | ||||
|             "The QuerySet value for the 'in' lookup must have 1 " | ||||
|             "selected fields (received 2)" | ||||
|         ) | ||||
|         with self.assertRaisesMessage(ValueError, msg): | ||||
|             Article.objects.filter(id__in=Article.objects.values("id", "headline")) | ||||
|  | ||||
|     def test_error_messages(self): | ||||
|         # Programming errors are pointed out with nice error messages | ||||
|         with self.assertRaisesMessage( | ||||
| @@ -1364,6 +1372,14 @@ class LookupTests(TestCase): | ||||
|         authors = Author.objects.filter(id=authors_max_ids[:1]) | ||||
|         self.assertEqual(authors.get(), newest_author) | ||||
|  | ||||
|     def test_exact_query_rhs_with_selected_columns_mismatch(self): | ||||
|         msg = ( | ||||
|             "The QuerySet value for the exact lookup must have 1 " | ||||
|             "selected fields (received 2)" | ||||
|         ) | ||||
|         with self.assertRaisesMessage(ValueError, msg): | ||||
|             Author.objects.filter(id=Author.objects.values("id", "name")[:1]) | ||||
|  | ||||
|     def test_isnull_non_boolean_value(self): | ||||
|         msg = "The QuerySet value for an isnull lookup must be True or False." | ||||
|         tests = [ | ||||
|   | ||||
| @@ -922,20 +922,6 @@ class Queries1Tests(TestCase): | ||||
|             [self.t2, self.t3], | ||||
|         ) | ||||
|  | ||||
|         # Multi-valued values() and values_list() querysets should raise errors. | ||||
|         with self.assertRaisesMessage( | ||||
|             TypeError, "Cannot use multi-field values as a filter value." | ||||
|         ): | ||||
|             Tag.objects.filter( | ||||
|                 name__in=Tag.objects.filter(parent=self.t1).values("name", "id") | ||||
|             ) | ||||
|         with self.assertRaisesMessage( | ||||
|             TypeError, "Cannot use multi-field values as a filter value." | ||||
|         ): | ||||
|             Tag.objects.filter( | ||||
|                 name__in=Tag.objects.filter(parent=self.t1).values_list("name", "id") | ||||
|             ) | ||||
|  | ||||
|     def test_ticket9985(self): | ||||
|         # qs.values_list(...).values(...) combinations should work. | ||||
|         self.assertSequenceEqual( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user