diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index f15b5856bf..6174b7bc98 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -15,7 +15,11 @@ from django.db import connection, connections, router from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin from django.utils import timezone -from django.utils.choices import CallableChoiceIterator, normalize_choices +from django.utils.choices import ( + CallableChoiceIterator, + flatten_choices, + normalize_choices, +) from django.utils.datastructures import DictWrapper from django.utils.dateparse import ( parse_date, @@ -1080,19 +1084,10 @@ class Field(RegisterLookupMixin): """ return str(self.value_from_object(obj)) - def _get_flatchoices(self): + @property + def flatchoices(self): """Flattened version of choices tuple.""" - if self.choices is None: - return [] - flat = [] - for choice, value in self.choices: - if isinstance(value, (list, tuple)): - flat.extend(value) - else: - flat.append((choice, value)) - return flat - - flatchoices = property(_get_flatchoices) + return list(flatten_choices(self.choices)) def save_form_data(self, instance, data): setattr(instance, self.name, data) diff --git a/django/utils/choices.py b/django/utils/choices.py index 734b9331a1..54dbdcb3aa 100644 --- a/django/utils/choices.py +++ b/django/utils/choices.py @@ -6,6 +6,7 @@ from django.utils.functional import Promise __all__ = [ "BaseChoiceIterator", "CallableChoiceIterator", + "flatten_choices", "normalize_choices", ] @@ -43,6 +44,15 @@ class CallableChoiceIterator(BaseChoiceIterator): yield from normalize_choices(self.func()) +def flatten_choices(choices): + """Flatten choices by removing nested values.""" + for value_or_group, label_or_nested in choices or (): + if isinstance(label_or_nested, (list, tuple)): + yield from label_or_nested + else: + yield value_or_group, label_or_nested + + def normalize_choices(value, *, depth=0): """Normalize choices values consistently for fields and widgets.""" # Avoid circular import when importing django.forms. diff --git a/tests/utils_tests/test_choices.py b/tests/utils_tests/test_choices.py index a2ad5541a4..e3e3766ea9 100644 --- a/tests/utils_tests/test_choices.py +++ b/tests/utils_tests/test_choices.py @@ -1,3 +1,4 @@ +import collections.abc from unittest import mock from django.db.models import TextChoices @@ -5,6 +6,7 @@ from django.test import SimpleTestCase from django.utils.choices import ( BaseChoiceIterator, CallableChoiceIterator, + flatten_choices, normalize_choices, ) from django.utils.translation import gettext_lazy as _ @@ -56,6 +58,46 @@ class ChoiceIteratorTests(SimpleTestCase): self.assertTrue(str(ctx.exception).endswith("index out of range")) +class FlattenChoicesTests(SimpleTestCase): + def test_empty(self): + def generator(): + yield from () + + for choices in ({}, [], (), set(), frozenset(), generator(), None, ""): + with self.subTest(choices=choices): + result = flatten_choices(choices) + self.assertIsInstance(result, collections.abc.Generator) + self.assertEqual(list(result), []) + + def test_non_empty(self): + choices = [ + ("C", _("Club")), + ("D", _("Diamond")), + ("H", _("Heart")), + ("S", _("Spade")), + ] + result = flatten_choices(choices) + self.assertIsInstance(result, collections.abc.Generator) + self.assertEqual(list(result), choices) + + def test_nested_choices(self): + choices = [ + ("Audio", [("vinyl", _("Vinyl")), ("cd", _("CD"))]), + ("Video", [("vhs", _("VHS Tape")), ("dvd", _("DVD"))]), + ("unknown", _("Unknown")), + ] + expected = [ + ("vinyl", _("Vinyl")), + ("cd", _("CD")), + ("vhs", _("VHS Tape")), + ("dvd", _("DVD")), + ("unknown", _("Unknown")), + ] + result = flatten_choices(choices) + self.assertIsInstance(result, collections.abc.Generator) + self.assertEqual(list(result), expected) + + class NormalizeFieldChoicesTests(SimpleTestCase): expected = [ ("C", _("Club")),