From ee5147cfd7de2add74a285537a8968ec074e70cd Mon Sep 17 00:00:00 2001 From: Amir Karimi Date: Thu, 12 Sep 2024 10:56:18 +0200 Subject: [PATCH] Fixed #29522 -- Refactored the Deserializer functions to classes. Co-authored-by: Emad Mokhtar --- AUTHORS | 2 +- django/core/serializers/json.py | 32 +++-- django/core/serializers/jsonl.py | 31 +++-- django/core/serializers/python.py | 129 +++++++++++-------- django/core/serializers/pyyaml.py | 31 +++-- docs/releases/5.2.txt | 4 +- docs/topics/serialization.txt | 80 ++++++++++++ tests/serializers/test_deserialization.py | 125 ++++++++++++++++++ tests/serializers/test_deserializedobject.py | 13 -- 9 files changed, 344 insertions(+), 103 deletions(-) create mode 100644 tests/serializers/test_deserialization.py delete mode 100644 tests/serializers/test_deserializedobject.py diff --git a/AUTHORS b/AUTHORS index 2c2b8d5b13..573a030ea1 100644 --- a/AUTHORS +++ b/AUTHORS @@ -68,7 +68,7 @@ answer newbie questions, and generally made Django that much better: Aljaž Košir Aljosa Mohorovic Alokik Vijay - Amir Karimi + Amir Karimi Amit Chakradeo Amit Ramon Amit Upadhyay diff --git a/django/core/serializers/json.py b/django/core/serializers/json.py index afac821465..7683368e62 100644 --- a/django/core/serializers/json.py +++ b/django/core/serializers/json.py @@ -59,19 +59,27 @@ class Serializer(PythonSerializer): return super(PythonSerializer, self).getvalue() -def Deserializer(stream_or_string, **options): +class Deserializer(PythonDeserializer): """Deserialize a stream or string of JSON data.""" - if not isinstance(stream_or_string, (bytes, str)): - stream_or_string = stream_or_string.read() - if isinstance(stream_or_string, bytes): - stream_or_string = stream_or_string.decode() - try: - objects = json.loads(stream_or_string) - yield from PythonDeserializer(objects, **options) - except (GeneratorExit, DeserializationError): - raise - except Exception as exc: - raise DeserializationError() from exc + + def __init__(self, stream_or_string, **options): + if not isinstance(stream_or_string, (bytes, str)): + stream_or_string = stream_or_string.read() + if isinstance(stream_or_string, bytes): + stream_or_string = stream_or_string.decode() + try: + objects = json.loads(stream_or_string) + except Exception as exc: + raise DeserializationError() from exc + super().__init__(objects, **options) + + def _handle_object(self, obj): + try: + yield from super()._handle_object(obj) + except (GeneratorExit, DeserializationError): + raise + except Exception as exc: + raise DeserializationError(f"Error deserializing object: {exc}") from exc class DjangoJSONEncoder(json.JSONEncoder): diff --git a/django/core/serializers/jsonl.py b/django/core/serializers/jsonl.py index c264c2ccaf..7bc9bed79f 100644 --- a/django/core/serializers/jsonl.py +++ b/django/core/serializers/jsonl.py @@ -39,19 +39,30 @@ class Serializer(PythonSerializer): return super(PythonSerializer, self).getvalue() -def Deserializer(stream_or_string, **options): +class Deserializer(PythonDeserializer): """Deserialize a stream or string of JSON data.""" - if isinstance(stream_or_string, bytes): - stream_or_string = stream_or_string.decode() - if isinstance(stream_or_string, (bytes, str)): - stream_or_string = stream_or_string.split("\n") - for line in stream_or_string: - if not line.strip(): - continue + def __init__(self, stream_or_string, **options): + if isinstance(stream_or_string, bytes): + stream_or_string = stream_or_string.decode() + if isinstance(stream_or_string, str): + stream_or_string = stream_or_string.splitlines() + super().__init__(Deserializer._get_lines(stream_or_string), **options) + + def _handle_object(self, obj): try: - yield from PythonDeserializer([json.loads(line)], **options) + yield from super()._handle_object(obj) except (GeneratorExit, DeserializationError): raise except Exception as exc: - raise DeserializationError() from exc + raise DeserializationError(f"Error deserializing object: {exc}") from exc + + @staticmethod + def _get_lines(stream): + for line in stream: + if not line.strip(): + continue + try: + yield json.loads(line) + except Exception as exc: + raise DeserializationError() from exc diff --git a/django/core/serializers/python.py b/django/core/serializers/python.py index 7ec894aa00..46ef9f0771 100644 --- a/django/core/serializers/python.py +++ b/django/core/serializers/python.py @@ -96,45 +96,60 @@ class Serializer(base.Serializer): return self.objects -def Deserializer( - object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options -): +class Deserializer(base.Deserializer): """ Deserialize simple Python objects back into Django ORM instances. It's expected that you pass the Python objects themselves (instead of a stream or a string) to the constructor """ - handle_forward_references = options.pop("handle_forward_references", False) - field_names_cache = {} # Model: - for d in object_list: - # Look up the model and starting build a dict of data for it. - try: - Model = _get_model(d["model"]) - except base.DeserializationError: - if ignorenonexistent: - continue - else: - raise + def __init__( + self, object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options + ): + super().__init__(object_list, **options) + self.handle_forward_references = options.pop("handle_forward_references", False) + self.using = using + self.ignorenonexistent = ignorenonexistent + self.field_names_cache = {} # Model: + self._iterator = None + + def __iter__(self): + for obj in self.stream: + yield from self._handle_object(obj) + + def __next__(self): + if self._iterator is None: + self._iterator = iter(self) + return next(self._iterator) + + def _handle_object(self, obj): data = {} - if "pk" in d: - try: - data[Model._meta.pk.attname] = Model._meta.pk.to_python(d.get("pk")) - except Exception as e: - raise base.DeserializationError.WithData( - e, d["model"], d.get("pk"), None - ) m2m_data = {} deferred_fields = {} - if Model not in field_names_cache: - field_names_cache[Model] = {f.name for f in Model._meta.get_fields()} - field_names = field_names_cache[Model] + # Look up the model and starting build a dict of data for it. + try: + Model = self._get_model_from_node(obj["model"]) + except base.DeserializationError: + if self.ignorenonexistent: + return + raise + if "pk" in obj: + try: + data[Model._meta.pk.attname] = Model._meta.pk.to_python(obj.get("pk")) + except Exception as e: + raise base.DeserializationError.WithData( + e, obj["model"], obj.get("pk"), None + ) + + if Model not in self.field_names_cache: + self.field_names_cache[Model] = {f.name for f in Model._meta.get_fields()} + field_names = self.field_names_cache[Model] # Handle each field - for field_name, field_value in d["fields"].items(): - if ignorenonexistent and field_name not in field_names: + for field_name, field_value in obj["fields"].items(): + if self.ignorenonexistent and field_name not in field_names: # skip fields no longer on model continue @@ -145,51 +160,59 @@ def Deserializer( field.remote_field, models.ManyToManyRel ): try: - values = base.deserialize_m2m_values( - field, field_value, using, handle_forward_references - ) + values = self._handle_m2m_field_node(field, field_value) + if values == base.DEFER_FIELD: + deferred_fields[field] = field_value + else: + m2m_data[field.name] = values except base.M2MDeserializationError as e: raise base.DeserializationError.WithData( - e.original_exc, d["model"], d.get("pk"), e.pk + e.original_exc, obj["model"], obj.get("pk"), e.pk ) - if values == base.DEFER_FIELD: - deferred_fields[field] = field_value - else: - m2m_data[field.name] = values + # Handle FK fields elif field.remote_field and isinstance( field.remote_field, models.ManyToOneRel ): try: - value = base.deserialize_fk_value( - field, field_value, using, handle_forward_references - ) + value = self._handle_fk_field_node(field, field_value) + if value == base.DEFER_FIELD: + deferred_fields[field] = field_value + else: + data[field.attname] = value except Exception as e: raise base.DeserializationError.WithData( - e, d["model"], d.get("pk"), field_value + e, obj["model"], obj.get("pk"), field_value ) - if value == base.DEFER_FIELD: - deferred_fields[field] = field_value - else: - data[field.attname] = value + # Handle all other fields else: try: data[field.name] = field.to_python(field_value) except Exception as e: raise base.DeserializationError.WithData( - e, d["model"], d.get("pk"), field_value + e, obj["model"], obj.get("pk"), field_value ) - obj = base.build_instance(Model, data, using) - yield base.DeserializedObject(obj, m2m_data, deferred_fields) + model_instance = base.build_instance(Model, data, self.using) + yield base.DeserializedObject(model_instance, m2m_data, deferred_fields) - -def _get_model(model_identifier): - """Look up a model from an "app_label.model_name" string.""" - try: - return apps.get_model(model_identifier) - except (LookupError, TypeError): - raise base.DeserializationError( - "Invalid model identifier: '%s'" % model_identifier + def _handle_m2m_field_node(self, field, field_value): + return base.deserialize_m2m_values( + field, field_value, self.using, self.handle_forward_references ) + + def _handle_fk_field_node(self, field, field_value): + return base.deserialize_fk_value( + field, field_value, self.using, self.handle_forward_references + ) + + @staticmethod + def _get_model_from_node(model_identifier): + """Look up a model from an "app_label.model_name" string.""" + try: + return apps.get_model(model_identifier) + except (LookupError, TypeError): + raise base.DeserializationError( + f"Invalid model identifier: {model_identifier}" + ) diff --git a/django/core/serializers/pyyaml.py b/django/core/serializers/pyyaml.py index 9a20b6658f..ed6e4b3895 100644 --- a/django/core/serializers/pyyaml.py +++ b/django/core/serializers/pyyaml.py @@ -6,7 +6,6 @@ Requires PyYaml (https://pyyaml.org/), but that's checked for in __init__. import collections import decimal -from io import StringIO import yaml @@ -66,17 +65,23 @@ class Serializer(PythonSerializer): return super(PythonSerializer, self).getvalue() -def Deserializer(stream_or_string, **options): +class Deserializer(PythonDeserializer): """Deserialize a stream or string of YAML data.""" - if isinstance(stream_or_string, bytes): - stream_or_string = stream_or_string.decode() - if isinstance(stream_or_string, str): - stream = StringIO(stream_or_string) - else: + + def __init__(self, stream_or_string, **options): stream = stream_or_string - try: - yield from PythonDeserializer(yaml.load(stream, Loader=SafeLoader), **options) - except (GeneratorExit, DeserializationError): - raise - except Exception as exc: - raise DeserializationError() from exc + if isinstance(stream_or_string, bytes): + stream = stream_or_string.decode() + try: + objects = yaml.load(stream, Loader=SafeLoader) + except Exception as exc: + raise DeserializationError() from exc + super().__init__(objects, **options) + + def _handle_object(self, obj): + try: + yield from super()._handle_object(obj) + except (GeneratorExit, DeserializationError): + raise + except Exception as exc: + raise DeserializationError(f"Error deserializing object: {exc}") from exc diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index 901475a7b4..7a0361283a 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -241,7 +241,9 @@ Security Serialization ~~~~~~~~~~~~~ -* ... +* Each serialization format now defines a ``Deserializer`` class, rather than a + function, to improve extensibility when defining a + :ref:`custom serialization format `. Signals ~~~~~~~ diff --git a/docs/topics/serialization.txt b/docs/topics/serialization.txt index 0bb57642ab..12cc616a21 100644 --- a/docs/topics/serialization.txt +++ b/docs/topics/serialization.txt @@ -347,6 +347,86 @@ again a mapping with the key being name of the field and the value the value: Referential fields are again represented by the PK or sequence of PKs. +.. _custom-serialization-formats: + +Custom serialization formats +---------------------------- + +In addition to the default formats, you can create a custom serialization +format. + +For example, let’s consider a csv serializer and deserializer. First, define a +``Serializer`` and a ``Deserializer`` class. These can override existing +serialization format classes: + +.. code-block:: python + :caption: ``path/to/custom_csv_serializer.py`` + + import csv + + from django.apps import apps + from django.core import serializers + from django.core.serializers.base import DeserializationError + + + class Serializer(serializers.python.Serializer): + def get_dump_object(self, obj): + dumped_object = super().get_dump_object(obj) + row = [dumped_object["model"], str(dumped_object["pk"])] + row += [str(value) for value in dumped_object["fields"].values()] + return ",".join(row), dumped_object["model"] + + def end_object(self, obj): + dumped_object_str, model = self.get_dump_object(obj) + if self.first: + fields = [field.name for field in apps.get_model(model)._meta.fields] + header = ",".join(fields) + self.stream.write(f"model,{header}\n") + self.stream.write(f"{dumped_object_str}\n") + + def getvalue(self): + return super(serializers.python.Serializer, self).getvalue() + + + class Deserializer(serializers.python.Deserializer): + def __init__(self, stream_or_string, **options): + if isinstance(stream_or_string, bytes): + stream_or_string = stream_or_string.decode() + if isinstance(stream_or_string, str): + stream_or_string = stream_or_string.splitlines() + try: + objects = csv.DictReader(stream_or_string) + except Exception as exc: + raise DeserializationError() from exc + super().__init__(objects, **options) + + def _handle_object(self, obj): + try: + model_fields = apps.get_model(obj["model"])._meta.fields + obj["fields"] = { + field.name: obj[field.name] + for field in model_fields + if field.name in obj + } + yield from super()._handle_object(obj) + except (GeneratorExit, DeserializationError): + raise + except Exception as exc: + raise DeserializationError(f"Error deserializing object: {exc}") from exc + +Then add the module containing the serializer definitions to your +:setting:`SERIALIZATION_MODULES` setting:: + + SERIALIZATION_MODULES = { + "csv": "path.to.custom_csv_serializer", + "json": "django.core.serializers.json", + } + +.. versionchanged:: 5.2 + + A ``Deserializer`` class definition was added to each of the provided + serialization formats. + .. _topics-serialization-natural-keys: Natural keys diff --git a/tests/serializers/test_deserialization.py b/tests/serializers/test_deserialization.py new file mode 100644 index 0000000000..3c4af2ce33 --- /dev/null +++ b/tests/serializers/test_deserialization.py @@ -0,0 +1,125 @@ +import json + +from django.core.serializers.base import DeserializationError, DeserializedObject +from django.core.serializers.json import Deserializer as JsonDeserializer +from django.core.serializers.jsonl import Deserializer as JsonlDeserializer +from django.core.serializers.python import Deserializer +from django.core.serializers.pyyaml import Deserializer as YamlDeserializer +from django.test import SimpleTestCase + +from .models import Author + + +class TestDeserializer(SimpleTestCase): + def setUp(self): + self.object_list = [ + {"pk": 1, "model": "serializers.author", "fields": {"name": "Jane"}}, + {"pk": 2, "model": "serializers.author", "fields": {"name": "Joe"}}, + ] + self.deserializer = Deserializer(self.object_list) + self.jane = Author(name="Jane", pk=1) + self.joe = Author(name="Joe", pk=2) + + def test_deserialized_object_repr(self): + deserial_obj = DeserializedObject(obj=self.jane) + self.assertEqual( + repr(deserial_obj), "" + ) + + def test_next_functionality(self): + first_item = next(self.deserializer) + + self.assertEqual(first_item.object, self.jane) + + second_item = next(self.deserializer) + self.assertEqual(second_item.object, self.joe) + + with self.assertRaises(StopIteration): + next(self.deserializer) + + def test_invalid_model_identifier(self): + invalid_object_list = [ + {"pk": 1, "model": "serializers.author2", "fields": {"name": "Jane"}} + ] + self.deserializer = Deserializer(invalid_object_list) + with self.assertRaises(DeserializationError): + next(self.deserializer) + + deserializer = Deserializer(object_list=[]) + with self.assertRaises(StopIteration): + next(deserializer) + + def test_custom_deserializer(self): + class CustomDeserializer(Deserializer): + @staticmethod + def _get_model_from_node(model_identifier): + return Author + + deserializer = CustomDeserializer(self.object_list) + result = next(iter(deserializer)) + deserialized_object = result.object + self.assertEqual( + self.jane, + deserialized_object, + ) + + def test_empty_object_list(self): + deserializer = Deserializer(object_list=[]) + with self.assertRaises(StopIteration): + next(deserializer) + + def test_json_bytes_input(self): + test_string = json.dumps(self.object_list) + stream = test_string.encode("utf-8") + deserializer = JsonDeserializer(stream_or_string=stream) + + first_item = next(deserializer) + second_item = next(deserializer) + + self.assertEqual(first_item.object, self.jane) + self.assertEqual(second_item.object, self.joe) + + def test_jsonl_bytes_input(self): + test_string = """ + {"pk": 1, "model": "serializers.author", "fields": {"name": "Jane"}} + {"pk": 2, "model": "serializers.author", "fields": {"name": "Joe"}} + {"pk": 3, "model": "serializers.author", "fields": {"name": "John"}} + {"pk": 4, "model": "serializers.author", "fields": {"name": "Smith"}}""" + stream = test_string.encode("utf-8") + deserializer = JsonlDeserializer(stream_or_string=stream) + + first_item = next(deserializer) + second_item = next(deserializer) + + self.assertEqual(first_item.object, self.jane) + self.assertEqual(second_item.object, self.joe) + + def test_yaml_bytes_input(self): + test_string = """- pk: 1 + model: serializers.author + fields: + name: Jane + +- pk: 2 + model: serializers.author + fields: + name: Joe + +- pk: 3 + model: serializers.author + fields: + name: John + +- pk: 4 + model: serializers.author + fields: + name: Smith +""" + stream = test_string.encode("utf-8") + deserializer = YamlDeserializer(stream_or_string=stream) + + first_item = next(deserializer) + second_item = next(deserializer) + + self.assertEqual(first_item.object, self.jane) + self.assertEqual(second_item.object, self.joe) diff --git a/tests/serializers/test_deserializedobject.py b/tests/serializers/test_deserializedobject.py deleted file mode 100644 index 1252052100..0000000000 --- a/tests/serializers/test_deserializedobject.py +++ /dev/null @@ -1,13 +0,0 @@ -from django.core.serializers.base import DeserializedObject -from django.test import SimpleTestCase - -from .models import Author - - -class TestDeserializedObjectTests(SimpleTestCase): - def test_repr(self): - author = Author(name="John", pk=1) - deserial_obj = DeserializedObject(obj=author) - self.assertEqual( - repr(deserial_obj), "" - )