diff --git a/django/db/models/base.py b/django/db/models/base.py index 0a5e5ff673..e3b14a41a0 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -33,6 +33,7 @@ from django.db.models.signals import ( ) from django.db.models.utils import make_model_tuple from django.utils.encoding import force_str +from django.utils.hashable import make_hashable from django.utils.text import capfirst, get_text_list from django.utils.translation import gettext_lazy as _ from django.utils.version import get_version @@ -940,8 +941,9 @@ class Model(metaclass=ModelBase): def _get_FIELD_display(self, field): value = getattr(self, field.attname) + choices_dict = dict(make_hashable(field.flatchoices)) # force_str() to coerce lazy strings. - return force_str(dict(field.flatchoices).get(value, value), strings_only=True) + return force_str(choices_dict.get(make_hashable(value), value), strings_only=True) def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs): if not self.pk: diff --git a/docs/ref/models/instances.txt b/docs/ref/models/instances.txt index 396300d035..5f8f389506 100644 --- a/docs/ref/models/instances.txt +++ b/docs/ref/models/instances.txt @@ -797,6 +797,11 @@ For example:: >>> p.get_shirt_size_display() 'Large' +.. versionchanged:: 3.1 + + Support for :class:`~django.contrib.postgres.fields.ArrayField` and + :class:`~django.contrib.postgres.fields.RangeField` was added. + .. method:: Model.get_next_by_FOO(**kwargs) .. method:: Model.get_previous_by_FOO(**kwargs) diff --git a/docs/releases/3.1.txt b/docs/releases/3.1.txt index a3b4f30d63..de114e7098 100644 --- a/docs/releases/3.1.txt +++ b/docs/releases/3.1.txt @@ -76,6 +76,10 @@ Minor features :class:`~django.contrib.postgres.operations.BloomExtension` migration operation installs the ``bloom`` extension to add support for this index. +* :meth:`~django.db.models.Model.get_FOO_display` now supports + :class:`~django.contrib.postgres.fields.ArrayField` and + :class:`~django.contrib.postgres.fields.RangeField`. + :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 7b7793f6c1..481d93f830 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -37,6 +37,53 @@ except ImportError: pass +@isolate_apps('postgres_tests') +class BasicTests(PostgreSQLSimpleTestCase): + def test_get_field_display(self): + class MyModel(PostgreSQLModel): + field = ArrayField( + models.CharField(max_length=16), + choices=[ + ['Media', [(['vinyl', 'cd'], 'Audio')]], + (('mp3', 'mp4'), 'Digital'), + ], + ) + + tests = ( + (['vinyl', 'cd'], 'Audio'), + (('mp3', 'mp4'), 'Digital'), + (('a', 'b'), "('a', 'b')"), + (['c', 'd'], "['c', 'd']"), + ) + for value, display in tests: + with self.subTest(value=value, display=display): + instance = MyModel(field=value) + self.assertEqual(instance.get_field_display(), display) + + def test_get_field_display_nested_array(self): + class MyModel(PostgreSQLModel): + field = ArrayField( + ArrayField(models.CharField(max_length=16)), + choices=[ + [ + 'Media', + [([['vinyl', 'cd'], ('x',)], 'Audio')], + ], + ((['mp3'], ('mp4',)), 'Digital'), + ], + ) + tests = ( + ([['vinyl', 'cd'], ('x',)], 'Audio'), + ((['mp3'], ('mp4',)), 'Digital'), + ((('a', 'b'), ('c',)), "(('a', 'b'), ('c',))"), + ([['a', 'b'], ['c']], "[['a', 'b'], ['c']]"), + ) + for value, display in tests: + with self.subTest(value=value, display=display): + instance = MyModel(field=value) + self.assertEqual(instance.get_field_display(), display) + + class TestSaveLoad(PostgreSQLTestCase): def test_integer(self): diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py index b68f22112f..326015e46d 100644 --- a/tests/postgres_tests/test_ranges.py +++ b/tests/postgres_tests/test_ranges.py @@ -7,6 +7,7 @@ from django.core import exceptions, serializers from django.db.models import DateField, DateTimeField, F, Func, Value from django.http import QueryDict from django.test import override_settings +from django.test.utils import isolate_apps from django.utils import timezone from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase @@ -22,6 +23,30 @@ except ImportError: pass +@isolate_apps('postgres_tests') +class BasicTests(PostgreSQLSimpleTestCase): + def test_get_field_display(self): + class Model(PostgreSQLModel): + field = pg_fields.IntegerRangeField( + choices=[ + ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]], + ((51, 100), '51-100'), + ], + ) + + tests = ( + ((1, 25), '1-25'), + ([26, 50], '26-50'), + ((51, 100), '51-100'), + ((1, 2), '(1, 2)'), + ([1, 2], '[1, 2]'), + ) + for value, display in tests: + with self.subTest(value=value, display=display): + instance = Model(field=value) + self.assertEqual(instance.get_field_display(), display) + + class TestSaveLoad(PostgreSQLTestCase): def test_all_fields(self):