1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #31262 -- Added support for mappings on model fields and ChoiceField's choices.

This commit is contained in:
Nick Pope 2023-08-31 02:57:40 +01:00 committed by GitHub
parent 68a8996bdf
commit 500e01073a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 822 additions and 249 deletions

View File

@ -262,7 +262,6 @@ class RelatedFieldWidgetWrapper(forms.Widget):
): ):
self.needs_multipart_form = widget.needs_multipart_form self.needs_multipart_form = widget.needs_multipart_form
self.attrs = widget.attrs self.attrs = widget.attrs
self.choices = widget.choices
self.widget = widget self.widget = widget
self.rel = rel self.rel = rel
# Backwards compatible check for whether a user can add related # Backwards compatible check for whether a user can add related
@ -295,6 +294,14 @@ class RelatedFieldWidgetWrapper(forms.Widget):
def media(self): def media(self):
return self.widget.media return self.widget.media
@property
def choices(self):
return self.widget.choices
@choices.setter
def choices(self, value):
self.widget.choices = value
def get_related_url(self, info, action, *args): def get_related_url(self, info, action, *args):
return reverse( return reverse(
"admin:%s_%s_%s" % (info + (action,)), "admin:%s_%s_%s" % (info + (action,)),
@ -307,7 +314,6 @@ class RelatedFieldWidgetWrapper(forms.Widget):
rel_opts = self.rel.model._meta rel_opts = self.rel.model._meta
info = (rel_opts.app_label, rel_opts.model_name) info = (rel_opts.app_label, rel_opts.model_name)
self.widget.choices = self.choices
related_field_name = self.rel.get_related_field().name related_field_name = self.rel.get_related_field().name
url_params = "&".join( url_params = "&".join(
"%s=%s" % param "%s=%s" % param

View File

@ -1,4 +1,3 @@
import collections.abc
import copy import copy
import datetime import datetime
import decimal import decimal
@ -14,9 +13,9 @@ from django.conf import settings
from django.core import checks, exceptions, validators from django.core import checks, exceptions, validators
from django.db import connection, connections, router from django.db import connection, connections, router
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.enums import ChoicesMeta
from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin
from django.utils import timezone from django.utils import timezone
from django.utils.choices import CallableChoiceIterator, normalize_choices
from django.utils.datastructures import DictWrapper from django.utils.datastructures import DictWrapper
from django.utils.dateparse import ( from django.utils.dateparse import (
parse_date, parse_date,
@ -225,10 +224,6 @@ class Field(RegisterLookupMixin):
self.unique_for_date = unique_for_date self.unique_for_date = unique_for_date
self.unique_for_month = unique_for_month self.unique_for_month = unique_for_month
self.unique_for_year = unique_for_year self.unique_for_year = unique_for_year
if isinstance(choices, ChoicesMeta):
choices = choices.choices
if isinstance(choices, collections.abc.Iterator):
choices = list(choices)
self.choices = choices self.choices = choices
self.help_text = help_text self.help_text = help_text
self.db_index = db_index self.db_index = db_index
@ -320,10 +315,13 @@ class Field(RegisterLookupMixin):
if not self.choices: if not self.choices:
return [] return []
if not is_iterable(self.choices) or isinstance(self.choices, str): if not is_iterable(self.choices) or isinstance(
self.choices, (str, CallableChoiceIterator)
):
return [ return [
checks.Error( checks.Error(
"'choices' must be an iterable (e.g., a list or tuple).", "'choices' must be a mapping (e.g. a dictionary) or an iterable "
"(e.g. a list or tuple).",
obj=self, obj=self,
id="fields.E004", id="fields.E004",
) )
@ -381,8 +379,8 @@ class Field(RegisterLookupMixin):
return [ return [
checks.Error( checks.Error(
"'choices' must be an iterable containing " "'choices' must be a mapping of actual values to human readable names "
"(actual value, human readable name) tuples.", "or an iterable containing (actual value, human readable name) tuples.",
obj=self, obj=self,
id="fields.E005", id="fields.E005",
) )
@ -543,6 +541,14 @@ class Field(RegisterLookupMixin):
return Col(alias, self, output_field) return Col(alias, self, output_field)
@property
def choices(self):
return self._choices
@choices.setter
def choices(self, value):
self._choices = normalize_choices(value)
@cached_property @cached_property
def cached_col(self): def cached_col(self):
from django.db.models.expressions import Col from django.db.models.expressions import Col
@ -625,9 +631,8 @@ class Field(RegisterLookupMixin):
equals_comparison = {"choices", "validators"} equals_comparison = {"choices", "validators"}
for name, default in possibles.items(): for name, default in possibles.items():
value = getattr(self, attr_overrides.get(name, name)) value = getattr(self, attr_overrides.get(name, name))
# Unroll anything iterable for choices into a concrete list if isinstance(value, CallableChoiceIterator):
if name == "choices" and isinstance(value, collections.abc.Iterable): value = value.func
value = list(value)
# Do correct kind of comparison # Do correct kind of comparison
if name in equals_comparison: if name in equals_comparison:
if value != default: if value != default:

View File

@ -17,7 +17,6 @@ from urllib.parse import urlsplit, urlunsplit
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models.enums import ChoicesMeta
from django.forms.boundfield import BoundField from django.forms.boundfield import BoundField
from django.forms.utils import from_current_timezone, to_current_timezone from django.forms.utils import from_current_timezone, to_current_timezone
from django.forms.widgets import ( from django.forms.widgets import (
@ -42,6 +41,7 @@ from django.forms.widgets import (
URLInput, URLInput,
) )
from django.utils import formats from django.utils import formats
from django.utils.choices import normalize_choices
from django.utils.dateparse import parse_datetime, parse_duration from django.utils.dateparse import parse_datetime, parse_duration
from django.utils.deprecation import RemovedInDjango60Warning from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.duration import duration_string from django.utils.duration import duration_string
@ -861,14 +861,6 @@ class NullBooleanField(BooleanField):
pass pass
class CallableChoiceIterator:
def __init__(self, choices_func):
self.choices_func = choices_func
def __iter__(self):
yield from self.choices_func()
class ChoiceField(Field): class ChoiceField(Field):
widget = Select widget = Select
default_error_messages = { default_error_messages = {
@ -879,8 +871,6 @@ class ChoiceField(Field):
def __init__(self, *, choices=(), **kwargs): def __init__(self, *, choices=(), **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if isinstance(choices, ChoicesMeta):
choices = choices.choices
self.choices = choices self.choices = choices
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -888,21 +878,15 @@ class ChoiceField(Field):
result._choices = copy.deepcopy(self._choices, memo) result._choices = copy.deepcopy(self._choices, memo)
return result return result
def _get_choices(self): @property
def choices(self):
return self._choices return self._choices
def _set_choices(self, value): @choices.setter
# Setting choices also sets the choices on the widget. def choices(self, value):
# choices can be any iterable, but we call list() on it because # Setting choices on the field also sets the choices on the widget.
# it will be consumed more than once. # Note that the property setter for the widget will re-normalize.
if callable(value): self._choices = self.widget.choices = normalize_choices(value)
value = CallableChoiceIterator(value)
else:
value = list(value)
self._choices = self.widget.choices = value
choices = property(_get_choices, _set_choices)
def to_python(self, value): def to_python(self, value):
"""Return a string.""" """Return a string."""

View File

@ -21,6 +21,7 @@ from django.forms.widgets import (
RadioSelect, RadioSelect,
SelectMultiple, SelectMultiple,
) )
from django.utils.choices import ChoiceIterator
from django.utils.text import capfirst, get_text_list from django.utils.text import capfirst, get_text_list
from django.utils.translation import gettext from django.utils.translation import gettext
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -1402,7 +1403,7 @@ class ModelChoiceIteratorValue:
return self.value == other return self.value == other
class ModelChoiceIterator: class ModelChoiceIterator(ChoiceIterator):
def __init__(self, field): def __init__(self, field):
self.field = field self.field = field
self.queryset = field.queryset self.queryset = field.queryset
@ -1532,7 +1533,7 @@ class ModelChoiceField(ChoiceField):
# the queryset. # the queryset.
return self.iterator(self) return self.iterator(self)
choices = property(_get_choices, ChoiceField._set_choices) choices = property(_get_choices, ChoiceField.choices.fset)
def prepare_value(self, value): def prepare_value(self, value):
if hasattr(value, "_meta"): if hasattr(value, "_meta"):

View File

@ -12,6 +12,7 @@ from itertools import chain
from django.forms.utils import to_current_timezone from django.forms.utils import to_current_timezone
from django.templatetags.static import static from django.templatetags.static import static
from django.utils import formats from django.utils import formats
from django.utils.choices import normalize_choices
from django.utils.dates import MONTHS from django.utils.dates import MONTHS
from django.utils.formats import get_format from django.utils.formats import get_format
from django.utils.html import format_html, html_safe from django.utils.html import format_html, html_safe
@ -620,10 +621,7 @@ class ChoiceWidget(Widget):
def __init__(self, attrs=None, choices=()): def __init__(self, attrs=None, choices=()):
super().__init__(attrs) super().__init__(attrs)
# choices can be any iterable, but we may need to render this widget self.choices = choices
# multiple times. Thus, collapse it into a list so it can be consumed
# more than once.
self.choices = list(choices)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
obj = copy.copy(self) obj = copy.copy(self)
@ -741,6 +739,14 @@ class ChoiceWidget(Widget):
value = [value] value = [value]
return [str(v) if v is not None else "" for v in value] return [str(v) if v is not None else "" for v in value]
@property
def choices(self):
return self._choices
@choices.setter
def choices(self, value):
self._choices = normalize_choices(value)
class Select(ChoiceWidget): class Select(ChoiceWidget):
input_type = "select" input_type = "select"

63
django/utils/choices.py Normal file
View File

@ -0,0 +1,63 @@
from collections.abc import Callable, Iterable, Iterator, Mapping
from django.db.models.enums import ChoicesMeta
from django.utils.functional import Promise
class ChoiceIterator:
"""Base class for lazy iterators for choices."""
class CallableChoiceIterator(ChoiceIterator):
"""Iterator to lazily normalize choices generated by a callable."""
def __init__(self, func):
self.func = func
def __iter__(self):
yield from normalize_choices(self.func())
def normalize_choices(value, *, depth=0):
"""Normalize choices values consistently for fields and widgets."""
match value:
case ChoiceIterator() | Promise() | bytes() | str():
# Avoid prematurely normalizing iterators that should be lazy.
# Because string-like types are iterable, return early to avoid
# iterating over them in the guard for the Iterable case below.
return value
case ChoicesMeta():
# Choices enumeration helpers already output in canonical form.
return value.choices
case Mapping() if depth < 2:
value = value.items()
case Iterator() if depth < 2:
# Although Iterator would be handled by the Iterable case below,
# the iterator would be consumed prematurely while checking that
# its elements are not string-like in the guard, so we handle it
# separately.
pass
case Iterable() if depth < 2 and not any(
isinstance(x, (Promise, bytes, str)) for x in value
):
# String-like types are iterable, so the guard above ensures that
# they're handled by the default case below.
pass
case Callable() if depth == 0:
# If at the top level, wrap callables to be evaluated lazily.
return CallableChoiceIterator(value)
case Callable() if depth < 2:
value = value()
case _:
return value
try:
# Recursive call to convert any nested values to a list of 2-tuples.
return [(k, normalize_choices(v, depth=depth + 1)) for k, v in value]
except (TypeError, ValueError):
# Return original value for the system check to raise if it has items
# that are not iterable or not 2-tuples:
# - TypeError: cannot unpack non-iterable <type> object
# - ValueError: <not enough / too many> values to unpack
return value

View File

@ -298,16 +298,23 @@ Model style
* Any custom methods * Any custom methods
* If ``choices`` is defined for a given model field, define each choice as a * If ``choices`` is defined for a given model field, define each choice as a
list of tuples, with an all-uppercase name as a class attribute on the model. mapping, with an all-uppercase name as a class attribute on the model.
Example:: Example::
class MyModel(models.Model): class MyModel(models.Model):
DIRECTION_UP = "U" DIRECTION_UP = "U"
DIRECTION_DOWN = "D" DIRECTION_DOWN = "D"
DIRECTION_CHOICES = [ DIRECTION_CHOICES = {
(DIRECTION_UP, "Up"), DIRECTION_UP: "Up",
(DIRECTION_DOWN, "Down"), DIRECTION_DOWN: "Down",
] }
Alternatively, consider using :ref:`field-choices-enum-types`::
class MyModel(models.Model):
class Direction(models.TextChoices):
UP = U, "Up"
DOWN = D, "Down"
Use of ``django.conf.settings`` Use of ``django.conf.settings``
=============================== ===============================

View File

@ -165,9 +165,11 @@ Model fields
* **fields.E002**: Field names must not contain ``"__"``. * **fields.E002**: Field names must not contain ``"__"``.
* **fields.E003**: ``pk`` is a reserved word that cannot be used as a field * **fields.E003**: ``pk`` is a reserved word that cannot be used as a field
name. name.
* **fields.E004**: ``choices`` must be an iterable (e.g., a list or tuple). * **fields.E004**: ``choices`` must be a mapping (e.g. a dictionary) or an
* **fields.E005**: ``choices`` must be an iterable containing ``(actual value, iterable (e.g. a list or tuple).
human readable name)`` tuples. * **fields.E005**: ``choices`` must be a mapping of actual values to human
readable names or an iterable containing ``(actual value, human readable
name)`` tuples.
* **fields.E006**: ``db_index`` must be ``None``, ``True`` or ``False``. * **fields.E006**: ``db_index`` must be ``None``, ``True`` or ``False``.
* **fields.E007**: Primary keys must not have ``null=True``. * **fields.E007**: Primary keys must not have ``null=True``.
* **fields.E008**: All ``validators`` must be callable. * **fields.E008**: All ``validators`` must be callable.

View File

@ -47,11 +47,11 @@ news application with an ``Article`` model::
from django.db import models from django.db import models
STATUS_CHOICES = [ STATUS_CHOICES = {
("d", "Draft"), "d": "Draft",
("p", "Published"), "p": "Published",
("w", "Withdrawn"), "w": "Withdrawn",
] }
class Article(models.Model): class Article(models.Model):

View File

@ -510,8 +510,9 @@ For each field, we describe the default widget used if you don't specify
.. versionchanged:: 5.0 .. versionchanged:: 5.0
Support for using :ref:`enumeration types <field-choices-enum-types>` Support for mappings and using
directly in the ``choices`` was added. :ref:`enumeration types <field-choices-enum-types>` directly in
``choices`` was added.
``DateField`` ``DateField``
------------- -------------

View File

@ -58,11 +58,11 @@ widget on the field. In the following example, the
from django import forms from django import forms
BIRTH_YEAR_CHOICES = ["1980", "1981", "1982"] BIRTH_YEAR_CHOICES = ["1980", "1981", "1982"]
FAVORITE_COLORS_CHOICES = [ FAVORITE_COLORS_CHOICES = {
("blue", "Blue"), "blue": "Blue",
("green", "Green"), "green": "Green",
("black", "Black"), "black": "Black",
] }
class SimpleForm(forms.Form): class SimpleForm(forms.Form):
@ -95,7 +95,7 @@ example:
.. code-block:: pycon .. code-block:: pycon
>>> from django import forms >>> from django import forms
>>> CHOICES = [("1", "First"), ("2", "Second")] >>> CHOICES = {"1": "First", "2": "Second"}
>>> choice_field = forms.ChoiceField(widget=forms.RadioSelect, choices=CHOICES) >>> choice_field = forms.ChoiceField(widget=forms.RadioSelect, choices=CHOICES)
>>> choice_field.choices >>> choice_field.choices
[('1', 'First'), ('2', 'Second')] [('1', 'First'), ('2', 'Second')]
@ -458,9 +458,9 @@ foundation for custom widgets.
class DateSelectorWidget(forms.MultiWidget): class DateSelectorWidget(forms.MultiWidget):
def __init__(self, attrs=None): def __init__(self, attrs=None):
days = [(day, day) for day in range(1, 32)] days = {day: day for day in range(1, 32)}
months = [(month, month) for month in range(1, 13)] months = {month: month for month in range(1, 13)}
years = [(year, year) for year in [2018, 2019, 2020]] years = {year: year for year in [2018, 2019, 2020]}
widgets = [ widgets = [
forms.Select(attrs=attrs, choices=days), forms.Select(attrs=attrs, choices=days),
forms.Select(attrs=attrs, choices=months), forms.Select(attrs=attrs, choices=months),

View File

@ -22,11 +22,11 @@ We'll be using the following model in the subsequent examples::
REGULAR = "R" REGULAR = "R"
GOLD = "G" GOLD = "G"
PLATINUM = "P" PLATINUM = "P"
ACCOUNT_TYPE_CHOICES = [ ACCOUNT_TYPE_CHOICES = {
(REGULAR, "Regular"), REGULAR: "Regular",
(GOLD, "Gold"), GOLD: "Gold",
(PLATINUM, "Platinum"), PLATINUM: "Platinum",
] }
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
registered_on = models.DateField() registered_on = models.DateField()
account_type = models.CharField( account_type = models.CharField(

View File

@ -86,14 +86,26 @@ If a field has ``blank=False``, the field will be required.
.. attribute:: Field.choices .. attribute:: Field.choices
A :term:`sequence` consisting itself of iterables of exactly two items (e.g. A mapping or iterable in the format described below to use as choices for this
``[(A, B), (A, B) ...]``) to use as choices for this field. If choices are field. If choices are given, they're enforced by
given, they're enforced by :ref:`model validation <validating-objects>` and the :ref:`model validation <validating-objects>` and the default form widget will
default form widget will be a select box with these choices instead of the be a select box with these choices instead of the standard text field.
standard text field.
The first element in each tuple is the actual value to be set on the model, If a mapping is given, the key element is the actual value to be set on the
and the second element is the human-readable name. For example:: model, and the second element is the human readable name. For example::
YEAR_IN_SCHOOL_CHOICES = {
"FR": "Freshman",
"SO": "Sophomore",
"JR": "Junior",
"SR": "Senior",
"GR": "Graduate",
}
You can also pass a :term:`sequence` consisting itself of iterables of exactly
two items (e.g. ``[(A1, B1), (A2, B2), …]``). The first element in each tuple
is the actual value to be set on the model, and the second element is the
human-readable name. For example::
YEAR_IN_SCHOOL_CHOICES = [ YEAR_IN_SCHOOL_CHOICES = [
("FR", "Freshman"), ("FR", "Freshman"),
@ -103,6 +115,10 @@ and the second element is the human-readable name. For example::
("GR", "Graduate"), ("GR", "Graduate"),
] ]
.. versionchanged:: 5.0
Support for mappings was added.
Generally, it's best to define choices inside a model class, and to Generally, it's best to define choices inside a model class, and to
define a suitably-named constant for each value:: define a suitably-named constant for each value::
@ -115,13 +131,13 @@ define a suitably-named constant for each value::
JUNIOR = "JR" JUNIOR = "JR"
SENIOR = "SR" SENIOR = "SR"
GRADUATE = "GR" GRADUATE = "GR"
YEAR_IN_SCHOOL_CHOICES = [ YEAR_IN_SCHOOL_CHOICES = {
(FRESHMAN, "Freshman"), FRESHMAN: "Freshman",
(SOPHOMORE, "Sophomore"), SOPHOMORE: "Sophomore",
(JUNIOR, "Junior"), JUNIOR: "Junior",
(SENIOR, "Senior"), SENIOR: "Senior",
(GRADUATE, "Graduate"), GRADUATE: "Graduate",
] }
year_in_school = models.CharField( year_in_school = models.CharField(
max_length=2, max_length=2,
choices=YEAR_IN_SCHOOL_CHOICES, choices=YEAR_IN_SCHOOL_CHOICES,
@ -142,6 +158,25 @@ will work anywhere that the ``Student`` model has been imported).
You can also collect your available choices into named groups that can You can also collect your available choices into named groups that can
be used for organizational purposes:: be used for organizational purposes::
MEDIA_CHOICES = {
"Audio": {
"vinyl": "Vinyl",
"cd": "CD",
},
"Video": {
"vhs": "VHS Tape",
"dvd": "DVD",
},
"unknown": "Unknown",
}
The key of the mapping is the name to apply to the group and the value is the
choices inside that group, consisting of the field value and a human-readable
name for an option. Grouped options may be combined with ungrouped options
within a single mapping (such as the ``"unknown"`` option in this example).
You can also use a sequence, e.g. a list of 2-tuples::
MEDIA_CHOICES = [ MEDIA_CHOICES = [
( (
"Audio", "Audio",
@ -160,17 +195,6 @@ be used for organizational purposes::
("unknown", "Unknown"), ("unknown", "Unknown"),
] ]
The first element in each tuple is the name to apply to the group. The
second element is an iterable of 2-tuples, with each 2-tuple containing
a value and a human-readable name for an option. Grouped options may be
combined with ungrouped options within a single list (such as the
``'unknown'`` option in this example).
For each model field that has :attr:`~Field.choices` set, Django will add a
method to retrieve the human-readable name for the field's current value. See
:meth:`~django.db.models.Model.get_FOO_display` in the database API
documentation.
Note that choices can be any sequence object -- not necessarily a list or Note that choices can be any sequence object -- not necessarily a list or
tuple. This lets you construct choices dynamically. But if you find yourself tuple. This lets you construct choices dynamically. But if you find yourself
hacking :attr:`~Field.choices` to be dynamic, you're probably better off using hacking :attr:`~Field.choices` to be dynamic, you're probably better off using
@ -180,6 +204,12 @@ meant for static data that doesn't change much, if ever.
.. note:: .. note::
A new migration is created each time the order of ``choices`` changes. A new migration is created each time the order of ``choices`` changes.
For each model field that has :attr:`~Field.choices` set, Django will normalize
the choices to a list of 2-tuples and add a method to retrieve the
human-readable name for the field's current value. See
:meth:`~django.db.models.Model.get_FOO_display` in the database API
documentation.
.. _field-choices-blank-label: .. _field-choices-blank-label:
Unless :attr:`blank=False<Field.blank>` is set on the field along with a Unless :attr:`blank=False<Field.blank>` is set on the field along with a

View File

@ -912,11 +912,11 @@ For example::
class Person(models.Model): class Person(models.Model):
SHIRT_SIZES = [ SHIRT_SIZES = {
("S", "Small"), "S": "Small",
("M", "Medium"), "M": "Medium",
("L", "Large"), "L": "Large",
] }
name = models.CharField(max_length=60) name = models.CharField(max_length=60)
shirt_size = models.CharField(max_length=2, choices=SHIRT_SIZES) shirt_size = models.CharField(max_length=2, choices=SHIRT_SIZES)

View File

@ -129,6 +129,55 @@ sets a database-computed default value. For example::
created = models.DateTimeField(db_default=Now()) created = models.DateTimeField(db_default=Now())
circumference = models.FloatField(db_default=2 * Pi()) circumference = models.FloatField(db_default=2 * Pi())
More options for declaring field choices
----------------------------------------
:attr:`.Field.choices` *(for model fields)* and :attr:`.ChoiceField.choices`
*(for form fields)* allow for more flexibility when declaring their values. In
previous versions of Django, ``choices`` should either be a list of 2-tuples,
or an :ref:`field-choices-enum-types` subclass, but the latter required
accessing the ``.choices`` attribute to provide the values in the expected
form::
from django.db import models
Medal = models.TextChoices("Medal", "GOLD SILVER BRONZE")
SPORT_CHOICES = [
("Martial Arts", [("judo", "Judo"), ("karate", "Karate")]),
("Racket", [("badminton", "Badminton"), ("tennis", "Tennis")]),
("unknown", "Unknown"),
]
class Winners(models.Model):
name = models.CharField(...)
medal = models.CharField(..., choices=Medal.choices)
sport = models.CharField(..., choices=SPORT_CHOICES)
Django 5.0 supports providing a mapping instead of an iterable, and also no
longer requires ``.choices`` to be used directly to expand :ref:`enumeration
types <field-choices-enum-types>`::
from django.db import models
Medal = models.TextChoices("Medal", "GOLD SILVER BRONZE")
SPORT_CHOICES = { # Using a mapping instead of a list of 2-tuples.
"Martial Arts": {"judo": "Judo", "karate": "Karate"},
"Racket": {"badminton": "Badminton", "tennis": "Tennis"},
"unknown": "Unknown",
}
class Winners(models.Model):
name = models.CharField(...)
medal = models.CharField(..., choices=Medal) # Using `.choices` not required.
sport = models.CharField(..., choices=SPORT_CHOICES)
Under the hood the provided ``choices`` are normalized into a list of 2-tuples
as the canonical form whenever the ``choices`` value is updated.
Minor features Minor features
-------------- --------------
@ -304,10 +353,6 @@ File Uploads
Forms Forms
~~~~~ ~~~~~
* :attr:`.ChoiceField.choices` now accepts
:ref:`Choices classes <field-choices-enum-types>` directly instead of
requiring expansion with the ``choices`` attribute.
* The new ``assume_scheme`` argument for :class:`~django.forms.URLField` allows * The new ``assume_scheme`` argument for :class:`~django.forms.URLField` allows
specifying a default URL scheme. specifying a default URL scheme.
@ -357,10 +402,6 @@ Models
of ``ValidationError`` raised during of ``ValidationError`` raised during
:ref:`model validation <validating-objects>`. :ref:`model validation <validating-objects>`.
* :attr:`.Field.choices` now accepts
:ref:`Choices classes <field-choices-enum-types>` directly instead of
requiring expansion with the ``choices`` attribute.
* The :ref:`force_insert <ref-models-force-insert>` argument of * The :ref:`force_insert <ref-models-force-insert>` argument of
:meth:`.Model.save` now allows specifying a tuple of parent classes that must :meth:`.Model.save` now allows specifying a tuple of parent classes that must
be forced to be inserted. be forced to be inserted.

View File

@ -154,9 +154,7 @@ For example::
class Person(models.Model): class Person(models.Model):
first_name = models.CharField(max_length=50) first_name = models.CharField(max_length=50)
last_name = models.CharField(max_length=50) last_name = models.CharField(max_length=50)
role = models.CharField( role = models.CharField(max_length=1, choices={"A": _("Author"), "E": _("Editor")})
max_length=1, choices=[("A", _("Author")), ("E", _("Editor"))]
)
people = models.Manager() people = models.Manager()
authors = AuthorManager() authors = AuthorManager()
editors = EditorManager() editors = EditorManager()
@ -259,9 +257,7 @@ custom ``QuerySet`` if you also implement them on the ``Manager``::
class Person(models.Model): class Person(models.Model):
first_name = models.CharField(max_length=50) first_name = models.CharField(max_length=50)
last_name = models.CharField(max_length=50) last_name = models.CharField(max_length=50)
role = models.CharField( role = models.CharField(max_length=1, choices={"A": _("Author"), "E": _("Editor")})
max_length=1, choices=[("A", _("Author")), ("E", _("Editor"))]
)
people = PersonManager() people = PersonManager()
This example allows you to call both ``authors()`` and ``editors()`` directly from This example allows you to call both ``authors()`` and ``editors()`` directly from

View File

@ -185,11 +185,11 @@ ones:
class Person(models.Model): class Person(models.Model):
SHIRT_SIZES = [ SHIRT_SIZES = {
("S", "Small"), "S": "Small",
("M", "Medium"), "M": "Medium",
("L", "Large"), "L": "Large",
] }
name = models.CharField(max_length=60) name = models.CharField(max_length=60)
shirt_size = models.CharField(max_length=1, choices=SHIRT_SIZES) shirt_size = models.CharField(max_length=1, choices=SHIRT_SIZES)

View File

@ -173,11 +173,11 @@ Consider this set of models::
from django.db import models from django.db import models
from django.forms import ModelForm from django.forms import ModelForm
TITLE_CHOICES = [ TITLE_CHOICES = {
("MR", "Mr."), "MR": "Mr.",
("MRS", "Mrs."), "MRS": "Mrs.",
("MS", "Ms."), "MS": "Ms.",
] }
class Author(models.Model): class Author(models.Model):

View File

@ -2,6 +2,7 @@ from django.apps import apps
from django.db import models from django.db import models
from django.test import SimpleTestCase, override_settings from django.test import SimpleTestCase, override_settings
from django.test.utils import isolate_lru_cache from django.test.utils import isolate_lru_cache
from django.utils.choices import normalize_choices
class FieldDeconstructionTests(SimpleTestCase): class FieldDeconstructionTests(SimpleTestCase):
@ -105,12 +106,22 @@ class FieldDeconstructionTests(SimpleTestCase):
self.assertEqual(kwargs, {"choices": [(0, "0"), (1, "1"), (2, "2")]}) self.assertEqual(kwargs, {"choices": [(0, "0"), (1, "1"), (2, "2")]})
def test_choices_iterable(self): def test_choices_iterable(self):
# Pass an iterator (but not an iterable) to choices. # Pass an iterable (but not an iterator) to choices.
field = models.IntegerField(choices="012345") field = models.IntegerField(choices="012345")
name, path, args, kwargs = field.deconstruct() name, path, args, kwargs = field.deconstruct()
self.assertEqual(path, "django.db.models.IntegerField") self.assertEqual(path, "django.db.models.IntegerField")
self.assertEqual(args, []) self.assertEqual(args, [])
self.assertEqual(kwargs, {"choices": ["0", "1", "2", "3", "4", "5"]}) self.assertEqual(kwargs, {"choices": normalize_choices("012345")})
def test_choices_callable(self):
def get_choices():
return [(i, str(i)) for i in range(3)]
field = models.IntegerField(choices=get_choices)
name, path, args, kwargs = field.deconstruct()
self.assertEqual(path, "django.db.models.IntegerField")
self.assertEqual(args, [])
self.assertEqual(kwargs, {"choices": get_choices})
def test_csi_field(self): def test_csi_field(self):
field = models.CommaSeparatedIntegerField(max_length=100) field = models.CommaSeparatedIntegerField(max_length=100)

View File

@ -65,6 +65,51 @@ class ChoiceFieldTest(FormFieldAssertionsMixin, SimpleTestCase):
f = ChoiceField(choices=choices) f = ChoiceField(choices=choices)
self.assertEqual("J", f.clean("J")) self.assertEqual("J", f.clean("J"))
def test_choicefield_callable_mapping(self):
def choices():
return {"J": "John", "P": "Paul"}
f = ChoiceField(choices=choices)
self.assertEqual("J", f.clean("J"))
def test_choicefield_callable_grouped_mapping(self):
def choices():
return {
"Numbers": {"1": "One", "2": "Two"},
"Letters": {"3": "A", "4": "B"},
}
f = ChoiceField(choices=choices)
for i in ("1", "2", "3", "4"):
with self.subTest(i):
self.assertEqual(i, f.clean(i))
def test_choicefield_mapping(self):
f = ChoiceField(choices={"J": "John", "P": "Paul"})
self.assertEqual("J", f.clean("J"))
def test_choicefield_grouped_mapping(self):
f = ChoiceField(
choices={
"Numbers": (("1", "One"), ("2", "Two")),
"Letters": (("3", "A"), ("4", "B")),
}
)
for i in ("1", "2", "3", "4"):
with self.subTest(i):
self.assertEqual(i, f.clean(i))
def test_choicefield_grouped_mapping_inner_dict(self):
f = ChoiceField(
choices={
"Numbers": {"1": "One", "2": "Two"},
"Letters": {"3": "A", "4": "B"},
}
)
for i in ("1", "2", "3", "4"):
with self.subTest(i):
self.assertEqual(i, f.clean(i))
def test_choicefield_callable_may_evaluate_to_different_values(self): def test_choicefield_callable_may_evaluate_to_different_values(self):
choices = [] choices = []
@ -76,11 +121,13 @@ class ChoiceFieldTest(FormFieldAssertionsMixin, SimpleTestCase):
choices = [("J", "John")] choices = [("J", "John")]
form = ChoiceFieldForm() form = ChoiceFieldForm()
self.assertEqual([("J", "John")], list(form.fields["choicefield"].choices)) self.assertEqual(choices, list(form.fields["choicefield"].choices))
self.assertEqual(choices, list(form.fields["choicefield"].widget.choices))
choices = [("P", "Paul")] choices = [("P", "Paul")]
form = ChoiceFieldForm() form = ChoiceFieldForm()
self.assertEqual([("P", "Paul")], list(form.fields["choicefield"].choices)) self.assertEqual(choices, list(form.fields["choicefield"].choices))
self.assertEqual(choices, list(form.fields["choicefield"].widget.choices))
def test_choicefield_disabled(self): def test_choicefield_disabled(self):
f = ChoiceField(choices=[("J", "John"), ("P", "Paul")], disabled=True) f = ChoiceField(choices=[("J", "John"), ("P", "Paul")], disabled=True)

View File

@ -9,13 +9,26 @@ class ChoiceWidgetTest(WidgetTest):
widget = ChoiceWidget widget = ChoiceWidget
@property @property
def nested_widget(self): def nested_widgets(self):
return self.widget( nested_widget = self.widget(
choices=( choices=(
("outer1", "Outer 1"), ("outer1", "Outer 1"),
('Group "1"', (("inner1", "Inner 1"), ("inner2", "Inner 2"))), ('Group "1"', (("inner1", "Inner 1"), ("inner2", "Inner 2"))),
) ),
) )
nested_widget_dict = self.widget(
choices={
"outer1": "Outer 1",
'Group "1"': {"inner1": "Inner 1", "inner2": "Inner 2"},
},
)
nested_widget_dict_tuple = self.widget(
choices={
"outer1": "Outer 1",
'Group "1"': (("inner1", "Inner 1"), ("inner2", "Inner 2")),
},
)
return (nested_widget, nested_widget_dict, nested_widget_dict_tuple)
def test_deepcopy(self): def test_deepcopy(self):
""" """

View File

@ -13,7 +13,6 @@ class RadioSelectTest(ChoiceWidgetTest):
widget = RadioSelect widget = RadioSelect
def test_render(self): def test_render(self):
choices = BLANK_CHOICE_DASH + self.beatles
html = """ html = """
<div> <div>
<div> <div>
@ -33,7 +32,10 @@ class RadioSelectTest(ChoiceWidgetTest):
</div> </div>
</div> </div>
""" """
self.check_html(self.widget(choices=choices), "beatle", "J", html=html) beatles_with_blank = BLANK_CHOICE_DASH + self.beatles
for choices in (beatles_with_blank, dict(beatles_with_blank)):
with self.subTest(choices):
self.check_html(self.widget(choices=choices), "beatle", "J", html=html)
def test_nested_choices(self): def test_nested_choices(self):
nested_choices = ( nested_choices = (
@ -312,7 +314,9 @@ class RadioSelectTest(ChoiceWidgetTest):
</div> </div>
</div> </div>
""" """
self.check_html(self.nested_widget, "nestchoice", None, html=html) for widget in self.nested_widgets:
with self.subTest(widget):
self.check_html(widget, "nestchoice", None, html=html)
def test_choices_select_outer(self): def test_choices_select_outer(self):
html = """ html = """
@ -334,7 +338,9 @@ class RadioSelectTest(ChoiceWidgetTest):
</div> </div>
</div> </div>
""" """
self.check_html(self.nested_widget, "nestchoice", "outer1", html=html) for widget in self.nested_widgets:
with self.subTest(widget):
self.check_html(widget, "nestchoice", "outer1", html=html)
def test_choices_select_inner(self): def test_choices_select_inner(self):
html = """ html = """
@ -356,7 +362,9 @@ class RadioSelectTest(ChoiceWidgetTest):
</div> </div>
</div> </div>
""" """
self.check_html(self.nested_widget, "nestchoice", "inner2", html=html) for widget in self.nested_widgets:
with self.subTest(widget):
self.check_html(widget, "nestchoice", "inner2", html=html)
def test_render_attrs(self): def test_render_attrs(self):
""" """

View File

@ -11,19 +11,17 @@ class SelectTest(ChoiceWidgetTest):
widget = Select widget = Select
def test_render(self): def test_render(self):
self.check_html( html = """
self.widget(choices=self.beatles), <select name="beatle">
"beatle", <option value="J" selected>John</option>
"J", <option value="P">Paul</option>
html=( <option value="G">George</option>
"""<select name="beatle"> <option value="R">Ringo</option>
<option value="J" selected>John</option> </select>
<option value="P">Paul</option> """
<option value="G">George</option> for choices in (self.beatles, dict(self.beatles)):
<option value="R">Ringo</option> with self.subTest(choices):
</select>""" self.check_html(self.widget(choices=choices), "beatle", "J", html=html)
),
)
def test_render_none(self): def test_render_none(self):
""" """
@ -237,52 +235,46 @@ class SelectTest(ChoiceWidgetTest):
""" """
Choices can be nested one level in order to create HTML optgroups. Choices can be nested one level in order to create HTML optgroups.
""" """
self.check_html( html = """
self.nested_widget, <select name="nestchoice">
"nestchoice", <option value="outer1">Outer 1</option>
None, <optgroup label="Group &quot;1&quot;">
html=( <option value="inner1">Inner 1</option>
"""<select name="nestchoice"> <option value="inner2">Inner 2</option>
<option value="outer1">Outer 1</option> </optgroup>
<optgroup label="Group &quot;1&quot;"> </select>
<option value="inner1">Inner 1</option> """
<option value="inner2">Inner 2</option> for widget in self.nested_widgets:
</optgroup> with self.subTest(widget):
</select>""" self.check_html(widget, "nestchoice", None, html=html)
),
)
def test_choices_select_outer(self): def test_choices_select_outer(self):
self.check_html( html = """
self.nested_widget, <select name="nestchoice">
"nestchoice", <option value="outer1" selected>Outer 1</option>
"outer1", <optgroup label="Group &quot;1&quot;">
html=( <option value="inner1">Inner 1</option>
"""<select name="nestchoice"> <option value="inner2">Inner 2</option>
<option value="outer1" selected>Outer 1</option> </optgroup>
<optgroup label="Group &quot;1&quot;"> </select>
<option value="inner1">Inner 1</option> """
<option value="inner2">Inner 2</option> for widget in self.nested_widgets:
</optgroup> with self.subTest(widget):
</select>""" self.check_html(widget, "nestchoice", "outer1", html=html)
),
)
def test_choices_select_inner(self): def test_choices_select_inner(self):
self.check_html( html = """
self.nested_widget, <select name="nestchoice">
"nestchoice", <option value="outer1">Outer 1</option>
"inner1", <optgroup label="Group &quot;1&quot;">
html=( <option value="inner1" selected>Inner 1</option>
"""<select name="nestchoice"> <option value="inner2">Inner 2</option>
<option value="outer1">Outer 1</option> </optgroup>
<optgroup label="Group &quot;1&quot;"> </select>
<option value="inner1" selected>Inner 1</option> """
<option value="inner2">Inner 2</option> for widget in self.nested_widgets:
</optgroup> with self.subTest(widget):
</select>""" self.check_html(widget, "nestchoice", "inner1", html=html)
),
)
@override_settings(USE_THOUSAND_SEPARATOR=True) @override_settings(USE_THOUSAND_SEPARATOR=True)
def test_doesnt_localize_option_value(self): def test_doesnt_localize_option_value(self):
@ -312,24 +304,7 @@ class SelectTest(ChoiceWidgetTest):
""" """
self.check_html(self.widget(choices=choices), "time", None, html=html) self.check_html(self.widget(choices=choices), "time", None, html=html)
def test_optgroups(self): def _test_optgroups(self, choices):
choices = [
(
"Audio",
[
("vinyl", "Vinyl"),
("cd", "CD"),
],
),
(
"Video",
[
("vhs", "VHS Tape"),
("dvd", "DVD"),
],
),
("unknown", "Unknown"),
]
groups = list( groups = list(
self.widget(choices=choices).optgroups( self.widget(choices=choices).optgroups(
"name", "name",
@ -418,6 +393,27 @@ class SelectTest(ChoiceWidgetTest):
) )
self.assertEqual(index, 2) self.assertEqual(index, 2)
def test_optgroups(self):
choices_dict = {
"Audio": [
("vinyl", "Vinyl"),
("cd", "CD"),
],
"Video": [
("vhs", "VHS Tape"),
("dvd", "DVD"),
],
"unknown": "Unknown",
}
choices_list = list(choices_dict.items())
choices_nested_dict = {
k: dict(v) if isinstance(v, list) else v for k, v in choices_dict.items()
}
for choices in (choices_dict, choices_list, choices_nested_dict):
with self.subTest(choices):
self._test_optgroups(choices)
def test_doesnt_render_required_when_impossible_to_select_empty_field(self): def test_doesnt_render_required_when_impossible_to_select_empty_field(self):
widget = self.widget(choices=[("J", "John"), ("P", "Paul")]) widget = self.widget(choices=[("J", "John"), ("P", "Paul")])
self.assertIs(widget.use_required_attribute(initial=None), False) self.assertIs(widget.use_required_attribute(initial=None), False)

View File

@ -199,7 +199,8 @@ class CharFieldTests(TestCase):
field.check(), field.check(),
[ [
Error( Error(
"'choices' must be an iterable (e.g., a list or tuple).", "'choices' must be a mapping (e.g. a dictionary) or an iterable "
"(e.g. a list or tuple).",
obj=field, obj=field,
id="fields.E004", id="fields.E004",
), ),
@ -217,8 +218,9 @@ class CharFieldTests(TestCase):
field.check(), field.check(),
[ [
Error( Error(
"'choices' must be an iterable containing (actual value, " "'choices' must be a mapping of actual values to human readable "
"human readable name) tuples.", "names or an iterable containing (actual value, human readable "
"name) tuples.",
obj=field, obj=field,
id="fields.E005", id="fields.E005",
), ),
@ -260,8 +262,9 @@ class CharFieldTests(TestCase):
field.check(), field.check(),
[ [
Error( Error(
"'choices' must be an iterable containing (actual " "'choices' must be a mapping of actual values to human "
"value, human readable name) tuples.", "readable names or an iterable containing (actual value, "
"human readable name) tuples.",
obj=field, obj=field,
id="fields.E005", id="fields.E005",
), ),
@ -309,8 +312,9 @@ class CharFieldTests(TestCase):
field.check(), field.check(),
[ [
Error( Error(
"'choices' must be an iterable containing (actual value, " "'choices' must be a mapping of actual values to human readable "
"human readable name) tuples.", "names or an iterable containing (actual value, human readable "
"name) tuples.",
obj=field, obj=field,
id="fields.E005", id="fields.E005",
), ),
@ -337,8 +341,9 @@ class CharFieldTests(TestCase):
field.check(), field.check(),
[ [
Error( Error(
"'choices' must be an iterable containing (actual value, " "'choices' must be a mapping of actual values to human readable "
"human readable name) tuples.", "names or an iterable containing (actual value, human readable "
"name) tuples.",
obj=field, obj=field,
id="fields.E005", id="fields.E005",
), ),
@ -386,6 +391,26 @@ class CharFieldTests(TestCase):
], ],
) )
def test_choices_callable(self):
def get_choices():
return [(i, i) for i in range(3)]
class Model(models.Model):
field = models.CharField(max_length=10, choices=get_choices)
field = Model._meta.get_field("field")
self.assertEqual(
field.check(),
[
Error(
"'choices' must be a mapping (e.g. a dictionary) or an iterable "
"(e.g. a list or tuple).",
obj=field,
id="fields.E004",
),
],
)
def test_bad_db_index_value(self): def test_bad_db_index_value(self):
class Model(models.Model): class Model(models.Model):
field = models.CharField(max_length=10, db_index="bad") field = models.CharField(max_length=10, db_index="bad")
@ -854,7 +879,8 @@ class IntegerFieldTests(SimpleTestCase):
field.check(), field.check(),
[ [
Error( Error(
"'choices' must be an iterable (e.g., a list or tuple).", "'choices' must be a mapping (e.g. a dictionary) or an iterable "
"(e.g. a list or tuple).",
obj=field, obj=field,
id="fields.E004", id="fields.E004",
), ),
@ -872,8 +898,9 @@ class IntegerFieldTests(SimpleTestCase):
field.check(), field.check(),
[ [
Error( Error(
"'choices' must be an iterable containing (actual value, human " "'choices' must be a mapping of actual values to human readable "
"readable name) tuples.", "names or an iterable containing (actual value, human readable "
"name) tuples.",
obj=field, obj=field,
id="fields.E005", id="fields.E005",
), ),

View File

@ -460,6 +460,16 @@ class WriterTests(SimpleTestCase):
"default=datetime.date(1969, 11, 19))", "default=datetime.date(1969, 11, 19))",
) )
def test_serialize_dictionary_choices(self):
for choices in ({"Group": [(2, "2"), (1, "1")]}, {"Group": {2: "2", 1: "1"}}):
with self.subTest(choices):
field = models.IntegerField(choices=choices)
string = MigrationWriter.serialize(field)[0]
self.assertEqual(
string,
"models.IntegerField(choices=[('Group', [(2, '2'), (1, '1')])])",
)
def test_serialize_nested_class(self): def test_serialize_nested_class(self):
for nested_cls in [self.NestedEnum, self.NestedChoices]: for nested_cls in [self.NestedEnum, self.NestedChoices]:
cls_name = nested_cls.__name__ cls_name = nested_cls.__name__

View File

@ -31,24 +31,18 @@ class Bar(models.Model):
class Whiz(models.Model): class Whiz(models.Model):
CHOICES = ( CHOICES = {
( "Group 1": {
"Group 1", 1: "First",
( 2: "Second",
(1, "First"), },
(2, "Second"), "Group 2": (
), (3, "Third"),
(4, "Fourth"),
), ),
( 0: "Other",
"Group 2", 5: _("translated"),
( }
(3, "Third"),
(4, "Fourth"),
),
),
(0, "Other"),
(5, _("translated")),
)
c = models.IntegerField(choices=CHOICES, null=True) c = models.IntegerField(choices=CHOICES, null=True)
@ -61,7 +55,7 @@ WhizDelayed._meta.get_field("c").choices = Whiz.CHOICES
class WhizIter(models.Model): class WhizIter(models.Model):
c = models.IntegerField(choices=iter(Whiz.CHOICES), null=True) c = models.IntegerField(choices=iter(Whiz.CHOICES.items()), null=True)
class WhizIterEmpty(models.Model): class WhizIterEmpty(models.Model):
@ -78,6 +72,10 @@ class Choiceful(models.Model):
no_choices = models.IntegerField(null=True) no_choices = models.IntegerField(null=True)
empty_choices = models.IntegerField(choices=(), null=True) empty_choices = models.IntegerField(choices=(), null=True)
with_choices = models.IntegerField(choices=[(1, "A")], null=True) with_choices = models.IntegerField(choices=[(1, "A")], null=True)
with_choices_dict = models.IntegerField(choices={1: "A"}, null=True)
with_choices_nested_dict = models.IntegerField(
choices={"Thing": {1: "A"}}, null=True
)
empty_choices_bool = models.BooleanField(choices=()) empty_choices_bool = models.BooleanField(choices=())
empty_choices_text = models.TextField(choices=()) empty_choices_text = models.TextField(choices=())
choices_from_enum = models.IntegerField(choices=Suit) choices_from_enum = models.IntegerField(choices=Suit)

View File

@ -279,6 +279,14 @@ class ValidationTests(SimpleTestCase):
f = models.IntegerField(choices=(("group", ((10, "A"), (20, "B"))), (30, "C"))) f = models.IntegerField(choices=(("group", ((10, "A"), (20, "B"))), (30, "C")))
self.assertEqual(10, f.clean(10, None)) self.assertEqual(10, f.clean(10, None))
def test_choices_validation_supports_named_groups_dicts(self):
f = models.IntegerField(choices={"group": ((10, "A"), (20, "B")), 30: "C"})
self.assertEqual(10, f.clean(10, None))
def test_choices_validation_supports_named_groups_nested_dicts(self):
f = models.IntegerField(choices={"group": {10: "A", 20: "B"}, 30: "C"})
self.assertEqual(10, f.clean(10, None))
def test_nullable_integerfield_raises_error_with_blank_false(self): def test_nullable_integerfield_raises_error_with_blank_false(self):
f = models.IntegerField(null=True, blank=False) f = models.IntegerField(null=True, blank=False)
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):

View File

@ -156,15 +156,21 @@ class ChoicesTests(SimpleTestCase):
cls.empty_choices_bool = Choiceful._meta.get_field("empty_choices_bool") cls.empty_choices_bool = Choiceful._meta.get_field("empty_choices_bool")
cls.empty_choices_text = Choiceful._meta.get_field("empty_choices_text") cls.empty_choices_text = Choiceful._meta.get_field("empty_choices_text")
cls.with_choices = Choiceful._meta.get_field("with_choices") cls.with_choices = Choiceful._meta.get_field("with_choices")
cls.with_choices_dict = Choiceful._meta.get_field("with_choices_dict")
cls.with_choices_nested_dict = Choiceful._meta.get_field(
"with_choices_nested_dict"
)
cls.choices_from_enum = Choiceful._meta.get_field("choices_from_enum") cls.choices_from_enum = Choiceful._meta.get_field("choices_from_enum")
cls.choices_from_iterator = Choiceful._meta.get_field("choices_from_iterator") cls.choices_from_iterator = Choiceful._meta.get_field("choices_from_iterator")
def test_choices(self): def test_choices(self):
self.assertIsNone(self.no_choices.choices) self.assertIsNone(self.no_choices.choices)
self.assertEqual(self.empty_choices.choices, ()) self.assertEqual(self.empty_choices.choices, [])
self.assertEqual(self.empty_choices_bool.choices, ()) self.assertEqual(self.empty_choices_bool.choices, [])
self.assertEqual(self.empty_choices_text.choices, ()) self.assertEqual(self.empty_choices_text.choices, [])
self.assertEqual(self.with_choices.choices, [(1, "A")]) self.assertEqual(self.with_choices.choices, [(1, "A")])
self.assertEqual(self.with_choices_dict.choices, [(1, "A")])
self.assertEqual(self.with_choices_nested_dict.choices, [("Thing", [(1, "A")])])
self.assertEqual( self.assertEqual(
self.choices_from_iterator.choices, [(0, "0"), (1, "1"), (2, "2")] self.choices_from_iterator.choices, [(0, "0"), (1, "1"), (2, "2")]
) )
@ -175,6 +181,8 @@ class ChoicesTests(SimpleTestCase):
self.assertEqual(self.empty_choices_bool.flatchoices, []) self.assertEqual(self.empty_choices_bool.flatchoices, [])
self.assertEqual(self.empty_choices_text.flatchoices, []) self.assertEqual(self.empty_choices_text.flatchoices, [])
self.assertEqual(self.with_choices.flatchoices, [(1, "A")]) self.assertEqual(self.with_choices.flatchoices, [(1, "A")])
self.assertEqual(self.with_choices_dict.flatchoices, [(1, "A")])
self.assertEqual(self.with_choices_nested_dict.flatchoices, [(1, "A")])
self.assertEqual( self.assertEqual(
self.choices_from_iterator.flatchoices, [(0, "0"), (1, "1"), (2, "2")] self.choices_from_iterator.flatchoices, [(0, "0"), (1, "1"), (2, "2")]
) )
@ -290,11 +298,11 @@ class GetChoicesTests(SimpleTestCase):
("b", "Bar"), ("b", "Bar"),
( (
"Group", "Group",
( [
("", "No Preference"), ("", "No Preference"),
("fg", "Foo"), ("fg", "Foo"),
("bg", "Bar"), ("bg", "Bar"),
), ],
), ),
] ]
f = models.CharField(choices=choices) f = models.CharField(choices=choices)
@ -302,7 +310,7 @@ class GetChoicesTests(SimpleTestCase):
def test_lazy_strings_not_evaluated(self): def test_lazy_strings_not_evaluated(self):
lazy_func = lazy(lambda x: 0 / 0, int) # raises ZeroDivisionError if evaluated. lazy_func = lazy(lambda x: 0 / 0, int) # raises ZeroDivisionError if evaluated.
f = models.CharField(choices=[(lazy_func("group"), (("a", "A"), ("b", "B")))]) f = models.CharField(choices=[(lazy_func("group"), [("a", "A"), ("b", "B")])])
self.assertEqual(f.get_choices(include_blank=True)[0], ("", "---------")) self.assertEqual(f.get_choices(include_blank=True)[0], ("", "---------"))

View File

@ -0,0 +1,305 @@
from unittest import mock
from django.db.models import TextChoices
from django.test import SimpleTestCase
from django.utils.choices import CallableChoiceIterator, normalize_choices
from django.utils.translation import gettext_lazy as _
class NormalizeFieldChoicesTests(SimpleTestCase):
expected = [
("C", _("Club")),
("D", _("Diamond")),
("H", _("Heart")),
("S", _("Spade")),
]
expected_nested = [
("Audio", [("vinyl", _("Vinyl")), ("cd", _("CD"))]),
("Video", [("vhs", _("VHS Tape")), ("dvd", _("DVD"))]),
("unknown", _("Unknown")),
]
invalid = [
1j,
123,
123.45,
"invalid",
b"invalid",
_("invalid"),
object(),
None,
True,
False,
]
invalid_iterable = [
# Special cases of a string-likes which would unpack incorrectly.
["ab"],
[b"ab"],
[_("ab")],
# Non-iterable items or iterable items with incorrect number of
# elements that cannot be unpacked.
[123],
[("value",)],
[("value", "label", "other")],
]
invalid_nested = [
# Nested choices can only be two-levels deep, so return callables,
# mappings, iterables, etc. at deeper levels unmodified.
[("Group", [("Value", lambda: "Label")])],
[("Group", [("Value", {"Label 1?": "Label 2?"})])],
[("Group", [("Value", [("Label 1?", "Label 2?")])])],
]
def test_empty(self):
def generator():
yield from ()
for choices in ({}, [], (), set(), frozenset(), generator()):
with self.subTest(choices=choices):
self.assertEqual(normalize_choices(choices), [])
def test_choices(self):
class Medal(TextChoices):
GOLD = "GOLD", _("Gold")
SILVER = "SILVER", _("Silver")
BRONZE = "BRONZE", _("Bronze")
expected = [
("GOLD", _("Gold")),
("SILVER", _("Silver")),
("BRONZE", _("Bronze")),
]
self.assertEqual(normalize_choices(Medal), expected)
def test_callable(self):
def get_choices():
return {
"C": _("Club"),
"D": _("Diamond"),
"H": _("Heart"),
"S": _("Spade"),
}
get_choices_spy = mock.Mock(wraps=get_choices)
output = normalize_choices(get_choices_spy)
get_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected)
get_choices_spy.assert_called_once()
def test_mapping(self):
choices = {
"C": _("Club"),
"D": _("Diamond"),
"H": _("Heart"),
"S": _("Spade"),
}
self.assertEqual(normalize_choices(choices), self.expected)
def test_iterable(self):
choices = [
("C", _("Club")),
("D", _("Diamond")),
("H", _("Heart")),
("S", _("Spade")),
]
self.assertEqual(normalize_choices(choices), self.expected)
def test_iterator(self):
def generator():
yield "C", _("Club")
yield "D", _("Diamond")
yield "H", _("Heart")
yield "S", _("Spade")
choices = generator()
self.assertEqual(normalize_choices(choices), self.expected)
def test_nested_callable(self):
def get_audio_choices():
return [("vinyl", _("Vinyl")), ("cd", _("CD"))]
def get_video_choices():
return [("vhs", _("VHS Tape")), ("dvd", _("DVD"))]
def get_media_choices():
return [
("Audio", get_audio_choices),
("Video", get_video_choices),
("unknown", _("Unknown")),
]
get_media_choices_spy = mock.Mock(wraps=get_media_choices)
output = normalize_choices(get_media_choices_spy)
get_media_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected_nested)
get_media_choices_spy.assert_called_once()
def test_nested_mapping(self):
choices = {
"Audio": {"vinyl": _("Vinyl"), "cd": _("CD")},
"Video": {"vhs": _("VHS Tape"), "dvd": _("DVD")},
"unknown": _("Unknown"),
}
self.assertEqual(normalize_choices(choices), self.expected_nested)
def test_nested_iterable(self):
choices = [
("Audio", [("vinyl", _("Vinyl")), ("cd", _("CD"))]),
("Video", [("vhs", _("VHS Tape")), ("dvd", _("DVD"))]),
("unknown", _("Unknown")),
]
self.assertEqual(normalize_choices(choices), self.expected_nested)
def test_nested_iterator(self):
def generate_audio_choices():
yield "vinyl", _("Vinyl")
yield "cd", _("CD")
def generate_video_choices():
yield "vhs", _("VHS Tape")
yield "dvd", _("DVD")
def generate_media_choices():
yield "Audio", generate_audio_choices()
yield "Video", generate_video_choices()
yield "unknown", _("Unknown")
choices = generate_media_choices()
self.assertEqual(normalize_choices(choices), self.expected_nested)
def test_callable_non_canonical(self):
# Canonical form is list of 2-tuple, but nested lists should work.
def get_choices():
return [
["C", _("Club")],
["D", _("Diamond")],
["H", _("Heart")],
["S", _("Spade")],
]
get_choices_spy = mock.Mock(wraps=get_choices)
output = normalize_choices(get_choices_spy)
get_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected)
get_choices_spy.assert_called_once()
def test_iterable_non_canonical(self):
# Canonical form is list of 2-tuple, but nested lists should work.
choices = [
["C", _("Club")],
["D", _("Diamond")],
["H", _("Heart")],
["S", _("Spade")],
]
self.assertEqual(normalize_choices(choices), self.expected)
def test_iterator_non_canonical(self):
# Canonical form is list of 2-tuple, but nested lists should work.
def generator():
yield ["C", _("Club")]
yield ["D", _("Diamond")]
yield ["H", _("Heart")]
yield ["S", _("Spade")]
choices = generator()
self.assertEqual(normalize_choices(choices), self.expected)
def test_nested_callable_non_canonical(self):
# Canonical form is list of 2-tuple, but nested lists should work.
def get_audio_choices():
return [["vinyl", _("Vinyl")], ["cd", _("CD")]]
def get_video_choices():
return [["vhs", _("VHS Tape")], ["dvd", _("DVD")]]
def get_media_choices():
return [
["Audio", get_audio_choices],
["Video", get_video_choices],
["unknown", _("Unknown")],
]
get_media_choices_spy = mock.Mock(wraps=get_media_choices)
output = normalize_choices(get_media_choices_spy)
get_media_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
self.assertEqual(list(output), self.expected_nested)
get_media_choices_spy.assert_called_once()
def test_nested_iterable_non_canonical(self):
# Canonical form is list of 2-tuple, but nested lists should work.
choices = [
["Audio", [["vinyl", _("Vinyl")], ["cd", _("CD")]]],
["Video", [["vhs", _("VHS Tape")], ["dvd", _("DVD")]]],
["unknown", _("Unknown")],
]
self.assertEqual(normalize_choices(choices), self.expected_nested)
def test_nested_iterator_non_canonical(self):
# Canonical form is list of 2-tuple, but nested lists should work.
def generator():
yield ["Audio", [["vinyl", _("Vinyl")], ["cd", _("CD")]]]
yield ["Video", [["vhs", _("VHS Tape")], ["dvd", _("DVD")]]]
yield ["unknown", _("Unknown")]
choices = generator()
self.assertEqual(normalize_choices(choices), self.expected_nested)
def test_nested_mixed_mapping_and_iterable(self):
# Although not documented, as it's better to stick to either mappings
# or iterables, nesting of mappings within iterables and vice versa
# works and is likely to occur in the wild. This is supported by the
# recursive call to `normalize_choices()` which will normalize nested
# choices.
choices = {
"Audio": [("vinyl", _("Vinyl")), ("cd", _("CD"))],
"Video": [("vhs", _("VHS Tape")), ("dvd", _("DVD"))],
"unknown": _("Unknown"),
}
self.assertEqual(normalize_choices(choices), self.expected_nested)
choices = [
("Audio", {"vinyl": _("Vinyl"), "cd": _("CD")}),
("Video", {"vhs": _("VHS Tape"), "dvd": _("DVD")}),
("unknown", _("Unknown")),
]
self.assertEqual(normalize_choices(choices), self.expected_nested)
def test_iterable_set(self):
# Although not documented, as sets are unordered which results in
# randomised order in form fields, passing a set of 2-tuples works.
# Consistent ordering of choices on model fields in migrations is
# enforced by the migrations serializer.
choices = {
("C", _("Club")),
("D", _("Diamond")),
("H", _("Heart")),
("S", _("Spade")),
}
self.assertEqual(sorted(normalize_choices(choices)), sorted(self.expected))
def test_unsupported_values_returned_unmodified(self):
# Unsupported values must be returned unmodified for model system check
# to work correctly.
for value in self.invalid + self.invalid_iterable + self.invalid_nested:
with self.subTest(value=value):
self.assertEqual(normalize_choices(value), value)
def test_unsupported_values_from_callable_returned_unmodified(self):
for value in self.invalid_iterable + self.invalid_nested:
with self.subTest(value=value):
self.assertEqual(list(normalize_choices(lambda: value)), value)
def test_unsupported_values_from_iterator_returned_unmodified(self):
for value in self.invalid_nested:
with self.subTest(value=value):
self.assertEqual(
list(normalize_choices((lambda: (yield from value))())),
value,
)