diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index 468826f224..d317a69d43 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -1,7 +1,7 @@ import itertools from django.core.exceptions import EmptyResultSet -from django.db.models.expressions import ColPairs, Func, Value +from django.db.models.expressions import Func, Value from django.db.models.lookups import ( Exact, GreaterThan, @@ -24,36 +24,14 @@ class TupleLookupMixin: return super().get_prep_lookup() def check_tuple_lookup(self): - assert isinstance(self.lhs, ColPairs) - self.check_rhs_is_tuple_or_list() self.check_rhs_length_equals_lhs_length() - def check_rhs_is_tuple_or_list(self): - if not isinstance(self.rhs, (tuple, list)): - raise ValueError( - f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " - "must be a tuple or a list" - ) - def check_rhs_length_equals_lhs_length(self): - if len(self.lhs) != len(self.rhs): + len_lhs = len(self.lhs) + if len_lhs != len(self.rhs): raise ValueError( f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " - f"must have {len(self.lhs)} elements" - ) - - def check_rhs_is_collection_of_tuples_or_lists(self): - if not all(isinstance(vals, (tuple, list)) for vals in self.rhs): - raise ValueError( - f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " - f"must be a collection of tuples or lists" - ) - - def check_rhs_elements_length_equals_lhs_length(self): - if not all(len(self.lhs) == len(vals) for vals in self.rhs): - raise ValueError( - f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " - f"must have {len(self.lhs)} elements each" + f"must have {len_lhs} elements" ) def as_sql(self, compiler, connection): @@ -192,11 +170,16 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): class TupleIn(TupleLookupMixin, In): def check_tuple_lookup(self): - assert isinstance(self.lhs, ColPairs) - self.check_rhs_is_tuple_or_list() - self.check_rhs_is_collection_of_tuples_or_lists() self.check_rhs_elements_length_equals_lhs_length() + def check_rhs_elements_length_equals_lhs_length(self): + len_lhs = len(self.lhs) + if not all(len_lhs == len(vals) for vals in self.rhs): + raise ValueError( + f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " + f"must have {len_lhs} elements each" + ) + def as_sql(self, compiler, connection): if not self.rhs: raise EmptyResultSet diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py index cf080d084b..2742d6e93d 100644 --- a/tests/foreign_object/test_tuple_lookups.py +++ b/tests/foreign_object/test_tuple_lookups.py @@ -240,3 +240,28 @@ class TupleLookupsTests(TestCase): self.assertSequenceEqual( Contact.objects.filter(customer__isnull=subquery).order_by("id"), () ) + + def test_lookup_errors(self): + m_2_elements = "'%s' lookup of 'customer' field must have 2 elements" + m_2_elements_each = "'in' lookup of 'customer' field must have 2 elements each" + test_cases = ( + ({"customer": 1}, m_2_elements % "exact"), + ({"customer": (1, 2, 3)}, m_2_elements % "exact"), + ({"customer__in": (1, 2, 3)}, m_2_elements_each), + ({"customer__in": ("foo", "bar")}, m_2_elements_each), + ({"customer__gt": 1}, m_2_elements % "gt"), + ({"customer__gt": (1, 2, 3)}, m_2_elements % "gt"), + ({"customer__gte": 1}, m_2_elements % "gte"), + ({"customer__gte": (1, 2, 3)}, m_2_elements % "gte"), + ({"customer__lt": 1}, m_2_elements % "lt"), + ({"customer__lt": (1, 2, 3)}, m_2_elements % "lt"), + ({"customer__lte": 1}, m_2_elements % "lte"), + ({"customer__lte": (1, 2, 3)}, m_2_elements % "lte"), + ) + + for kwargs, message in test_cases: + with ( + self.subTest(kwargs=kwargs), + self.assertRaisesMessage(ValueError, message), + ): + Contact.objects.get(**kwargs)