mirror of https://github.com/django/django.git
Refs #31262 -- Added __eq__() and __getitem__() to BaseChoiceIterator.
This makes it easier to work with lazy iterators used for callables, etc. when extracting items or comparing to lists, e.g. during testing. Also added `BaseChoiceIterator.__iter__()` to make it clear that subclasses must implement this and added `__all__` to the module. Co-authored-by: Adam Johnson <me@adamj.eu> Co-authored-by: Natalia Bidart <124304+nessita@users.noreply.github.com>
This commit is contained in:
parent
e2922b0d5f
commit
07fa79ef2b
|
@ -1,11 +1,37 @@
|
|||
from collections.abc import Callable, Iterable, Iterator, Mapping
|
||||
from itertools import islice, zip_longest
|
||||
|
||||
from django.utils.functional import Promise
|
||||
|
||||
__all__ = [
|
||||
"BaseChoiceIterator",
|
||||
"CallableChoiceIterator",
|
||||
"normalize_choices",
|
||||
]
|
||||
|
||||
|
||||
class BaseChoiceIterator:
|
||||
"""Base class for lazy iterators for choices."""
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Iterable):
|
||||
return all(a == b for a, b in zip_longest(self, other, fillvalue=object()))
|
||||
return super().__eq__(other)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index < 0:
|
||||
# Suboptimally consume whole iterator to handle negative index.
|
||||
return list(self)[index]
|
||||
try:
|
||||
return next(islice(self, index, index + 1))
|
||||
except StopIteration:
|
||||
raise IndexError("index out of range") from None
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError(
|
||||
"BaseChoiceIterator subclasses must implement __iter__()."
|
||||
)
|
||||
|
||||
|
||||
class CallableChoiceIterator(BaseChoiceIterator):
|
||||
"""Iterator to lazily normalize choices generated by a callable."""
|
||||
|
|
|
@ -2,10 +2,60 @@ from unittest import mock
|
|||
|
||||
from django.db.models import TextChoices
|
||||
from django.test import SimpleTestCase
|
||||
from django.utils.choices import CallableChoiceIterator, normalize_choices
|
||||
from django.utils.choices import (
|
||||
BaseChoiceIterator,
|
||||
CallableChoiceIterator,
|
||||
normalize_choices,
|
||||
)
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class SimpleChoiceIterator(BaseChoiceIterator):
|
||||
def __iter__(self):
|
||||
return ((i, f"Item #{i}") for i in range(1, 4))
|
||||
|
||||
|
||||
class ChoiceIteratorTests(SimpleTestCase):
|
||||
def test_not_implemented_error_on_missing_iter(self):
|
||||
class InvalidChoiceIterator(BaseChoiceIterator):
|
||||
pass # Not overriding __iter__().
|
||||
|
||||
msg = "BaseChoiceIterator subclasses must implement __iter__()."
|
||||
with self.assertRaisesMessage(NotImplementedError, msg):
|
||||
iter(InvalidChoiceIterator())
|
||||
|
||||
def test_eq(self):
|
||||
unrolled = [(1, "Item #1"), (2, "Item #2"), (3, "Item #3")]
|
||||
self.assertEqual(SimpleChoiceIterator(), unrolled)
|
||||
self.assertEqual(unrolled, SimpleChoiceIterator())
|
||||
|
||||
def test_eq_instances(self):
|
||||
self.assertEqual(SimpleChoiceIterator(), SimpleChoiceIterator())
|
||||
|
||||
def test_not_equal_subset(self):
|
||||
self.assertNotEqual(SimpleChoiceIterator(), [(1, "Item #1"), (2, "Item #2")])
|
||||
|
||||
def test_not_equal_superset(self):
|
||||
self.assertNotEqual(
|
||||
SimpleChoiceIterator(),
|
||||
[(1, "Item #1"), (2, "Item #2"), (3, "Item #3"), None],
|
||||
)
|
||||
|
||||
def test_getitem(self):
|
||||
choices = SimpleChoiceIterator()
|
||||
for i, expected in [(0, (1, "Item #1")), (-1, (3, "Item #3"))]:
|
||||
with self.subTest(index=i):
|
||||
self.assertEqual(choices[i], expected)
|
||||
|
||||
def test_getitem_indexerror(self):
|
||||
choices = SimpleChoiceIterator()
|
||||
for i in (4, -4):
|
||||
with self.subTest(index=i):
|
||||
with self.assertRaises(IndexError) as ctx:
|
||||
choices[i]
|
||||
self.assertTrue(str(ctx.exception).endswith("index out of range"))
|
||||
|
||||
|
||||
class NormalizeFieldChoicesTests(SimpleTestCase):
|
||||
expected = [
|
||||
("C", _("Club")),
|
||||
|
@ -84,7 +134,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
|
|||
|
||||
get_choices_spy.assert_not_called()
|
||||
self.assertIsInstance(output, CallableChoiceIterator)
|
||||
self.assertEqual(list(output), self.expected)
|
||||
self.assertEqual(output, self.expected)
|
||||
get_choices_spy.assert_called_once()
|
||||
|
||||
def test_mapping(self):
|
||||
|
@ -134,7 +184,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
|
|||
|
||||
get_media_choices_spy.assert_not_called()
|
||||
self.assertIsInstance(output, CallableChoiceIterator)
|
||||
self.assertEqual(list(output), self.expected_nested)
|
||||
self.assertEqual(output, self.expected_nested)
|
||||
get_media_choices_spy.assert_called_once()
|
||||
|
||||
def test_nested_mapping(self):
|
||||
|
@ -185,7 +235,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
|
|||
|
||||
get_choices_spy.assert_not_called()
|
||||
self.assertIsInstance(output, CallableChoiceIterator)
|
||||
self.assertEqual(list(output), self.expected)
|
||||
self.assertEqual(output, self.expected)
|
||||
get_choices_spy.assert_called_once()
|
||||
|
||||
def test_iterable_non_canonical(self):
|
||||
|
@ -230,7 +280,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
|
|||
|
||||
get_media_choices_spy.assert_not_called()
|
||||
self.assertIsInstance(output, CallableChoiceIterator)
|
||||
self.assertEqual(list(output), self.expected_nested)
|
||||
self.assertEqual(output, self.expected_nested)
|
||||
get_media_choices_spy.assert_called_once()
|
||||
|
||||
def test_nested_iterable_non_canonical(self):
|
||||
|
@ -294,12 +344,12 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
|
|||
def test_unsupported_values_from_callable_returned_unmodified(self):
|
||||
for value in self.invalid_iterable + self.invalid_nested:
|
||||
with self.subTest(value=value):
|
||||
self.assertEqual(list(normalize_choices(lambda: value)), value)
|
||||
self.assertEqual(normalize_choices(lambda: value), value)
|
||||
|
||||
def test_unsupported_values_from_iterator_returned_unmodified(self):
|
||||
for value in self.invalid_nested:
|
||||
with self.subTest(value=value):
|
||||
self.assertEqual(
|
||||
list(normalize_choices((lambda: (yield from value))())),
|
||||
normalize_choices((lambda: (yield from value))()),
|
||||
value,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue