diff --git a/django/contrib/admin/widgets.py b/django/contrib/admin/widgets.py index c7a79f520b..fc0cd941d1 100644 --- a/django/contrib/admin/widgets.py +++ b/django/contrib/admin/widgets.py @@ -262,7 +262,6 @@ class RelatedFieldWidgetWrapper(forms.Widget): ): self.needs_multipart_form = widget.needs_multipart_form self.attrs = widget.attrs - self.choices = widget.choices self.widget = widget self.rel = rel # Backwards compatible check for whether a user can add related @@ -295,6 +294,14 @@ class RelatedFieldWidgetWrapper(forms.Widget): def media(self): return self.widget.media + @property + def choices(self): + return self.widget.choices + + @choices.setter + def choices(self, value): + self.widget.choices = value + def get_related_url(self, info, action, *args): return reverse( "admin:%s_%s_%s" % (info + (action,)), @@ -307,7 +314,6 @@ class RelatedFieldWidgetWrapper(forms.Widget): rel_opts = self.rel.model._meta info = (rel_opts.app_label, rel_opts.model_name) - self.widget.choices = self.choices related_field_name = self.rel.get_related_field().name url_params = "&".join( "%s=%s" % param diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index f958589bea..15fe9d9c9c 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -1,4 +1,3 @@ -import collections.abc import copy import datetime import decimal @@ -14,9 +13,9 @@ from django.conf import settings from django.core import checks, exceptions, validators from django.db import connection, connections, router from django.db.models.constants import LOOKUP_SEP -from django.db.models.enums import ChoicesMeta from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin from django.utils import timezone +from django.utils.choices import CallableChoiceIterator, normalize_choices from django.utils.datastructures import DictWrapper from django.utils.dateparse import ( parse_date, @@ -225,10 +224,6 @@ class Field(RegisterLookupMixin): self.unique_for_date = unique_for_date self.unique_for_month = unique_for_month self.unique_for_year = unique_for_year - if isinstance(choices, ChoicesMeta): - choices = choices.choices - if isinstance(choices, collections.abc.Iterator): - choices = list(choices) self.choices = choices self.help_text = help_text self.db_index = db_index @@ -320,10 +315,13 @@ class Field(RegisterLookupMixin): if not self.choices: return [] - if not is_iterable(self.choices) or isinstance(self.choices, str): + if not is_iterable(self.choices) or isinstance( + self.choices, (str, CallableChoiceIterator) + ): return [ checks.Error( - "'choices' must be an iterable (e.g., a list or tuple).", + "'choices' must be a mapping (e.g. a dictionary) or an iterable " + "(e.g. a list or tuple).", obj=self, id="fields.E004", ) @@ -381,8 +379,8 @@ class Field(RegisterLookupMixin): return [ checks.Error( - "'choices' must be an iterable containing " - "(actual value, human readable name) tuples.", + "'choices' must be a mapping of actual values to human readable names " + "or an iterable containing (actual value, human readable name) tuples.", obj=self, id="fields.E005", ) @@ -543,6 +541,14 @@ class Field(RegisterLookupMixin): return Col(alias, self, output_field) + @property + def choices(self): + return self._choices + + @choices.setter + def choices(self, value): + self._choices = normalize_choices(value) + @cached_property def cached_col(self): from django.db.models.expressions import Col @@ -625,9 +631,8 @@ class Field(RegisterLookupMixin): equals_comparison = {"choices", "validators"} for name, default in possibles.items(): value = getattr(self, attr_overrides.get(name, name)) - # Unroll anything iterable for choices into a concrete list - if name == "choices" and isinstance(value, collections.abc.Iterable): - value = list(value) + if isinstance(value, CallableChoiceIterator): + value = value.func # Do correct kind of comparison if name in equals_comparison: if value != default: diff --git a/django/forms/fields.py b/django/forms/fields.py index bd226de543..3a0510ffc9 100644 --- a/django/forms/fields.py +++ b/django/forms/fields.py @@ -17,7 +17,6 @@ from urllib.parse import urlsplit, urlunsplit from django.core import validators from django.core.exceptions import ValidationError -from django.db.models.enums import ChoicesMeta from django.forms.boundfield import BoundField from django.forms.utils import from_current_timezone, to_current_timezone from django.forms.widgets import ( @@ -42,6 +41,7 @@ from django.forms.widgets import ( URLInput, ) from django.utils import formats +from django.utils.choices import normalize_choices from django.utils.dateparse import parse_datetime, parse_duration from django.utils.deprecation import RemovedInDjango60Warning from django.utils.duration import duration_string @@ -861,14 +861,6 @@ class NullBooleanField(BooleanField): pass -class CallableChoiceIterator: - def __init__(self, choices_func): - self.choices_func = choices_func - - def __iter__(self): - yield from self.choices_func() - - class ChoiceField(Field): widget = Select default_error_messages = { @@ -879,8 +871,6 @@ class ChoiceField(Field): def __init__(self, *, choices=(), **kwargs): super().__init__(**kwargs) - if isinstance(choices, ChoicesMeta): - choices = choices.choices self.choices = choices def __deepcopy__(self, memo): @@ -888,21 +878,15 @@ class ChoiceField(Field): result._choices = copy.deepcopy(self._choices, memo) return result - def _get_choices(self): + @property + def choices(self): return self._choices - def _set_choices(self, value): - # Setting choices also sets the choices on the widget. - # choices can be any iterable, but we call list() on it because - # it will be consumed more than once. - if callable(value): - value = CallableChoiceIterator(value) - else: - value = list(value) - - self._choices = self.widget.choices = value - - choices = property(_get_choices, _set_choices) + @choices.setter + def choices(self, value): + # Setting choices on the field also sets the choices on the widget. + # Note that the property setter for the widget will re-normalize. + self._choices = self.widget.choices = normalize_choices(value) def to_python(self, value): """Return a string.""" diff --git a/django/forms/models.py b/django/forms/models.py index dc30d79b5d..d353da4ddc 100644 --- a/django/forms/models.py +++ b/django/forms/models.py @@ -21,6 +21,7 @@ from django.forms.widgets import ( RadioSelect, SelectMultiple, ) +from django.utils.choices import ChoiceIterator from django.utils.text import capfirst, get_text_list from django.utils.translation import gettext from django.utils.translation import gettext_lazy as _ @@ -1402,7 +1403,7 @@ class ModelChoiceIteratorValue: return self.value == other -class ModelChoiceIterator: +class ModelChoiceIterator(ChoiceIterator): def __init__(self, field): self.field = field self.queryset = field.queryset @@ -1532,7 +1533,7 @@ class ModelChoiceField(ChoiceField): # the queryset. return self.iterator(self) - choices = property(_get_choices, ChoiceField._set_choices) + choices = property(_get_choices, ChoiceField.choices.fset) def prepare_value(self, value): if hasattr(value, "_meta"): diff --git a/django/forms/widgets.py b/django/forms/widgets.py index ab7c0f755f..2c734052d5 100644 --- a/django/forms/widgets.py +++ b/django/forms/widgets.py @@ -12,6 +12,7 @@ from itertools import chain from django.forms.utils import to_current_timezone from django.templatetags.static import static from django.utils import formats +from django.utils.choices import normalize_choices from django.utils.dates import MONTHS from django.utils.formats import get_format from django.utils.html import format_html, html_safe @@ -620,10 +621,7 @@ class ChoiceWidget(Widget): def __init__(self, attrs=None, choices=()): super().__init__(attrs) - # choices can be any iterable, but we may need to render this widget - # multiple times. Thus, collapse it into a list so it can be consumed - # more than once. - self.choices = list(choices) + self.choices = choices def __deepcopy__(self, memo): obj = copy.copy(self) @@ -741,6 +739,14 @@ class ChoiceWidget(Widget): value = [value] return [str(v) if v is not None else "" for v in value] + @property + def choices(self): + return self._choices + + @choices.setter + def choices(self, value): + self._choices = normalize_choices(value) + class Select(ChoiceWidget): input_type = "select" diff --git a/django/utils/choices.py b/django/utils/choices.py new file mode 100644 index 0000000000..fc8267af34 --- /dev/null +++ b/django/utils/choices.py @@ -0,0 +1,63 @@ +from collections.abc import Callable, Iterable, Iterator, Mapping + +from django.db.models.enums import ChoicesMeta +from django.utils.functional import Promise + + +class ChoiceIterator: + """Base class for lazy iterators for choices.""" + + +class CallableChoiceIterator(ChoiceIterator): + """Iterator to lazily normalize choices generated by a callable.""" + + def __init__(self, func): + self.func = func + + def __iter__(self): + yield from normalize_choices(self.func()) + + +def normalize_choices(value, *, depth=0): + """Normalize choices values consistently for fields and widgets.""" + + match value: + case ChoiceIterator() | Promise() | bytes() | str(): + # Avoid prematurely normalizing iterators that should be lazy. + # Because string-like types are iterable, return early to avoid + # iterating over them in the guard for the Iterable case below. + return value + case ChoicesMeta(): + # Choices enumeration helpers already output in canonical form. + return value.choices + case Mapping() if depth < 2: + value = value.items() + case Iterator() if depth < 2: + # Although Iterator would be handled by the Iterable case below, + # the iterator would be consumed prematurely while checking that + # its elements are not string-like in the guard, so we handle it + # separately. + pass + case Iterable() if depth < 2 and not any( + isinstance(x, (Promise, bytes, str)) for x in value + ): + # String-like types are iterable, so the guard above ensures that + # they're handled by the default case below. + pass + case Callable() if depth == 0: + # If at the top level, wrap callables to be evaluated lazily. + return CallableChoiceIterator(value) + case Callable() if depth < 2: + value = value() + case _: + return value + + try: + # Recursive call to convert any nested values to a list of 2-tuples. + return [(k, normalize_choices(v, depth=depth + 1)) for k, v in value] + except (TypeError, ValueError): + # Return original value for the system check to raise if it has items + # that are not iterable or not 2-tuples: + # - TypeError: cannot unpack non-iterable object + # - ValueError: values to unpack + return value diff --git a/docs/internals/contributing/writing-code/coding-style.txt b/docs/internals/contributing/writing-code/coding-style.txt index d227e04ba0..6871d43d7b 100644 --- a/docs/internals/contributing/writing-code/coding-style.txt +++ b/docs/internals/contributing/writing-code/coding-style.txt @@ -298,16 +298,23 @@ Model style * Any custom methods * If ``choices`` is defined for a given model field, define each choice as a - list of tuples, with an all-uppercase name as a class attribute on the model. + mapping, with an all-uppercase name as a class attribute on the model. Example:: class MyModel(models.Model): DIRECTION_UP = "U" DIRECTION_DOWN = "D" - DIRECTION_CHOICES = [ - (DIRECTION_UP, "Up"), - (DIRECTION_DOWN, "Down"), - ] + DIRECTION_CHOICES = { + DIRECTION_UP: "Up", + DIRECTION_DOWN: "Down", + } + + Alternatively, consider using :ref:`field-choices-enum-types`:: + + class MyModel(models.Model): + class Direction(models.TextChoices): + UP = U, "Up" + DOWN = D, "Down" Use of ``django.conf.settings`` =============================== diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index a208114596..cd539aa964 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -165,9 +165,11 @@ Model fields * **fields.E002**: Field names must not contain ``"__"``. * **fields.E003**: ``pk`` is a reserved word that cannot be used as a field name. -* **fields.E004**: ``choices`` must be an iterable (e.g., a list or tuple). -* **fields.E005**: ``choices`` must be an iterable containing ``(actual value, - human readable name)`` tuples. +* **fields.E004**: ``choices`` must be a mapping (e.g. a dictionary) or an + iterable (e.g. a list or tuple). +* **fields.E005**: ``choices`` must be a mapping of actual values to human + readable names or an iterable containing ``(actual value, human readable + name)`` tuples. * **fields.E006**: ``db_index`` must be ``None``, ``True`` or ``False``. * **fields.E007**: Primary keys must not have ``null=True``. * **fields.E008**: All ``validators`` must be callable. diff --git a/docs/ref/contrib/admin/actions.txt b/docs/ref/contrib/admin/actions.txt index f5f91982c1..263d43f8a3 100644 --- a/docs/ref/contrib/admin/actions.txt +++ b/docs/ref/contrib/admin/actions.txt @@ -47,11 +47,11 @@ news application with an ``Article`` model:: from django.db import models - STATUS_CHOICES = [ - ("d", "Draft"), - ("p", "Published"), - ("w", "Withdrawn"), - ] + STATUS_CHOICES = { + "d": "Draft", + "p": "Published", + "w": "Withdrawn", + } class Article(models.Model): diff --git a/docs/ref/forms/fields.txt b/docs/ref/forms/fields.txt index 307ebb15a2..5e42404f94 100644 --- a/docs/ref/forms/fields.txt +++ b/docs/ref/forms/fields.txt @@ -510,8 +510,9 @@ For each field, we describe the default widget used if you don't specify .. versionchanged:: 5.0 - Support for using :ref:`enumeration types ` - directly in the ``choices`` was added. + Support for mappings and using + :ref:`enumeration types ` directly in + ``choices`` was added. ``DateField`` ------------- diff --git a/docs/ref/forms/widgets.txt b/docs/ref/forms/widgets.txt index efff81ddbc..f76759b254 100644 --- a/docs/ref/forms/widgets.txt +++ b/docs/ref/forms/widgets.txt @@ -58,11 +58,11 @@ widget on the field. In the following example, the from django import forms BIRTH_YEAR_CHOICES = ["1980", "1981", "1982"] - FAVORITE_COLORS_CHOICES = [ - ("blue", "Blue"), - ("green", "Green"), - ("black", "Black"), - ] + FAVORITE_COLORS_CHOICES = { + "blue": "Blue", + "green": "Green", + "black": "Black", + } class SimpleForm(forms.Form): @@ -95,7 +95,7 @@ example: .. code-block:: pycon >>> from django import forms - >>> CHOICES = [("1", "First"), ("2", "Second")] + >>> CHOICES = {"1": "First", "2": "Second"} >>> choice_field = forms.ChoiceField(widget=forms.RadioSelect, choices=CHOICES) >>> choice_field.choices [('1', 'First'), ('2', 'Second')] @@ -458,9 +458,9 @@ foundation for custom widgets. class DateSelectorWidget(forms.MultiWidget): def __init__(self, attrs=None): - days = [(day, day) for day in range(1, 32)] - months = [(month, month) for month in range(1, 13)] - years = [(year, year) for year in [2018, 2019, 2020]] + days = {day: day for day in range(1, 32)} + months = {month: month for month in range(1, 13)} + years = {year: year for year in [2018, 2019, 2020]} widgets = [ forms.Select(attrs=attrs, choices=days), forms.Select(attrs=attrs, choices=months), diff --git a/docs/ref/models/conditional-expressions.txt b/docs/ref/models/conditional-expressions.txt index d14312870f..cfdbd0790a 100644 --- a/docs/ref/models/conditional-expressions.txt +++ b/docs/ref/models/conditional-expressions.txt @@ -22,11 +22,11 @@ We'll be using the following model in the subsequent examples:: REGULAR = "R" GOLD = "G" PLATINUM = "P" - ACCOUNT_TYPE_CHOICES = [ - (REGULAR, "Regular"), - (GOLD, "Gold"), - (PLATINUM, "Platinum"), - ] + ACCOUNT_TYPE_CHOICES = { + REGULAR: "Regular", + GOLD: "Gold", + PLATINUM: "Platinum", + } name = models.CharField(max_length=50) registered_on = models.DateField() account_type = models.CharField( diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index ff67312949..fbc90e5420 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -86,14 +86,26 @@ If a field has ``blank=False``, the field will be required. .. attribute:: Field.choices -A :term:`sequence` consisting itself of iterables of exactly two items (e.g. -``[(A, B), (A, B) ...]``) to use as choices for this field. If choices are -given, they're enforced by :ref:`model validation ` and the -default form widget will be a select box with these choices instead of the -standard text field. +A mapping or iterable in the format described below to use as choices for this +field. If choices are given, they're enforced by +:ref:`model validation ` and the default form widget will +be a select box with these choices instead of the standard text field. -The first element in each tuple is the actual value to be set on the model, -and the second element is the human-readable name. For example:: +If a mapping is given, the key element is the actual value to be set on the +model, and the second element is the human readable name. For example:: + + YEAR_IN_SCHOOL_CHOICES = { + "FR": "Freshman", + "SO": "Sophomore", + "JR": "Junior", + "SR": "Senior", + "GR": "Graduate", + } + +You can also pass a :term:`sequence` consisting itself of iterables of exactly +two items (e.g. ``[(A1, B1), (A2, B2), …]``). The first element in each tuple +is the actual value to be set on the model, and the second element is the +human-readable name. For example:: YEAR_IN_SCHOOL_CHOICES = [ ("FR", "Freshman"), @@ -103,6 +115,10 @@ and the second element is the human-readable name. For example:: ("GR", "Graduate"), ] +.. versionchanged:: 5.0 + + Support for mappings was added. + Generally, it's best to define choices inside a model class, and to define a suitably-named constant for each value:: @@ -115,13 +131,13 @@ define a suitably-named constant for each value:: JUNIOR = "JR" SENIOR = "SR" GRADUATE = "GR" - YEAR_IN_SCHOOL_CHOICES = [ - (FRESHMAN, "Freshman"), - (SOPHOMORE, "Sophomore"), - (JUNIOR, "Junior"), - (SENIOR, "Senior"), - (GRADUATE, "Graduate"), - ] + YEAR_IN_SCHOOL_CHOICES = { + FRESHMAN: "Freshman", + SOPHOMORE: "Sophomore", + JUNIOR: "Junior", + SENIOR: "Senior", + GRADUATE: "Graduate", + } year_in_school = models.CharField( max_length=2, choices=YEAR_IN_SCHOOL_CHOICES, @@ -142,6 +158,25 @@ will work anywhere that the ``Student`` model has been imported). You can also collect your available choices into named groups that can be used for organizational purposes:: + MEDIA_CHOICES = { + "Audio": { + "vinyl": "Vinyl", + "cd": "CD", + }, + "Video": { + "vhs": "VHS Tape", + "dvd": "DVD", + }, + "unknown": "Unknown", + } + +The key of the mapping is the name to apply to the group and the value is the +choices inside that group, consisting of the field value and a human-readable +name for an option. Grouped options may be combined with ungrouped options +within a single mapping (such as the ``"unknown"`` option in this example). + +You can also use a sequence, e.g. a list of 2-tuples:: + MEDIA_CHOICES = [ ( "Audio", @@ -160,17 +195,6 @@ be used for organizational purposes:: ("unknown", "Unknown"), ] -The first element in each tuple is the name to apply to the group. The -second element is an iterable of 2-tuples, with each 2-tuple containing -a value and a human-readable name for an option. Grouped options may be -combined with ungrouped options within a single list (such as the -``'unknown'`` option in this example). - -For each model field that has :attr:`~Field.choices` set, Django will add a -method to retrieve the human-readable name for the field's current value. See -:meth:`~django.db.models.Model.get_FOO_display` in the database API -documentation. - Note that choices can be any sequence object -- not necessarily a list or tuple. This lets you construct choices dynamically. But if you find yourself hacking :attr:`~Field.choices` to be dynamic, you're probably better off using @@ -180,6 +204,12 @@ meant for static data that doesn't change much, if ever. .. note:: A new migration is created each time the order of ``choices`` changes. +For each model field that has :attr:`~Field.choices` set, Django will normalize +the choices to a list of 2-tuples and add a method to retrieve the +human-readable name for the field's current value. See +:meth:`~django.db.models.Model.get_FOO_display` in the database API +documentation. + .. _field-choices-blank-label: Unless :attr:`blank=False` is set on the field along with a diff --git a/docs/ref/models/instances.txt b/docs/ref/models/instances.txt index 6bfd521aaa..6ceb0703ab 100644 --- a/docs/ref/models/instances.txt +++ b/docs/ref/models/instances.txt @@ -912,11 +912,11 @@ For example:: class Person(models.Model): - SHIRT_SIZES = [ - ("S", "Small"), - ("M", "Medium"), - ("L", "Large"), - ] + SHIRT_SIZES = { + "S": "Small", + "M": "Medium", + "L": "Large", + } name = models.CharField(max_length=60) shirt_size = models.CharField(max_length=2, choices=SHIRT_SIZES) diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 36c2c650fa..cae3668bfc 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -129,6 +129,55 @@ sets a database-computed default value. For example:: created = models.DateTimeField(db_default=Now()) circumference = models.FloatField(db_default=2 * Pi()) +More options for declaring field choices +---------------------------------------- + +:attr:`.Field.choices` *(for model fields)* and :attr:`.ChoiceField.choices` +*(for form fields)* allow for more flexibility when declaring their values. In +previous versions of Django, ``choices`` should either be a list of 2-tuples, +or an :ref:`field-choices-enum-types` subclass, but the latter required +accessing the ``.choices`` attribute to provide the values in the expected +form:: + + from django.db import models + + Medal = models.TextChoices("Medal", "GOLD SILVER BRONZE") + + SPORT_CHOICES = [ + ("Martial Arts", [("judo", "Judo"), ("karate", "Karate")]), + ("Racket", [("badminton", "Badminton"), ("tennis", "Tennis")]), + ("unknown", "Unknown"), + ] + + + class Winners(models.Model): + name = models.CharField(...) + medal = models.CharField(..., choices=Medal.choices) + sport = models.CharField(..., choices=SPORT_CHOICES) + +Django 5.0 supports providing a mapping instead of an iterable, and also no +longer requires ``.choices`` to be used directly to expand :ref:`enumeration +types `:: + + from django.db import models + + Medal = models.TextChoices("Medal", "GOLD SILVER BRONZE") + + SPORT_CHOICES = { # Using a mapping instead of a list of 2-tuples. + "Martial Arts": {"judo": "Judo", "karate": "Karate"}, + "Racket": {"badminton": "Badminton", "tennis": "Tennis"}, + "unknown": "Unknown", + } + + + class Winners(models.Model): + name = models.CharField(...) + medal = models.CharField(..., choices=Medal) # Using `.choices` not required. + sport = models.CharField(..., choices=SPORT_CHOICES) + +Under the hood the provided ``choices`` are normalized into a list of 2-tuples +as the canonical form whenever the ``choices`` value is updated. + Minor features -------------- @@ -304,10 +353,6 @@ File Uploads Forms ~~~~~ -* :attr:`.ChoiceField.choices` now accepts - :ref:`Choices classes ` directly instead of - requiring expansion with the ``choices`` attribute. - * The new ``assume_scheme`` argument for :class:`~django.forms.URLField` allows specifying a default URL scheme. @@ -357,10 +402,6 @@ Models of ``ValidationError`` raised during :ref:`model validation `. -* :attr:`.Field.choices` now accepts - :ref:`Choices classes ` directly instead of - requiring expansion with the ``choices`` attribute. - * The :ref:`force_insert ` argument of :meth:`.Model.save` now allows specifying a tuple of parent classes that must be forced to be inserted. diff --git a/docs/topics/db/managers.txt b/docs/topics/db/managers.txt index 047d02ebae..61de153898 100644 --- a/docs/topics/db/managers.txt +++ b/docs/topics/db/managers.txt @@ -154,9 +154,7 @@ For example:: class Person(models.Model): first_name = models.CharField(max_length=50) last_name = models.CharField(max_length=50) - role = models.CharField( - max_length=1, choices=[("A", _("Author")), ("E", _("Editor"))] - ) + role = models.CharField(max_length=1, choices={"A": _("Author"), "E": _("Editor")}) people = models.Manager() authors = AuthorManager() editors = EditorManager() @@ -259,9 +257,7 @@ custom ``QuerySet`` if you also implement them on the ``Manager``:: class Person(models.Model): first_name = models.CharField(max_length=50) last_name = models.CharField(max_length=50) - role = models.CharField( - max_length=1, choices=[("A", _("Author")), ("E", _("Editor"))] - ) + role = models.CharField(max_length=1, choices={"A": _("Author"), "E": _("Editor")}) people = PersonManager() This example allows you to call both ``authors()`` and ``editors()`` directly from diff --git a/docs/topics/db/models.txt b/docs/topics/db/models.txt index 33a515f14f..cc6c1f5298 100644 --- a/docs/topics/db/models.txt +++ b/docs/topics/db/models.txt @@ -185,11 +185,11 @@ ones: class Person(models.Model): - SHIRT_SIZES = [ - ("S", "Small"), - ("M", "Medium"), - ("L", "Large"), - ] + SHIRT_SIZES = { + "S": "Small", + "M": "Medium", + "L": "Large", + } name = models.CharField(max_length=60) shirt_size = models.CharField(max_length=1, choices=SHIRT_SIZES) diff --git a/docs/topics/forms/modelforms.txt b/docs/topics/forms/modelforms.txt index 7f3c042f30..fbd5695c17 100644 --- a/docs/topics/forms/modelforms.txt +++ b/docs/topics/forms/modelforms.txt @@ -173,11 +173,11 @@ Consider this set of models:: from django.db import models from django.forms import ModelForm - TITLE_CHOICES = [ - ("MR", "Mr."), - ("MRS", "Mrs."), - ("MS", "Ms."), - ] + TITLE_CHOICES = { + "MR": "Mr.", + "MRS": "Mrs.", + "MS": "Ms.", + } class Author(models.Model): diff --git a/tests/field_deconstruction/tests.py b/tests/field_deconstruction/tests.py index 3663886708..3b10ee0091 100644 --- a/tests/field_deconstruction/tests.py +++ b/tests/field_deconstruction/tests.py @@ -2,6 +2,7 @@ from django.apps import apps from django.db import models from django.test import SimpleTestCase, override_settings from django.test.utils import isolate_lru_cache +from django.utils.choices import normalize_choices class FieldDeconstructionTests(SimpleTestCase): @@ -105,12 +106,22 @@ class FieldDeconstructionTests(SimpleTestCase): self.assertEqual(kwargs, {"choices": [(0, "0"), (1, "1"), (2, "2")]}) def test_choices_iterable(self): - # Pass an iterator (but not an iterable) to choices. + # Pass an iterable (but not an iterator) to choices. field = models.IntegerField(choices="012345") name, path, args, kwargs = field.deconstruct() self.assertEqual(path, "django.db.models.IntegerField") self.assertEqual(args, []) - self.assertEqual(kwargs, {"choices": ["0", "1", "2", "3", "4", "5"]}) + self.assertEqual(kwargs, {"choices": normalize_choices("012345")}) + + def test_choices_callable(self): + def get_choices(): + return [(i, str(i)) for i in range(3)] + + field = models.IntegerField(choices=get_choices) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.IntegerField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"choices": get_choices}) def test_csi_field(self): field = models.CommaSeparatedIntegerField(max_length=100) diff --git a/tests/forms_tests/field_tests/test_choicefield.py b/tests/forms_tests/field_tests/test_choicefield.py index e7893abe57..ebc6c14bea 100644 --- a/tests/forms_tests/field_tests/test_choicefield.py +++ b/tests/forms_tests/field_tests/test_choicefield.py @@ -65,6 +65,51 @@ class ChoiceFieldTest(FormFieldAssertionsMixin, SimpleTestCase): f = ChoiceField(choices=choices) self.assertEqual("J", f.clean("J")) + def test_choicefield_callable_mapping(self): + def choices(): + return {"J": "John", "P": "Paul"} + + f = ChoiceField(choices=choices) + self.assertEqual("J", f.clean("J")) + + def test_choicefield_callable_grouped_mapping(self): + def choices(): + return { + "Numbers": {"1": "One", "2": "Two"}, + "Letters": {"3": "A", "4": "B"}, + } + + f = ChoiceField(choices=choices) + for i in ("1", "2", "3", "4"): + with self.subTest(i): + self.assertEqual(i, f.clean(i)) + + def test_choicefield_mapping(self): + f = ChoiceField(choices={"J": "John", "P": "Paul"}) + self.assertEqual("J", f.clean("J")) + + def test_choicefield_grouped_mapping(self): + f = ChoiceField( + choices={ + "Numbers": (("1", "One"), ("2", "Two")), + "Letters": (("3", "A"), ("4", "B")), + } + ) + for i in ("1", "2", "3", "4"): + with self.subTest(i): + self.assertEqual(i, f.clean(i)) + + def test_choicefield_grouped_mapping_inner_dict(self): + f = ChoiceField( + choices={ + "Numbers": {"1": "One", "2": "Two"}, + "Letters": {"3": "A", "4": "B"}, + } + ) + for i in ("1", "2", "3", "4"): + with self.subTest(i): + self.assertEqual(i, f.clean(i)) + def test_choicefield_callable_may_evaluate_to_different_values(self): choices = [] @@ -76,11 +121,13 @@ class ChoiceFieldTest(FormFieldAssertionsMixin, SimpleTestCase): choices = [("J", "John")] form = ChoiceFieldForm() - self.assertEqual([("J", "John")], list(form.fields["choicefield"].choices)) + self.assertEqual(choices, list(form.fields["choicefield"].choices)) + self.assertEqual(choices, list(form.fields["choicefield"].widget.choices)) choices = [("P", "Paul")] form = ChoiceFieldForm() - self.assertEqual([("P", "Paul")], list(form.fields["choicefield"].choices)) + self.assertEqual(choices, list(form.fields["choicefield"].choices)) + self.assertEqual(choices, list(form.fields["choicefield"].widget.choices)) def test_choicefield_disabled(self): f = ChoiceField(choices=[("J", "John"), ("P", "Paul")], disabled=True) diff --git a/tests/forms_tests/widget_tests/test_choicewidget.py b/tests/forms_tests/widget_tests/test_choicewidget.py index 129178f207..abd1961b32 100644 --- a/tests/forms_tests/widget_tests/test_choicewidget.py +++ b/tests/forms_tests/widget_tests/test_choicewidget.py @@ -9,13 +9,26 @@ class ChoiceWidgetTest(WidgetTest): widget = ChoiceWidget @property - def nested_widget(self): - return self.widget( + def nested_widgets(self): + nested_widget = self.widget( choices=( ("outer1", "Outer 1"), ('Group "1"', (("inner1", "Inner 1"), ("inner2", "Inner 2"))), - ) + ), ) + nested_widget_dict = self.widget( + choices={ + "outer1": "Outer 1", + 'Group "1"': {"inner1": "Inner 1", "inner2": "Inner 2"}, + }, + ) + nested_widget_dict_tuple = self.widget( + choices={ + "outer1": "Outer 1", + 'Group "1"': (("inner1", "Inner 1"), ("inner2", "Inner 2")), + }, + ) + return (nested_widget, nested_widget_dict, nested_widget_dict_tuple) def test_deepcopy(self): """ diff --git a/tests/forms_tests/widget_tests/test_radioselect.py b/tests/forms_tests/widget_tests/test_radioselect.py index 5e5cea9d35..be336151ef 100644 --- a/tests/forms_tests/widget_tests/test_radioselect.py +++ b/tests/forms_tests/widget_tests/test_radioselect.py @@ -13,7 +13,6 @@ class RadioSelectTest(ChoiceWidgetTest): widget = RadioSelect def test_render(self): - choices = BLANK_CHOICE_DASH + self.beatles html = """
@@ -33,7 +32,10 @@ class RadioSelectTest(ChoiceWidgetTest):
""" - self.check_html(self.widget(choices=choices), "beatle", "J", html=html) + beatles_with_blank = BLANK_CHOICE_DASH + self.beatles + for choices in (beatles_with_blank, dict(beatles_with_blank)): + with self.subTest(choices): + self.check_html(self.widget(choices=choices), "beatle", "J", html=html) def test_nested_choices(self): nested_choices = ( @@ -312,7 +314,9 @@ class RadioSelectTest(ChoiceWidgetTest): """ - self.check_html(self.nested_widget, "nestchoice", None, html=html) + for widget in self.nested_widgets: + with self.subTest(widget): + self.check_html(widget, "nestchoice", None, html=html) def test_choices_select_outer(self): html = """ @@ -334,7 +338,9 @@ class RadioSelectTest(ChoiceWidgetTest): """ - self.check_html(self.nested_widget, "nestchoice", "outer1", html=html) + for widget in self.nested_widgets: + with self.subTest(widget): + self.check_html(widget, "nestchoice", "outer1", html=html) def test_choices_select_inner(self): html = """ @@ -356,7 +362,9 @@ class RadioSelectTest(ChoiceWidgetTest): """ - self.check_html(self.nested_widget, "nestchoice", "inner2", html=html) + for widget in self.nested_widgets: + with self.subTest(widget): + self.check_html(widget, "nestchoice", "inner2", html=html) def test_render_attrs(self): """ diff --git a/tests/forms_tests/widget_tests/test_select.py b/tests/forms_tests/widget_tests/test_select.py index 60a0b72880..6164d0b6b3 100644 --- a/tests/forms_tests/widget_tests/test_select.py +++ b/tests/forms_tests/widget_tests/test_select.py @@ -11,19 +11,17 @@ class SelectTest(ChoiceWidgetTest): widget = Select def test_render(self): - self.check_html( - self.widget(choices=self.beatles), - "beatle", - "J", - html=( - """""" - ), - ) + html = """ + + """ + for choices in (self.beatles, dict(self.beatles)): + with self.subTest(choices): + self.check_html(self.widget(choices=choices), "beatle", "J", html=html) def test_render_none(self): """ @@ -237,52 +235,46 @@ class SelectTest(ChoiceWidgetTest): """ Choices can be nested one level in order to create HTML optgroups. """ - self.check_html( - self.nested_widget, - "nestchoice", - None, - html=( - """""" - ), - ) + html = """ + + """ + for widget in self.nested_widgets: + with self.subTest(widget): + self.check_html(widget, "nestchoice", None, html=html) def test_choices_select_outer(self): - self.check_html( - self.nested_widget, - "nestchoice", - "outer1", - html=( - """""" - ), - ) + html = """ + + """ + for widget in self.nested_widgets: + with self.subTest(widget): + self.check_html(widget, "nestchoice", "outer1", html=html) def test_choices_select_inner(self): - self.check_html( - self.nested_widget, - "nestchoice", - "inner1", - html=( - """""" - ), - ) + html = """ + + """ + for widget in self.nested_widgets: + with self.subTest(widget): + self.check_html(widget, "nestchoice", "inner1", html=html) @override_settings(USE_THOUSAND_SEPARATOR=True) def test_doesnt_localize_option_value(self): @@ -312,24 +304,7 @@ class SelectTest(ChoiceWidgetTest): """ self.check_html(self.widget(choices=choices), "time", None, html=html) - def test_optgroups(self): - choices = [ - ( - "Audio", - [ - ("vinyl", "Vinyl"), - ("cd", "CD"), - ], - ), - ( - "Video", - [ - ("vhs", "VHS Tape"), - ("dvd", "DVD"), - ], - ), - ("unknown", "Unknown"), - ] + def _test_optgroups(self, choices): groups = list( self.widget(choices=choices).optgroups( "name", @@ -418,6 +393,27 @@ class SelectTest(ChoiceWidgetTest): ) self.assertEqual(index, 2) + def test_optgroups(self): + choices_dict = { + "Audio": [ + ("vinyl", "Vinyl"), + ("cd", "CD"), + ], + "Video": [ + ("vhs", "VHS Tape"), + ("dvd", "DVD"), + ], + "unknown": "Unknown", + } + choices_list = list(choices_dict.items()) + choices_nested_dict = { + k: dict(v) if isinstance(v, list) else v for k, v in choices_dict.items() + } + + for choices in (choices_dict, choices_list, choices_nested_dict): + with self.subTest(choices): + self._test_optgroups(choices) + def test_doesnt_render_required_when_impossible_to_select_empty_field(self): widget = self.widget(choices=[("J", "John"), ("P", "Paul")]) self.assertIs(widget.use_required_attribute(initial=None), False) diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index 063c99e8bd..6014448013 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -199,7 +199,8 @@ class CharFieldTests(TestCase): field.check(), [ Error( - "'choices' must be an iterable (e.g., a list or tuple).", + "'choices' must be a mapping (e.g. a dictionary) or an iterable " + "(e.g. a list or tuple).", obj=field, id="fields.E004", ), @@ -217,8 +218,9 @@ class CharFieldTests(TestCase): field.check(), [ Error( - "'choices' must be an iterable containing (actual value, " - "human readable name) tuples.", + "'choices' must be a mapping of actual values to human readable " + "names or an iterable containing (actual value, human readable " + "name) tuples.", obj=field, id="fields.E005", ), @@ -260,8 +262,9 @@ class CharFieldTests(TestCase): field.check(), [ Error( - "'choices' must be an iterable containing (actual " - "value, human readable name) tuples.", + "'choices' must be a mapping of actual values to human " + "readable names or an iterable containing (actual value, " + "human readable name) tuples.", obj=field, id="fields.E005", ), @@ -309,8 +312,9 @@ class CharFieldTests(TestCase): field.check(), [ Error( - "'choices' must be an iterable containing (actual value, " - "human readable name) tuples.", + "'choices' must be a mapping of actual values to human readable " + "names or an iterable containing (actual value, human readable " + "name) tuples.", obj=field, id="fields.E005", ), @@ -337,8 +341,9 @@ class CharFieldTests(TestCase): field.check(), [ Error( - "'choices' must be an iterable containing (actual value, " - "human readable name) tuples.", + "'choices' must be a mapping of actual values to human readable " + "names or an iterable containing (actual value, human readable " + "name) tuples.", obj=field, id="fields.E005", ), @@ -386,6 +391,26 @@ class CharFieldTests(TestCase): ], ) + def test_choices_callable(self): + def get_choices(): + return [(i, i) for i in range(3)] + + class Model(models.Model): + field = models.CharField(max_length=10, choices=get_choices) + + field = Model._meta.get_field("field") + self.assertEqual( + field.check(), + [ + Error( + "'choices' must be a mapping (e.g. a dictionary) or an iterable " + "(e.g. a list or tuple).", + obj=field, + id="fields.E004", + ), + ], + ) + def test_bad_db_index_value(self): class Model(models.Model): field = models.CharField(max_length=10, db_index="bad") @@ -854,7 +879,8 @@ class IntegerFieldTests(SimpleTestCase): field.check(), [ Error( - "'choices' must be an iterable (e.g., a list or tuple).", + "'choices' must be a mapping (e.g. a dictionary) or an iterable " + "(e.g. a list or tuple).", obj=field, id="fields.E004", ), @@ -872,8 +898,9 @@ class IntegerFieldTests(SimpleTestCase): field.check(), [ Error( - "'choices' must be an iterable containing (actual value, human " - "readable name) tuples.", + "'choices' must be a mapping of actual values to human readable " + "names or an iterable containing (actual value, human readable " + "name) tuples.", obj=field, id="fields.E005", ), diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 5ee814d2af..e28e52ffad 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -460,6 +460,16 @@ class WriterTests(SimpleTestCase): "default=datetime.date(1969, 11, 19))", ) + def test_serialize_dictionary_choices(self): + for choices in ({"Group": [(2, "2"), (1, "1")]}, {"Group": {2: "2", 1: "1"}}): + with self.subTest(choices): + field = models.IntegerField(choices=choices) + string = MigrationWriter.serialize(field)[0] + self.assertEqual( + string, + "models.IntegerField(choices=[('Group', [(2, '2'), (1, '1')])])", + ) + def test_serialize_nested_class(self): for nested_cls in [self.NestedEnum, self.NestedChoices]: cls_name = nested_cls.__name__ diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py index 1776cb9bcb..e1a5a3872f 100644 --- a/tests/model_fields/models.py +++ b/tests/model_fields/models.py @@ -31,24 +31,18 @@ class Bar(models.Model): class Whiz(models.Model): - CHOICES = ( - ( - "Group 1", - ( - (1, "First"), - (2, "Second"), - ), + CHOICES = { + "Group 1": { + 1: "First", + 2: "Second", + }, + "Group 2": ( + (3, "Third"), + (4, "Fourth"), ), - ( - "Group 2", - ( - (3, "Third"), - (4, "Fourth"), - ), - ), - (0, "Other"), - (5, _("translated")), - ) + 0: "Other", + 5: _("translated"), + } c = models.IntegerField(choices=CHOICES, null=True) @@ -61,7 +55,7 @@ WhizDelayed._meta.get_field("c").choices = Whiz.CHOICES class WhizIter(models.Model): - c = models.IntegerField(choices=iter(Whiz.CHOICES), null=True) + c = models.IntegerField(choices=iter(Whiz.CHOICES.items()), null=True) class WhizIterEmpty(models.Model): @@ -78,6 +72,10 @@ class Choiceful(models.Model): no_choices = models.IntegerField(null=True) empty_choices = models.IntegerField(choices=(), null=True) with_choices = models.IntegerField(choices=[(1, "A")], null=True) + with_choices_dict = models.IntegerField(choices={1: "A"}, null=True) + with_choices_nested_dict = models.IntegerField( + choices={"Thing": {1: "A"}}, null=True + ) empty_choices_bool = models.BooleanField(choices=()) empty_choices_text = models.TextField(choices=()) choices_from_enum = models.IntegerField(choices=Suit) diff --git a/tests/model_fields/test_integerfield.py b/tests/model_fields/test_integerfield.py index 6761589b7e..0d91cff9eb 100644 --- a/tests/model_fields/test_integerfield.py +++ b/tests/model_fields/test_integerfield.py @@ -279,6 +279,14 @@ class ValidationTests(SimpleTestCase): f = models.IntegerField(choices=(("group", ((10, "A"), (20, "B"))), (30, "C"))) self.assertEqual(10, f.clean(10, None)) + def test_choices_validation_supports_named_groups_dicts(self): + f = models.IntegerField(choices={"group": ((10, "A"), (20, "B")), 30: "C"}) + self.assertEqual(10, f.clean(10, None)) + + def test_choices_validation_supports_named_groups_nested_dicts(self): + f = models.IntegerField(choices={"group": {10: "A", 20: "B"}, 30: "C"}) + self.assertEqual(10, f.clean(10, None)) + def test_nullable_integerfield_raises_error_with_blank_false(self): f = models.IntegerField(null=True, blank=False) with self.assertRaises(ValidationError): diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py index 91ef6058f9..bc57a00a1d 100644 --- a/tests/model_fields/tests.py +++ b/tests/model_fields/tests.py @@ -156,15 +156,21 @@ class ChoicesTests(SimpleTestCase): cls.empty_choices_bool = Choiceful._meta.get_field("empty_choices_bool") cls.empty_choices_text = Choiceful._meta.get_field("empty_choices_text") cls.with_choices = Choiceful._meta.get_field("with_choices") + cls.with_choices_dict = Choiceful._meta.get_field("with_choices_dict") + cls.with_choices_nested_dict = Choiceful._meta.get_field( + "with_choices_nested_dict" + ) cls.choices_from_enum = Choiceful._meta.get_field("choices_from_enum") cls.choices_from_iterator = Choiceful._meta.get_field("choices_from_iterator") def test_choices(self): self.assertIsNone(self.no_choices.choices) - self.assertEqual(self.empty_choices.choices, ()) - self.assertEqual(self.empty_choices_bool.choices, ()) - self.assertEqual(self.empty_choices_text.choices, ()) + self.assertEqual(self.empty_choices.choices, []) + self.assertEqual(self.empty_choices_bool.choices, []) + self.assertEqual(self.empty_choices_text.choices, []) self.assertEqual(self.with_choices.choices, [(1, "A")]) + self.assertEqual(self.with_choices_dict.choices, [(1, "A")]) + self.assertEqual(self.with_choices_nested_dict.choices, [("Thing", [(1, "A")])]) self.assertEqual( self.choices_from_iterator.choices, [(0, "0"), (1, "1"), (2, "2")] ) @@ -175,6 +181,8 @@ class ChoicesTests(SimpleTestCase): self.assertEqual(self.empty_choices_bool.flatchoices, []) self.assertEqual(self.empty_choices_text.flatchoices, []) self.assertEqual(self.with_choices.flatchoices, [(1, "A")]) + self.assertEqual(self.with_choices_dict.flatchoices, [(1, "A")]) + self.assertEqual(self.with_choices_nested_dict.flatchoices, [(1, "A")]) self.assertEqual( self.choices_from_iterator.flatchoices, [(0, "0"), (1, "1"), (2, "2")] ) @@ -290,11 +298,11 @@ class GetChoicesTests(SimpleTestCase): ("b", "Bar"), ( "Group", - ( + [ ("", "No Preference"), ("fg", "Foo"), ("bg", "Bar"), - ), + ], ), ] f = models.CharField(choices=choices) @@ -302,7 +310,7 @@ class GetChoicesTests(SimpleTestCase): def test_lazy_strings_not_evaluated(self): lazy_func = lazy(lambda x: 0 / 0, int) # raises ZeroDivisionError if evaluated. - f = models.CharField(choices=[(lazy_func("group"), (("a", "A"), ("b", "B")))]) + f = models.CharField(choices=[(lazy_func("group"), [("a", "A"), ("b", "B")])]) self.assertEqual(f.get_choices(include_blank=True)[0], ("", "---------")) diff --git a/tests/utils_tests/test_choices.py b/tests/utils_tests/test_choices.py new file mode 100644 index 0000000000..d96c3d49c4 --- /dev/null +++ b/tests/utils_tests/test_choices.py @@ -0,0 +1,305 @@ +from unittest import mock + +from django.db.models import TextChoices +from django.test import SimpleTestCase +from django.utils.choices import CallableChoiceIterator, normalize_choices +from django.utils.translation import gettext_lazy as _ + + +class NormalizeFieldChoicesTests(SimpleTestCase): + expected = [ + ("C", _("Club")), + ("D", _("Diamond")), + ("H", _("Heart")), + ("S", _("Spade")), + ] + expected_nested = [ + ("Audio", [("vinyl", _("Vinyl")), ("cd", _("CD"))]), + ("Video", [("vhs", _("VHS Tape")), ("dvd", _("DVD"))]), + ("unknown", _("Unknown")), + ] + invalid = [ + 1j, + 123, + 123.45, + "invalid", + b"invalid", + _("invalid"), + object(), + None, + True, + False, + ] + invalid_iterable = [ + # Special cases of a string-likes which would unpack incorrectly. + ["ab"], + [b"ab"], + [_("ab")], + # Non-iterable items or iterable items with incorrect number of + # elements that cannot be unpacked. + [123], + [("value",)], + [("value", "label", "other")], + ] + invalid_nested = [ + # Nested choices can only be two-levels deep, so return callables, + # mappings, iterables, etc. at deeper levels unmodified. + [("Group", [("Value", lambda: "Label")])], + [("Group", [("Value", {"Label 1?": "Label 2?"})])], + [("Group", [("Value", [("Label 1?", "Label 2?")])])], + ] + + def test_empty(self): + def generator(): + yield from () + + for choices in ({}, [], (), set(), frozenset(), generator()): + with self.subTest(choices=choices): + self.assertEqual(normalize_choices(choices), []) + + def test_choices(self): + class Medal(TextChoices): + GOLD = "GOLD", _("Gold") + SILVER = "SILVER", _("Silver") + BRONZE = "BRONZE", _("Bronze") + + expected = [ + ("GOLD", _("Gold")), + ("SILVER", _("Silver")), + ("BRONZE", _("Bronze")), + ] + self.assertEqual(normalize_choices(Medal), expected) + + def test_callable(self): + def get_choices(): + return { + "C": _("Club"), + "D": _("Diamond"), + "H": _("Heart"), + "S": _("Spade"), + } + + get_choices_spy = mock.Mock(wraps=get_choices) + output = normalize_choices(get_choices_spy) + + get_choices_spy.assert_not_called() + self.assertIsInstance(output, CallableChoiceIterator) + self.assertEqual(list(output), self.expected) + get_choices_spy.assert_called_once() + + def test_mapping(self): + choices = { + "C": _("Club"), + "D": _("Diamond"), + "H": _("Heart"), + "S": _("Spade"), + } + self.assertEqual(normalize_choices(choices), self.expected) + + def test_iterable(self): + choices = [ + ("C", _("Club")), + ("D", _("Diamond")), + ("H", _("Heart")), + ("S", _("Spade")), + ] + self.assertEqual(normalize_choices(choices), self.expected) + + def test_iterator(self): + def generator(): + yield "C", _("Club") + yield "D", _("Diamond") + yield "H", _("Heart") + yield "S", _("Spade") + + choices = generator() + self.assertEqual(normalize_choices(choices), self.expected) + + def test_nested_callable(self): + def get_audio_choices(): + return [("vinyl", _("Vinyl")), ("cd", _("CD"))] + + def get_video_choices(): + return [("vhs", _("VHS Tape")), ("dvd", _("DVD"))] + + def get_media_choices(): + return [ + ("Audio", get_audio_choices), + ("Video", get_video_choices), + ("unknown", _("Unknown")), + ] + + get_media_choices_spy = mock.Mock(wraps=get_media_choices) + output = normalize_choices(get_media_choices_spy) + + get_media_choices_spy.assert_not_called() + self.assertIsInstance(output, CallableChoiceIterator) + self.assertEqual(list(output), self.expected_nested) + get_media_choices_spy.assert_called_once() + + def test_nested_mapping(self): + choices = { + "Audio": {"vinyl": _("Vinyl"), "cd": _("CD")}, + "Video": {"vhs": _("VHS Tape"), "dvd": _("DVD")}, + "unknown": _("Unknown"), + } + self.assertEqual(normalize_choices(choices), self.expected_nested) + + def test_nested_iterable(self): + choices = [ + ("Audio", [("vinyl", _("Vinyl")), ("cd", _("CD"))]), + ("Video", [("vhs", _("VHS Tape")), ("dvd", _("DVD"))]), + ("unknown", _("Unknown")), + ] + self.assertEqual(normalize_choices(choices), self.expected_nested) + + def test_nested_iterator(self): + def generate_audio_choices(): + yield "vinyl", _("Vinyl") + yield "cd", _("CD") + + def generate_video_choices(): + yield "vhs", _("VHS Tape") + yield "dvd", _("DVD") + + def generate_media_choices(): + yield "Audio", generate_audio_choices() + yield "Video", generate_video_choices() + yield "unknown", _("Unknown") + + choices = generate_media_choices() + self.assertEqual(normalize_choices(choices), self.expected_nested) + + def test_callable_non_canonical(self): + # Canonical form is list of 2-tuple, but nested lists should work. + def get_choices(): + return [ + ["C", _("Club")], + ["D", _("Diamond")], + ["H", _("Heart")], + ["S", _("Spade")], + ] + + get_choices_spy = mock.Mock(wraps=get_choices) + output = normalize_choices(get_choices_spy) + + get_choices_spy.assert_not_called() + self.assertIsInstance(output, CallableChoiceIterator) + self.assertEqual(list(output), self.expected) + get_choices_spy.assert_called_once() + + def test_iterable_non_canonical(self): + # Canonical form is list of 2-tuple, but nested lists should work. + choices = [ + ["C", _("Club")], + ["D", _("Diamond")], + ["H", _("Heart")], + ["S", _("Spade")], + ] + self.assertEqual(normalize_choices(choices), self.expected) + + def test_iterator_non_canonical(self): + # Canonical form is list of 2-tuple, but nested lists should work. + def generator(): + yield ["C", _("Club")] + yield ["D", _("Diamond")] + yield ["H", _("Heart")] + yield ["S", _("Spade")] + + choices = generator() + self.assertEqual(normalize_choices(choices), self.expected) + + def test_nested_callable_non_canonical(self): + # Canonical form is list of 2-tuple, but nested lists should work. + + def get_audio_choices(): + return [["vinyl", _("Vinyl")], ["cd", _("CD")]] + + def get_video_choices(): + return [["vhs", _("VHS Tape")], ["dvd", _("DVD")]] + + def get_media_choices(): + return [ + ["Audio", get_audio_choices], + ["Video", get_video_choices], + ["unknown", _("Unknown")], + ] + + get_media_choices_spy = mock.Mock(wraps=get_media_choices) + output = normalize_choices(get_media_choices_spy) + + get_media_choices_spy.assert_not_called() + self.assertIsInstance(output, CallableChoiceIterator) + self.assertEqual(list(output), self.expected_nested) + get_media_choices_spy.assert_called_once() + + def test_nested_iterable_non_canonical(self): + # Canonical form is list of 2-tuple, but nested lists should work. + choices = [ + ["Audio", [["vinyl", _("Vinyl")], ["cd", _("CD")]]], + ["Video", [["vhs", _("VHS Tape")], ["dvd", _("DVD")]]], + ["unknown", _("Unknown")], + ] + self.assertEqual(normalize_choices(choices), self.expected_nested) + + def test_nested_iterator_non_canonical(self): + # Canonical form is list of 2-tuple, but nested lists should work. + def generator(): + yield ["Audio", [["vinyl", _("Vinyl")], ["cd", _("CD")]]] + yield ["Video", [["vhs", _("VHS Tape")], ["dvd", _("DVD")]]] + yield ["unknown", _("Unknown")] + + choices = generator() + self.assertEqual(normalize_choices(choices), self.expected_nested) + + def test_nested_mixed_mapping_and_iterable(self): + # Although not documented, as it's better to stick to either mappings + # or iterables, nesting of mappings within iterables and vice versa + # works and is likely to occur in the wild. This is supported by the + # recursive call to `normalize_choices()` which will normalize nested + # choices. + choices = { + "Audio": [("vinyl", _("Vinyl")), ("cd", _("CD"))], + "Video": [("vhs", _("VHS Tape")), ("dvd", _("DVD"))], + "unknown": _("Unknown"), + } + self.assertEqual(normalize_choices(choices), self.expected_nested) + choices = [ + ("Audio", {"vinyl": _("Vinyl"), "cd": _("CD")}), + ("Video", {"vhs": _("VHS Tape"), "dvd": _("DVD")}), + ("unknown", _("Unknown")), + ] + self.assertEqual(normalize_choices(choices), self.expected_nested) + + def test_iterable_set(self): + # Although not documented, as sets are unordered which results in + # randomised order in form fields, passing a set of 2-tuples works. + # Consistent ordering of choices on model fields in migrations is + # enforced by the migrations serializer. + choices = { + ("C", _("Club")), + ("D", _("Diamond")), + ("H", _("Heart")), + ("S", _("Spade")), + } + self.assertEqual(sorted(normalize_choices(choices)), sorted(self.expected)) + + def test_unsupported_values_returned_unmodified(self): + # Unsupported values must be returned unmodified for model system check + # to work correctly. + for value in self.invalid + self.invalid_iterable + self.invalid_nested: + with self.subTest(value=value): + self.assertEqual(normalize_choices(value), value) + + def test_unsupported_values_from_callable_returned_unmodified(self): + for value in self.invalid_iterable + self.invalid_nested: + with self.subTest(value=value): + self.assertEqual(list(normalize_choices(lambda: value)), value) + + def test_unsupported_values_from_iterator_returned_unmodified(self): + for value in self.invalid_nested: + with self.subTest(value=value): + self.assertEqual( + list(normalize_choices((lambda: (yield from value))())), + value, + )