From 711c0547224de00aee39b8720c706ac4977e89fd Mon Sep 17 00:00:00 2001 From: Nick Pope Date: Mon, 16 Oct 2023 19:25:17 +0100 Subject: [PATCH] [5.0.x] Refs #31262 -- Added __eq__() and __getitem__() to BaseChoiceIterator. This makes it easier to work with lazy iterators used for callables, etc. when extracting items or comparing to lists, e.g. during testing. Also added `BaseChoiceIterator.__iter__()` to make it clear that subclasses must implement this and added `__all__` to the module. Co-authored-by: Adam Johnson Co-authored-by: Natalia Bidart <124304+nessita@users.noreply.github.com> Backport of 07fa79ef2bb3e8cace7bd87b292c6c85230eed05 from main --- django/utils/choices.py | 26 +++++++++++++ tests/utils_tests/test_choices.py | 64 +++++++++++++++++++++++++++---- 2 files changed, 83 insertions(+), 7 deletions(-) diff --git a/django/utils/choices.py b/django/utils/choices.py index a0611d96f1..734b9331a1 100644 --- a/django/utils/choices.py +++ b/django/utils/choices.py @@ -1,11 +1,37 @@ from collections.abc import Callable, Iterable, Iterator, Mapping +from itertools import islice, zip_longest from django.utils.functional import Promise +__all__ = [ + "BaseChoiceIterator", + "CallableChoiceIterator", + "normalize_choices", +] + class BaseChoiceIterator: """Base class for lazy iterators for choices.""" + def __eq__(self, other): + if isinstance(other, Iterable): + return all(a == b for a, b in zip_longest(self, other, fillvalue=object())) + return super().__eq__(other) + + def __getitem__(self, index): + if index < 0: + # Suboptimally consume whole iterator to handle negative index. + return list(self)[index] + try: + return next(islice(self, index, index + 1)) + except StopIteration: + raise IndexError("index out of range") from None + + def __iter__(self): + raise NotImplementedError( + "BaseChoiceIterator subclasses must implement __iter__()." + ) + class CallableChoiceIterator(BaseChoiceIterator): """Iterator to lazily normalize choices generated by a callable.""" diff --git a/tests/utils_tests/test_choices.py b/tests/utils_tests/test_choices.py index d96c3d49c4..a2ad5541a4 100644 --- a/tests/utils_tests/test_choices.py +++ b/tests/utils_tests/test_choices.py @@ -2,10 +2,60 @@ from unittest import mock from django.db.models import TextChoices from django.test import SimpleTestCase -from django.utils.choices import CallableChoiceIterator, normalize_choices +from django.utils.choices import ( + BaseChoiceIterator, + CallableChoiceIterator, + normalize_choices, +) from django.utils.translation import gettext_lazy as _ +class SimpleChoiceIterator(BaseChoiceIterator): + def __iter__(self): + return ((i, f"Item #{i}") for i in range(1, 4)) + + +class ChoiceIteratorTests(SimpleTestCase): + def test_not_implemented_error_on_missing_iter(self): + class InvalidChoiceIterator(BaseChoiceIterator): + pass # Not overriding __iter__(). + + msg = "BaseChoiceIterator subclasses must implement __iter__()." + with self.assertRaisesMessage(NotImplementedError, msg): + iter(InvalidChoiceIterator()) + + def test_eq(self): + unrolled = [(1, "Item #1"), (2, "Item #2"), (3, "Item #3")] + self.assertEqual(SimpleChoiceIterator(), unrolled) + self.assertEqual(unrolled, SimpleChoiceIterator()) + + def test_eq_instances(self): + self.assertEqual(SimpleChoiceIterator(), SimpleChoiceIterator()) + + def test_not_equal_subset(self): + self.assertNotEqual(SimpleChoiceIterator(), [(1, "Item #1"), (2, "Item #2")]) + + def test_not_equal_superset(self): + self.assertNotEqual( + SimpleChoiceIterator(), + [(1, "Item #1"), (2, "Item #2"), (3, "Item #3"), None], + ) + + def test_getitem(self): + choices = SimpleChoiceIterator() + for i, expected in [(0, (1, "Item #1")), (-1, (3, "Item #3"))]: + with self.subTest(index=i): + self.assertEqual(choices[i], expected) + + def test_getitem_indexerror(self): + choices = SimpleChoiceIterator() + for i in (4, -4): + with self.subTest(index=i): + with self.assertRaises(IndexError) as ctx: + choices[i] + self.assertTrue(str(ctx.exception).endswith("index out of range")) + + class NormalizeFieldChoicesTests(SimpleTestCase): expected = [ ("C", _("Club")), @@ -84,7 +134,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase): get_choices_spy.assert_not_called() self.assertIsInstance(output, CallableChoiceIterator) - self.assertEqual(list(output), self.expected) + self.assertEqual(output, self.expected) get_choices_spy.assert_called_once() def test_mapping(self): @@ -134,7 +184,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase): get_media_choices_spy.assert_not_called() self.assertIsInstance(output, CallableChoiceIterator) - self.assertEqual(list(output), self.expected_nested) + self.assertEqual(output, self.expected_nested) get_media_choices_spy.assert_called_once() def test_nested_mapping(self): @@ -185,7 +235,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase): get_choices_spy.assert_not_called() self.assertIsInstance(output, CallableChoiceIterator) - self.assertEqual(list(output), self.expected) + self.assertEqual(output, self.expected) get_choices_spy.assert_called_once() def test_iterable_non_canonical(self): @@ -230,7 +280,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase): get_media_choices_spy.assert_not_called() self.assertIsInstance(output, CallableChoiceIterator) - self.assertEqual(list(output), self.expected_nested) + self.assertEqual(output, self.expected_nested) get_media_choices_spy.assert_called_once() def test_nested_iterable_non_canonical(self): @@ -294,12 +344,12 @@ class NormalizeFieldChoicesTests(SimpleTestCase): def test_unsupported_values_from_callable_returned_unmodified(self): for value in self.invalid_iterable + self.invalid_nested: with self.subTest(value=value): - self.assertEqual(list(normalize_choices(lambda: value)), value) + self.assertEqual(normalize_choices(lambda: value), value) def test_unsupported_values_from_iterator_returned_unmodified(self): for value in self.invalid_nested: with self.subTest(value=value): self.assertEqual( - list(normalize_choices((lambda: (yield from value))())), + normalize_choices((lambda: (yield from value))()), value, )