1
0
mirror of https://github.com/django/django.git synced 2025-10-25 22:56:12 +00:00

Refs #373 -- Improved test coverage of tuple lookup checks.

This also removed unreachable checks.
This commit is contained in:
Bendeguz Csirmaz
2024-08-15 13:27:47 +08:00
committed by Sarah Boyce
parent 2a4321ba23
commit 347ab72c02
2 changed files with 37 additions and 29 deletions

View File

@@ -1,7 +1,7 @@
import itertools import itertools
from django.core.exceptions import EmptyResultSet from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import ColPairs, Func, Value from django.db.models.expressions import Func, Value
from django.db.models.lookups import ( from django.db.models.lookups import (
Exact, Exact,
GreaterThan, GreaterThan,
@@ -24,36 +24,14 @@ class TupleLookupMixin:
return super().get_prep_lookup() return super().get_prep_lookup()
def check_tuple_lookup(self): def check_tuple_lookup(self):
assert isinstance(self.lhs, ColPairs)
self.check_rhs_is_tuple_or_list()
self.check_rhs_length_equals_lhs_length() self.check_rhs_length_equals_lhs_length()
def check_rhs_is_tuple_or_list(self):
if not isinstance(self.rhs, (tuple, list)):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
"must be a tuple or a list"
)
def check_rhs_length_equals_lhs_length(self): def check_rhs_length_equals_lhs_length(self):
if len(self.lhs) != len(self.rhs): len_lhs = len(self.lhs)
if len_lhs != len(self.rhs):
raise ValueError( raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must have {len(self.lhs)} elements" f"must have {len_lhs} elements"
)
def check_rhs_is_collection_of_tuples_or_lists(self):
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must be a collection of tuples or lists"
)
def check_rhs_elements_length_equals_lhs_length(self):
if not all(len(self.lhs) == len(vals) for vals in self.rhs):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must have {len(self.lhs)} elements each"
) )
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
@@ -192,11 +170,16 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
class TupleIn(TupleLookupMixin, In): class TupleIn(TupleLookupMixin, In):
def check_tuple_lookup(self): def check_tuple_lookup(self):
assert isinstance(self.lhs, ColPairs)
self.check_rhs_is_tuple_or_list()
self.check_rhs_is_collection_of_tuples_or_lists()
self.check_rhs_elements_length_equals_lhs_length() self.check_rhs_elements_length_equals_lhs_length()
def check_rhs_elements_length_equals_lhs_length(self):
len_lhs = len(self.lhs)
if not all(len_lhs == len(vals) for vals in self.rhs):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must have {len_lhs} elements each"
)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if not self.rhs: if not self.rhs:
raise EmptyResultSet raise EmptyResultSet

View File

@@ -240,3 +240,28 @@ class TupleLookupsTests(TestCase):
self.assertSequenceEqual( self.assertSequenceEqual(
Contact.objects.filter(customer__isnull=subquery).order_by("id"), () Contact.objects.filter(customer__isnull=subquery).order_by("id"), ()
) )
def test_lookup_errors(self):
m_2_elements = "'%s' lookup of 'customer' field must have 2 elements"
m_2_elements_each = "'in' lookup of 'customer' field must have 2 elements each"
test_cases = (
({"customer": 1}, m_2_elements % "exact"),
({"customer": (1, 2, 3)}, m_2_elements % "exact"),
({"customer__in": (1, 2, 3)}, m_2_elements_each),
({"customer__in": ("foo", "bar")}, m_2_elements_each),
({"customer__gt": 1}, m_2_elements % "gt"),
({"customer__gt": (1, 2, 3)}, m_2_elements % "gt"),
({"customer__gte": 1}, m_2_elements % "gte"),
({"customer__gte": (1, 2, 3)}, m_2_elements % "gte"),
({"customer__lt": 1}, m_2_elements % "lt"),
({"customer__lt": (1, 2, 3)}, m_2_elements % "lt"),
({"customer__lte": 1}, m_2_elements % "lte"),
({"customer__lte": (1, 2, 3)}, m_2_elements % "lte"),
)
for kwargs, message in test_cases:
with (
self.subTest(kwargs=kwargs),
self.assertRaisesMessage(ValueError, message),
):
Contact.objects.get(**kwargs)