From fafdc4c213816a8efa1deeb37657535b9b6e51b5 Mon Sep 17 00:00:00 2001 From: Shafiya Adzhani Date: Sat, 24 Aug 2024 18:06:35 +0700 Subject: [PATCH] Added JSONRemove for removing keys in JSONField. --- django/db/backends/sqlite3/features.py | 2 + django/db/models/functions/__init__.py | 3 +- django/db/models/functions/json.py | 97 ++++++++ docs/ref/models/database-functions.txt | 28 +++ tests/db_functions/json/test_json_remove.py | 249 ++++++++++++++++++++ 5 files changed, 378 insertions(+), 1 deletion(-) create mode 100644 tests/db_functions/json/test_json_remove.py diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 720e2bfd50..5ddae0bef8 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -134,6 +134,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): "test_lookups_special_chars_double_quotes", "db_functions.json.test_json_set.JSONSetTests." "test_set_special_chars_double_quotes", + "db_functions.json.test_json_remove.JSONRemoveTests." + "test_remove_special_chars_double_quotes", }, } ) diff --git a/django/db/models/functions/__init__.py b/django/db/models/functions/__init__.py index 742a574571..c799297564 100644 --- a/django/db/models/functions/__init__.py +++ b/django/db/models/functions/__init__.py @@ -25,7 +25,7 @@ from .datetime import ( TruncWeek, TruncYear, ) -from .json import JSONObject, JSONSet +from .json import JSONObject, JSONRemove, JSONSet from .math import ( Abs, ACos, @@ -127,6 +127,7 @@ __all__ = [ "TruncYear", # json "JSONObject", + "JSONRemove", "JSONSet", # math "Abs", diff --git a/django/db/models/functions/json.py b/django/db/models/functions/json.py index 3fc67071e7..cfce932f12 100644 --- a/django/db/models/functions/json.py +++ b/django/db/models/functions/json.py @@ -219,3 +219,100 @@ class JSONSet(Func): arg_joiner=self, **extra_context, ) + + +class JSONRemove(Func): + def __init__(self, expression, *paths, **kwargs): + if not paths: + raise TypeError("JSONRemove requires at least one path to remove") + self.paths = paths + super().__init__(expression, **kwargs) + + def _get_repr_options(self): + return {**super().get_repr_options(), **self.fields} + + def join(self, args): + path = self.paths[0] + key_paths = path.split(LOOKUP_SEP) + key_paths_join = compile_json_path(key_paths) + + return f"{args[0]}, REMOVE q'\uffff{key_paths_join}\uffff'" + + def as_sql( + self, + compiler, + connection, + function=None, + template=None, + arg_joiner=None, + **extra_context, + ): + if not connection.features.supports_partial_json_update: + raise NotSupportedError( + "JSONRemove() is not supported on this database backend." + ) + + copy = self.copy() + new_source_expressions = copy.get_source_expressions() + + for path in self.paths: + key_paths = path.split(LOOKUP_SEP) + key_paths_join = compile_json_path(key_paths) + new_source_expressions.append(Value(key_paths_join)) + + copy.set_source_expressions(new_source_expressions) + + return super(JSONRemove, copy).as_sql( + compiler, + connection, + function="JSON_REMOVE", + **extra_context, + ) + + def as_postgresql(self, compiler, connection, **extra_context): + copy = self.copy() + path, *rest = self.paths + + if rest: + copy.paths = (path,) + return JSONRemove(copy, *rest).as_postgresql( + compiler, connection, **extra_context + ) + + new_source_expressions = copy.get_source_expressions() + key_paths = path.split(LOOKUP_SEP) + new_source_expressions.append(Value(key_paths)) + copy.set_source_expressions(new_source_expressions) + + return super(JSONRemove, copy).as_sql( + compiler, + connection, + template="%(expressions)s", + arg_joiner="#- ", + **extra_context, + ) + + def as_oracle(self, compiler, connection, **extra_context): + if not connection.features.supports_partial_json_update: + raise NotSupportedError( + "JSONRemove() is not supported on this database backend." + ) + + copy = self.copy() + + all_items = self.paths + path, *rest = all_items + + if rest: + copy.paths = (path,) + return JSONRemove(copy, *rest).as_oracle( + compiler, connection, **extra_context + ) + + return super(JSONRemove, copy).as_sql( + compiler, + connection, + function="JSON_TRANSFORM", + arg_joiner=self, + **extra_context, + ) diff --git a/docs/ref/models/database-functions.txt b/docs/ref/models/database-functions.txt index ba00614226..b8c49b1864 100644 --- a/docs/ref/models/database-functions.txt +++ b/docs/ref/models/database-functions.txt @@ -926,6 +926,34 @@ Usage example: >>> user_preferences.settings {'font': {'name': 'Arial', 'size': 10, 'color': 'white'}, 'notifications': True} # theme__type will not be added +``JSONRemove`` +-------------- + +.. class:: JSONRemove(expression, *paths, **kwargs) + +Removes specified paths from a :class:`~django.db.models.JSONField` in the +database. Accepts an expression identifying the +:class:`~django.db.models.JSONField` and a variable number of paths defining +the keys to be deleted from the JSON structure. + +Usage example: + +.. code-block:: pycon + + >>> from django.db.models.functions import JSONRemove + >>> user_preferences = UserPreferences.objects.create( + ... settings={ + ... "font": {"name": "Arial", "size": 10}, + ... "notifications": True, + ... } + ... ) + >>> UserPreferences.objects.update( + ... settings=JSONRemove("settings", "font__size", "notifications") + ... ) + 1 + >>> user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + >>> print(user_preferences.settings) + {'font': {'name': 'Arial'}} .. _math-functions: diff --git a/tests/db_functions/json/test_json_remove.py b/tests/db_functions/json/test_json_remove.py new file mode 100644 index 0000000000..0a93617f00 --- /dev/null +++ b/tests/db_functions/json/test_json_remove.py @@ -0,0 +1,249 @@ +from django.db import NotSupportedError +from django.db.models import IntegerField, JSONField, Sum, Value +from django.db.models.functions import Cast, JSONRemove +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature + +from ..models import UserPreferences + + +@skipUnlessDBFeature("supports_partial_json_update") +class JSONRemoveTests(TestCase): + def test_remove_single_key(self): + user_preferences = UserPreferences.objects.create( + settings={"theme": "dark", "font": "Arial"} + ) + UserPreferences.objects.update(settings=JSONRemove("settings", "theme")) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual(user_preferences.settings, {"font": "Arial"}) + + def test_remove_nonexistent_key(self): + user_preferences = UserPreferences.objects.create(settings={"theme": "dark"}) + UserPreferences.objects.update(settings=JSONRemove("settings", "font")) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual(user_preferences.settings, {"theme": "dark"}) + + def test_remove_nested_key(self): + user_preferences = UserPreferences.objects.create( + settings={"font": {"size": 20, "color": "red"}} + ) + UserPreferences.objects.update(settings=JSONRemove("settings", "font__color")) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual(user_preferences.settings, {"font": {"size": 20}}) + + def test_remove_nested_keys_to_be_empty_object(self): + user_preferences = UserPreferences.objects.create( + settings={"font": {"color": "red"}, "notifications": True} + ) + UserPreferences.objects.update( + settings=JSONRemove("settings", "font__color"), + ) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual( + user_preferences.settings, + { + "font": {}, + "notifications": True, + }, + ) + + def test_remove_multiple_keys(self): + user_preferences = UserPreferences.objects.create( + settings={"font": {"size": 20, "color": "red"}, "theme": "dark"} + ) + UserPreferences.objects.update( + settings=JSONRemove("settings", "font__color", "theme") + ) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual(user_preferences.settings, {"font": {"size": 20}}) + + def test_remove_keys_with_recursive_call(self): + user_preferences = UserPreferences.objects.create( + settings={"font": {"size": 20, "color": "red"}, "theme": "dark"} + ) + UserPreferences.objects.update( + settings=JSONRemove(JSONRemove("settings", "font__color"), "theme") + ) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual(user_preferences.settings, {"font": {"size": 20}}) + + def test_save_on_model_field(self): + user_preferences = UserPreferences.objects.create( + settings={"theme": "dark", "font": "Arial"} + ) + user_preferences.settings = JSONRemove("settings", "theme") + user_preferences.save() + + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual(user_preferences.settings, {"font": "Arial"}) + + def test_update_or_create_not_created(self): + user_preferences = UserPreferences.objects.create( + settings={ + "theme": {"color": "black", "font": "Arial"}, + "notifications": {"email": False, "sms": True}, + } + ) + + updated_user_preferences, created = UserPreferences.objects.update_or_create( + defaults={"settings": JSONRemove("settings", "theme__color")}, + id=user_preferences.id, + ) + self.assertIs(created, False) + # Refresh the object to avoid expression persistence after the update. + updated_user_preferences.refresh_from_db() + self.assertEqual( + updated_user_preferences.settings, + { + "theme": {"font": "Arial"}, + "notifications": {"email": False, "sms": True}, + }, + ) + + def test_update_or_create_created(self): + updated_user_preferences, created = UserPreferences.objects.update_or_create( + defaults={"settings": JSONRemove("settings", "theme")}, + id=9999, + create_defaults={ + "settings": JSONRemove( + Value( + {"theme": "dark", "notifications": True}, + output_field=JSONField(), + ), + "theme", + ) + }, + ) + self.assertIs(created, True) + updated_user_preferences.refresh_from_db() + self.assertEqual(updated_user_preferences.id, 9999) + self.assertEqual( + updated_user_preferences.settings, + {"notifications": True}, + ) + + def test_remove_special_chars(self): + test_keys = [ + "CONTROL", + "single'", + "dollar$", + "dot.dot", + "with space", + "back\\slash", + "question?mark", + "user@name", + "emo🤡'ji", + "com,ma", + "curly{{{brace}}}s", + "escape\uffff'seq'\uffffue\uffff'nce", + ] + for key in test_keys: + with self.subTest(key=key): + user_preferences = UserPreferences.objects.create( + settings={key: 20, "notifications": True, "font": {"size": 30}} + ) + UserPreferences.objects.update(settings=JSONRemove("settings", key)) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual( + user_preferences.settings, + {"notifications": True, "font": {"size": 30}}, + ) + + def test_remove_special_chars_double_quotes(self): + test_keys = [ + 'double"', + "m\\i@x. m🤡'a,t{{{ch}}}e?d$\"'es\uffff'ca\uffff'pe", + ] + for key in test_keys: + with self.subTest(key=key): + user_preferences = UserPreferences.objects.create( + settings={key: 20, "notifications": True, "font": {"size": 30}} + ) + UserPreferences.objects.update(settings=JSONRemove("settings", key)) + user_preferences = UserPreferences.objects.get(pk=user_preferences.pk) + self.assertEqual( + user_preferences.settings, + {"notifications": True, "font": {"size": 30}}, + ) + + def test_remove_with_values(self): + UserPreferences.objects.create( + settings={"font": {"name": "Arial", "size": 10}, "notifications": True} + ) + user_preferences_value = ( + UserPreferences.objects.annotate( + settings_updated=JSONRemove("settings", "font__size") + ) + .values("settings_updated", "settings__font__size") + .first() + ) + + self.assertEqual( + user_preferences_value, + { + "settings_updated": {"font": {"name": "Arial"}, "notifications": True}, + "settings__font__size": 10, + }, + ) + + def test_remove_with_values_list(self): + UserPreferences.objects.create( + settings={"font": {"name": "Arial", "size": 10}, "notifications": True} + ) + UserPreferences.objects.create( + settings={ + "font": {"name": "Comic Sans", "size": 20}, + "notifications": False, + } + ) + user_preferences_values = UserPreferences.objects.annotate( + settings_updated=JSONRemove("settings", "font__size") + ).values_list("settings_updated", flat=True) + + self.assertEqual( + user_preferences_values[0], + {"font": {"name": "Arial"}, "notifications": True}, + ) + + self.assertEqual( + user_preferences_values[1], + {"font": {"name": "Comic Sans"}, "notifications": False}, + ) + + def test_remove_with_aggregate(self): + UserPreferences.objects.create( + settings={"font": {"name": "Arial", "size": 10}, "notifications": True} + ) + UserPreferences.objects.create( + settings={"font": {"name": "Comic Sans", "size": 20}, "notifications": True} + ) + result = UserPreferences.objects.annotate( + settings_updated=JSONRemove("settings", "font__size") + ).aggregate( + total_font_size=Sum( + Cast( + "settings_updated__font__size", + IntegerField(), + ) + ) + ) + + self.assertEqual(result["total_font_size"], None) + + +class InvalidJSONRemoveTests(TestCase): + @skipIfDBFeature("supports_partial_json_update") + def test_remove_not_supported(self): + with self.assertRaisesMessage( + NotSupportedError, "JSONRemove() is not supported on this database backend." + ): + UserPreferences.objects.create(settings={"theme": "dark", "font": "Arial"}) + UserPreferences.objects.update(settings=JSONRemove("settings", "theme")) + + def test_remove_missing_path_to_be_removed_error(self): + with self.assertRaisesMessage( + TypeError, "JSONRemove requires at least one path to remove" + ): + UserPreferences.objects.create( + settings={"theme": "dark", "notifications": True} + ) + UserPreferences.objects.update(settings=JSONRemove("settings"))