diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py index ae83d9e379..e333f30619 100644 --- a/django/contrib/postgres/fields/jsonb.py +++ b/django/contrib/postgres/fields/jsonb.py @@ -10,6 +10,19 @@ from django.utils.translation import ugettext_lazy as _ __all__ = ['JSONField'] +class JsonAdapter(Json): + """ + Customized psycopg2.extras.Json to allow for a custom encoder. + """ + def __init__(self, adapted, dumps=None, encoder=None): + self.encoder = encoder + super(JsonAdapter, self).__init__(adapted, dumps=dumps) + + def dumps(self, obj): + options = {'cls': self.encoder} if self.encoder else {} + return json.dumps(obj, **options) + + class JSONField(Field): empty_strings_allowed = False description = _('A JSON object') @@ -17,9 +30,21 @@ class JSONField(Field): 'invalid': _("Value must be valid JSON."), } + def __init__(self, verbose_name=None, name=None, encoder=None, **kwargs): + if encoder and not callable(encoder): + raise ValueError("The encoder parameter must be a callable object.") + self.encoder = encoder + super(JSONField, self).__init__(verbose_name, name, **kwargs) + def db_type(self, connection): return 'jsonb' + def deconstruct(self): + name, path, args, kwargs = super(JSONField, self).deconstruct() + if self.encoder is not None: + kwargs['encoder'] = self.encoder + return name, path, args, kwargs + def get_transform(self, name): transform = super(JSONField, self).get_transform(name) if transform: @@ -28,13 +53,14 @@ class JSONField(Field): def get_prep_value(self, value): if value is not None: - return Json(value) + return JsonAdapter(value, encoder=self.encoder) return value def validate(self, value, model_instance): super(JSONField, self).validate(value, model_instance) + options = {'cls': self.encoder} if self.encoder else {} try: - json.dumps(value) + json.dumps(value, **options) except TypeError: raise exceptions.ValidationError( self.error_messages['invalid'], diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt index fdb69a7370..3d22b13577 100644 --- a/docs/ref/contrib/postgres/fields.txt +++ b/docs/ref/contrib/postgres/fields.txt @@ -458,17 +458,32 @@ using in conjunction with lookups on ``JSONField`` ============= -.. class:: JSONField(**options) +.. class:: JSONField(encoder=None, **options) A field for storing JSON encoded data. In Python the data is represented in its Python native format: dictionaries, lists, strings, numbers, booleans and ``None``. - If you want to store other data types, you'll need to serialize them first. - For example, you might cast a ``datetime`` to a string. You might also want - to convert the string back to a ``datetime`` when you retrieve the data - from the database. There are some third-party ``JSONField`` implementations - which do this sort of thing automatically. + .. attribute:: encoder + + .. versionadded:: 1.11 + + An optional JSON-encoding class to serialize data types not supported + by the standard JSON serializer (``datetime``, ``uuid``, etc.). For + example, you can use the + :class:`~django.core.serializers.json.DjangoJSONEncoder` class or any + other :py:class:`json.JSONEncoder` subclass. + + When the value is retrieved from the database, it will be in the format + chosen by the custom encoder (most often a string), so you'll need to + take extra steps to convert the value back to the initial data type + (:meth:`Model.from_db() ` and + :meth:`Field.from_db_value() ` + are two possible hooks for that purpose). Your deserialization may need + to account for the fact that you can't be certain of the input type. + For example, you run the risk of returning a ``datetime`` that was + actually a string that just happened to be in the same format chosen + for ``datetime``\s. If you give the field a :attr:`~django.db.models.Field.default`, ensure it's a callable such as ``dict`` (for an empty default) or a callable that diff --git a/docs/releases/1.11.txt b/docs/releases/1.11.txt index f50b000016..7d9af8d132 100644 --- a/docs/releases/1.11.txt +++ b/docs/releases/1.11.txt @@ -129,6 +129,10 @@ Minor features * The new :class:`~django.contrib.postgres.indexes.GinIndex` class allows creating gin indexes in the database. +* :class:`~django.contrib.postgres.fields.JSONField` accepts a new ``encoder`` + parameter to specify a custom class to encode data types not supported by the + standard encoder. + :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py index d50c6d6a91..c02bc2faf4 100644 --- a/tests/postgres_tests/fields.py +++ b/tests/postgres_tests/fields.py @@ -23,6 +23,10 @@ except ImportError: }) return name, path, args, kwargs + class DummyJSONField(models.Field): + def __init__(self, encoder=None, **kwargs): + super(DummyJSONField, self).__init__(**kwargs) + ArrayField = DummyArrayField BigIntegerRangeField = models.Field DateRangeField = models.Field @@ -30,5 +34,5 @@ except ImportError: FloatRangeField = models.Field HStoreField = models.Field IntegerRangeField = models.Field - JSONField = models.Field + JSONField = DummyJSONField SearchVectorField = models.Field diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index 69c5b17f77..974aae1ec8 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +from django.core.serializers.json import DjangoJSONEncoder from django.db import migrations, models from ..fields import ( @@ -223,6 +224,7 @@ class Migration(migrations.Migration): fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), ('field', JSONField(null=True, blank=True)), + ('field_custom', JSONField(null=True, blank=True, encoder=DjangoJSONEncoder)), ], options={ }, diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index d94eb90d4a..52dbc0335c 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -1,3 +1,4 @@ +from django.core.serializers.json import DjangoJSONEncoder from django.db import connection, models from .fields import ( @@ -132,6 +133,7 @@ class RangeLookupsModel(PostgreSQLModel): if connection.vendor == 'postgresql' and connection.pg_version >= 90400: class JSONModel(models.Model): field = JSONField(blank=True, null=True) + field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder) else: # create an object with this name so we don't have failing imports class JSONModel(object): diff --git a/tests/postgres_tests/test_json.py b/tests/postgres_tests/test_json.py index b88d103761..fd4db7fae4 100644 --- a/tests/postgres_tests/test_json.py +++ b/tests/postgres_tests/test_json.py @@ -1,7 +1,12 @@ +from __future__ import unicode_literals + import datetime import unittest +import uuid +from decimal import Decimal from django.core import exceptions, serializers +from django.core.serializers.json import DjangoJSONEncoder from django.db import connection from django.forms import CharField, Form, widgets from django.test import TestCase @@ -79,6 +84,27 @@ class TestSaveLoad(TestCase): loaded = JSONModel.objects.get() self.assertEqual(loaded.field, obj) + def test_custom_encoding(self): + """ + JSONModel.field_custom has a custom DjangoJSONEncoder. + """ + some_uuid = uuid.uuid4() + obj_before = { + 'date': datetime.date(2016, 8, 12), + 'datetime': datetime.datetime(2016, 8, 12, 13, 44, 47, 575981), + 'decimal': Decimal('10.54'), + 'uuid': some_uuid, + } + obj_after = { + 'date': '2016-08-12', + 'datetime': '2016-08-12T13:44:47.575', + 'decimal': '10.54', + 'uuid': str(some_uuid), + } + JSONModel.objects.create(field_custom=obj_before) + loaded = JSONModel.objects.get() + self.assertEqual(loaded.field_custom, obj_after) + @skipUnlessPG94 class TestQuerying(TestCase): @@ -215,7 +241,10 @@ class TestQuerying(TestCase): @skipUnlessPG94 class TestSerialization(TestCase): - test_data = '[{"fields": {"field": {"a": "b", "c": null}}, "model": "postgres_tests.jsonmodel", "pk": null}]' + test_data = ( + '[{"fields": {"field": {"a": "b", "c": null}, "field_custom": null}, ' + '"model": "postgres_tests.jsonmodel", "pk": null}]' + ) def test_dumping(self): instance = JSONModel(field={'a': 'b', 'c': None}) @@ -236,6 +265,12 @@ class TestValidation(PostgreSQLTestCase): self.assertEqual(cm.exception.code, 'invalid') self.assertEqual(cm.exception.message % cm.exception.params, "Value must be valid JSON.") + def test_custom_encoder(self): + with self.assertRaisesMessage(ValueError, "The encoder parameter must be a callable object."): + field = JSONField(encoder=DjangoJSONEncoder()) + field = JSONField(encoder=DjangoJSONEncoder) + self.assertEqual(field.clean(datetime.timedelta(days=1), None), datetime.timedelta(days=1)) + class TestFormField(PostgreSQLTestCase):