diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py index 8d743c436a..188fcf520c 100644 --- a/django/db/models/fields/json.py +++ b/django/db/models/fields/json.py @@ -99,18 +99,23 @@ class JSONField(CheckFieldDefaultMixin, Field): def get_db_prep_value(self, value, connection, prepared=False): if not prepared: value = self.get_prep_value(value) - if isinstance(value, expressions.Value) and isinstance( - value.output_field, JSONField - ): - value = value.value - elif hasattr(value, "as_sql"): - return value return connection.ops.adapt_json_value(value, self.encoder) def get_db_prep_save(self, value, connection): + # This slightly involved logic is to allow for `None` to be used to + # store SQL `NULL` while `Value(None, JSONField())` can be used to + # store JSON `null` while preventing compilable `as_sql` values from + # making their way to `get_db_prep_value`, which is what the `super()` + # implementation does. if value is None: return value - return self.get_db_prep_value(value, connection) + if ( + isinstance(value, expressions.Value) + and value.value is None + and isinstance(value.output_field, JSONField) + ): + value = None + return super().get_db_prep_save(value, connection) def get_transform(self, name): transform = super().get_transform(name) diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py index fdea06b23d..299e927615 100644 --- a/tests/model_fields/models.py +++ b/tests/model_fields/models.py @@ -430,6 +430,17 @@ class RelatedJSONModel(models.Model): required_db_features = {"supports_json_field"} +class CustomSerializationJSONModel(models.Model): + class StringifiedJSONField(models.JSONField): + def get_prep_value(self, value): + return json.dumps(value, cls=self.encoder) + + json_field = StringifiedJSONField() + + class Meta: + required_db_features = {"supports_json_field"} + + class AllFieldsModel(models.Model): big_integer = models.BigIntegerField() binary = models.BinaryField() diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py index 5a9cf9ad7a..267b9a0e66 100644 --- a/tests/model_fields/test_jsonfield.py +++ b/tests/model_fields/test_jsonfield.py @@ -40,7 +40,13 @@ from django.db.models.functions import Cast from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext -from .models import CustomJSONDecoder, JSONModel, NullableJSONModel, RelatedJSONModel +from .models import ( + CustomJSONDecoder, + CustomSerializationJSONModel, + JSONModel, + NullableJSONModel, + RelatedJSONModel, +) @skipUnlessDBFeature("supports_json_field") @@ -298,6 +304,17 @@ class TestSaveLoad(TestCase): obj.refresh_from_db() self.assertEqual(obj.value, value) + def test_bulk_update_custom_get_prep_value(self): + objs = CustomSerializationJSONModel.objects.bulk_create( + [CustomSerializationJSONModel(pk=1, json_field={"version": "1"})] + ) + objs[0].json_field["version"] = "1-alpha" + CustomSerializationJSONModel.objects.bulk_update(objs, ["json_field"]) + self.assertSequenceEqual( + CustomSerializationJSONModel.objects.values("json_field"), + [{"json_field": '{"version": "1-alpha"}'}], + ) + @skipUnlessDBFeature("supports_json_field") class TestQuerying(TestCase):