1
0
mirror of https://github.com/django/django.git synced 2024-12-27 11:35:53 +00:00

[5.0.x] Refs #31262 -- Added __eq__() and __getitem__() to BaseChoiceIterator.

This makes it easier to work with lazy iterators used for callables,
etc. when extracting items or comparing to lists, e.g. during testing.

Also added `BaseChoiceIterator.__iter__()` to make it clear that
subclasses must implement this and added `__all__` to the module.

Co-authored-by: Adam Johnson <me@adamj.eu>
Co-authored-by: Natalia Bidart <124304+nessita@users.noreply.github.com>

Backport of 07fa79ef2b from main
This commit is contained in:
Nick Pope 2023-10-16 19:25:17 +01:00 committed by Natalia
parent 08aa336af4
commit 711c054722
2 changed files with 83 additions and 7 deletions

View File

@ -1,11 +1,37 @@
from collections.abc import Callable, Iterable, Iterator, Mapping from collections.abc import Callable, Iterable, Iterator, Mapping
from itertools import islice, zip_longest
from django.utils.functional import Promise from django.utils.functional import Promise
__all__ = [
"BaseChoiceIterator",
"CallableChoiceIterator",
"normalize_choices",
]
class BaseChoiceIterator: class BaseChoiceIterator:
"""Base class for lazy iterators for choices.""" """Base class for lazy iterators for choices."""
def __eq__(self, other):
if isinstance(other, Iterable):
return all(a == b for a, b in zip_longest(self, other, fillvalue=object()))
return super().__eq__(other)
def __getitem__(self, index):
if index < 0:
# Suboptimally consume whole iterator to handle negative index.
return list(self)[index]
try:
return next(islice(self, index, index + 1))
except StopIteration:
raise IndexError("index out of range") from None
def __iter__(self):
raise NotImplementedError(
"BaseChoiceIterator subclasses must implement __iter__()."
)
class CallableChoiceIterator(BaseChoiceIterator): class CallableChoiceIterator(BaseChoiceIterator):
"""Iterator to lazily normalize choices generated by a callable.""" """Iterator to lazily normalize choices generated by a callable."""

View File

@ -2,10 +2,60 @@ from unittest import mock
from django.db.models import TextChoices from django.db.models import TextChoices
from django.test import SimpleTestCase from django.test import SimpleTestCase
from django.utils.choices import CallableChoiceIterator, normalize_choices from django.utils.choices import (
BaseChoiceIterator,
CallableChoiceIterator,
normalize_choices,
)
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
class SimpleChoiceIterator(BaseChoiceIterator):
def __iter__(self):
return ((i, f"Item #{i}") for i in range(1, 4))
class ChoiceIteratorTests(SimpleTestCase):
def test_not_implemented_error_on_missing_iter(self):
class InvalidChoiceIterator(BaseChoiceIterator):
pass # Not overriding __iter__().
msg = "BaseChoiceIterator subclasses must implement __iter__()."
with self.assertRaisesMessage(NotImplementedError, msg):
iter(InvalidChoiceIterator())
def test_eq(self):
unrolled = [(1, "Item #1"), (2, "Item #2"), (3, "Item #3")]
self.assertEqual(SimpleChoiceIterator(), unrolled)
self.assertEqual(unrolled, SimpleChoiceIterator())
def test_eq_instances(self):
self.assertEqual(SimpleChoiceIterator(), SimpleChoiceIterator())
def test_not_equal_subset(self):
self.assertNotEqual(SimpleChoiceIterator(), [(1, "Item #1"), (2, "Item #2")])
def test_not_equal_superset(self):
self.assertNotEqual(
SimpleChoiceIterator(),
[(1, "Item #1"), (2, "Item #2"), (3, "Item #3"), None],
)
def test_getitem(self):
choices = SimpleChoiceIterator()
for i, expected in [(0, (1, "Item #1")), (-1, (3, "Item #3"))]:
with self.subTest(index=i):
self.assertEqual(choices[i], expected)
def test_getitem_indexerror(self):
choices = SimpleChoiceIterator()
for i in (4, -4):
with self.subTest(index=i):
with self.assertRaises(IndexError) as ctx:
choices[i]
self.assertTrue(str(ctx.exception).endswith("index out of range"))
class NormalizeFieldChoicesTests(SimpleTestCase): class NormalizeFieldChoicesTests(SimpleTestCase):
expected = [ expected = [
("C", _("Club")), ("C", _("Club")),
@ -84,7 +134,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_choices_spy.assert_not_called() get_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator) self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected) self.assertEqual(output, self.expected)
get_choices_spy.assert_called_once() get_choices_spy.assert_called_once()
def test_mapping(self): def test_mapping(self):
@ -134,7 +184,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_media_choices_spy.assert_not_called() get_media_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator) self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected_nested) self.assertEqual(output, self.expected_nested)
get_media_choices_spy.assert_called_once() get_media_choices_spy.assert_called_once()
def test_nested_mapping(self): def test_nested_mapping(self):
@ -185,7 +235,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_choices_spy.assert_not_called() get_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator) self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected) self.assertEqual(output, self.expected)
get_choices_spy.assert_called_once() get_choices_spy.assert_called_once()
def test_iterable_non_canonical(self): def test_iterable_non_canonical(self):
@ -230,7 +280,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_media_choices_spy.assert_not_called() get_media_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator) self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected_nested) self.assertEqual(output, self.expected_nested)
get_media_choices_spy.assert_called_once() get_media_choices_spy.assert_called_once()
def test_nested_iterable_non_canonical(self): def test_nested_iterable_non_canonical(self):
@ -294,12 +344,12 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
def test_unsupported_values_from_callable_returned_unmodified(self): def test_unsupported_values_from_callable_returned_unmodified(self):
for value in self.invalid_iterable + self.invalid_nested: for value in self.invalid_iterable + self.invalid_nested:
with self.subTest(value=value): with self.subTest(value=value):
self.assertEqual(list(normalize_choices(lambda: value)), value) self.assertEqual(normalize_choices(lambda: value), value)
def test_unsupported_values_from_iterator_returned_unmodified(self): def test_unsupported_values_from_iterator_returned_unmodified(self):
for value in self.invalid_nested: for value in self.invalid_nested:
with self.subTest(value=value): with self.subTest(value=value):
self.assertEqual( self.assertEqual(
list(normalize_choices((lambda: (yield from value))())), normalize_choices((lambda: (yield from value))()),
value, value,
) )