1
0
mirror of https://github.com/django/django.git synced 2025-07-04 09:49:12 +00:00

unicode: Made the serializers unicode-aware. Refs #3878, #4227.

git-svn-id: http://code.djangoproject.com/svn/django/branches/unicode@5248 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2007-05-15 16:14:55 +00:00
parent 0e4c3838ab
commit 9d001fa7f9
6 changed files with 103 additions and 102 deletions

View File

@ -6,7 +6,7 @@ Usage::
>>> from django.core import serializers >>> from django.core import serializers
>>> json = serializers.serialize("json", some_query_set) >>> json = serializers.serialize("json", some_query_set)
>>> objects = list(serializers.deserialize("json", json)) >>> objects = list(serializers.deserialize("json", json))
To add your own serializers, use the SERIALIZATION_MODULES setting:: To add your own serializers, use the SERIALIZATION_MODULES setting::
SERIALIZATION_MODULES = { SERIALIZATION_MODULES = {
@ -30,19 +30,19 @@ try:
import yaml import yaml
BUILTIN_SERIALIZERS["yaml"] = "django.core.serializers.pyyaml" BUILTIN_SERIALIZERS["yaml"] = "django.core.serializers.pyyaml"
except ImportError: except ImportError:
pass pass
_serializers = {} _serializers = {}
def register_serializer(format, serializer_module): def register_serializer(format, serializer_module):
"""Register a new serializer by passing in a module name.""" """Register a new serializer by passing in a module name."""
module = __import__(serializer_module, {}, {}, ['']) module = __import__(serializer_module, {}, {}, [''])
_serializers[format] = module _serializers[format] = module
def unregister_serializer(format): def unregister_serializer(format):
"""Unregister a given serializer""" """Unregister a given serializer"""
del _serializers[format] del _serializers[format]
def get_serializer(format): def get_serializer(format):
if not _serializers: if not _serializers:
_load_serializers() _load_serializers()
@ -52,12 +52,12 @@ def get_serializer_formats():
if not _serializers: if not _serializers:
_load_serializers() _load_serializers()
return _serializers.keys() return _serializers.keys()
def get_deserializer(format): def get_deserializer(format):
if not _serializers: if not _serializers:
_load_serializers() _load_serializers()
return _serializers[format].Deserializer return _serializers[format].Deserializer
def serialize(format, queryset, **options): def serialize(format, queryset, **options):
""" """
Serialize a queryset (or any iterator that returns database objects) using Serialize a queryset (or any iterator that returns database objects) using
@ -87,4 +87,4 @@ def _load_serializers():
register_serializer(format, BUILTIN_SERIALIZERS[format]) register_serializer(format, BUILTIN_SERIALIZERS[format])
if hasattr(settings, "SERIALIZATION_MODULES"): if hasattr(settings, "SERIALIZATION_MODULES"):
for format in settings.SERIALIZATION_MODULES: for format in settings.SERIALIZATION_MODULES:
register_serializer(format, settings.SERIALIZATION_MODULES[format]) register_serializer(format, settings.SERIALIZATION_MODULES[format])

View File

@ -7,6 +7,7 @@ try:
except ImportError: except ImportError:
from StringIO import StringIO from StringIO import StringIO
from django.db import models from django.db import models
from django.utils.encoding import smart_str, smart_unicode
class SerializationError(Exception): class SerializationError(Exception):
"""Something bad happened during serialization.""" """Something bad happened during serialization."""
@ -59,7 +60,7 @@ class Serializer(object):
value = getattr(obj, "get_%s_url" % field.name, lambda: None)() value = getattr(obj, "get_%s_url" % field.name, lambda: None)()
else: else:
value = field.flatten_data(follow=None, obj=obj).get(field.name, "") value = field.flatten_data(follow=None, obj=obj).get(field.name, "")
return str(value) return smart_unicode(value)
def start_serialization(self): def start_serialization(self):
""" """
@ -154,7 +155,7 @@ class DeserializedObject(object):
self.m2m_data = m2m_data self.m2m_data = m2m_data
def __repr__(self): def __repr__(self):
return "<DeserializedObject: %s>" % str(self.object) return "<DeserializedObject: %s>" % smart_str(self.object)
def save(self, save_m2m=True): def save(self, save_m2m=True):
self.object.save() self.object.save()

View File

@ -7,49 +7,50 @@ other serializers.
from django.conf import settings from django.conf import settings
from django.core.serializers import base from django.core.serializers import base
from django.db import models from django.db import models
from django.utils.encoding import smart_unicode
class Serializer(base.Serializer): class Serializer(base.Serializer):
""" """
Serializes a QuerySet to basic Python objects. Serializes a QuerySet to basic Python objects.
""" """
def start_serialization(self): def start_serialization(self):
self._current = None self._current = None
self.objects = [] self.objects = []
def end_serialization(self): def end_serialization(self):
pass pass
def start_object(self, obj): def start_object(self, obj):
self._current = {} self._current = {}
def end_object(self, obj): def end_object(self, obj):
self.objects.append({ self.objects.append({
"model" : str(obj._meta), "model" : smart_unicode(obj._meta),
"pk" : str(obj._get_pk_val()), "pk" : smart_unicode(obj._get_pk_val()),
"fields" : self._current "fields" : self._current
}) })
self._current = None self._current = None
def handle_field(self, obj, field): def handle_field(self, obj, field):
self._current[field.name] = getattr(obj, field.name) self._current[field.name] = getattr(obj, field.name)
def handle_fk_field(self, obj, field): def handle_fk_field(self, obj, field):
related = getattr(obj, field.name) related = getattr(obj, field.name)
if related is not None: if related is not None:
related = getattr(related, field.rel.field_name) related = getattr(related, field.rel.field_name)
self._current[field.name] = related self._current[field.name] = related
def handle_m2m_field(self, obj, field): def handle_m2m_field(self, obj, field):
self._current[field.name] = [related._get_pk_val() for related in getattr(obj, field.name).iterator()] self._current[field.name] = [related._get_pk_val() for related in getattr(obj, field.name).iterator()]
def getvalue(self): def getvalue(self):
return self.objects return self.objects
def Deserializer(object_list, **options): def Deserializer(object_list, **options):
""" """
Deserialize simple Python objects back into Django ORM instances. Deserialize simple Python objects back into Django ORM instances.
It's expected that you pass the Python objects themselves (instead of a It's expected that you pass the Python objects themselves (instead of a
stream or a string) to the constructor stream or a string) to the constructor
""" """
@ -59,36 +60,30 @@ def Deserializer(object_list, **options):
Model = _get_model(d["model"]) Model = _get_model(d["model"])
data = {Model._meta.pk.attname : Model._meta.pk.to_python(d["pk"])} data = {Model._meta.pk.attname : Model._meta.pk.to_python(d["pk"])}
m2m_data = {} m2m_data = {}
# Handle each field # Handle each field
for (field_name, field_value) in d["fields"].iteritems(): for (field_name, field_value) in d["fields"].iteritems():
if isinstance(field_value, unicode): if isinstance(field_value, str):
field_value = field_value.encode(options.get("encoding", settings.DEFAULT_CHARSET)) field_value = smart_unicode(field_value, options.get("encoding", settings.DEFAULT_CHARSET))
field = Model._meta.get_field(field_name) field = Model._meta.get_field(field_name)
# Handle M2M relations # Handle M2M relations
if field.rel and isinstance(field.rel, models.ManyToManyRel): if field.rel and isinstance(field.rel, models.ManyToManyRel):
pks = []
m2m_convert = field.rel.to._meta.pk.to_python m2m_convert = field.rel.to._meta.pk.to_python
for pk in field_value: m2m_data[field.name] = [m2m_convert(smart_unicode(pk)) for pk in field_value]
if isinstance(pk, unicode):
pks.append(m2m_convert(pk.encode(options.get("encoding", settings.DEFAULT_CHARSET))))
else:
pks.append(m2m_convert(pk))
m2m_data[field.name] = pks
# Handle FK fields # Handle FK fields
elif field.rel and isinstance(field.rel, models.ManyToOneRel): elif field.rel and isinstance(field.rel, models.ManyToOneRel):
if field_value: if field_value:
data[field.attname] = field.rel.to._meta.get_field(field.rel.field_name).to_python(field_value) data[field.attname] = field.rel.to._meta.get_field(field.rel.field_name).to_python(field_value)
else: else:
data[field.attname] = None data[field.attname] = None
# Handle all other fields # Handle all other fields
else: else:
data[field.name] = field.to_python(field_value) data[field.name] = field.to_python(field_value)
yield base.DeserializedObject(Model(**data), m2m_data) yield base.DeserializedObject(Model(**data), m2m_data)
def _get_model(model_identifier): def _get_model(model_identifier):
@ -100,5 +95,5 @@ def _get_model(model_identifier):
except TypeError: except TypeError:
Model = None Model = None
if Model is None: if Model is None:
raise base.DeserializationError("Invalid model identifier: '%s'" % model_identifier) raise base.DeserializationError(u"Invalid model identifier: '%s'" % model_identifier)
return Model return Model

View File

@ -19,7 +19,7 @@ class Serializer(PythonSerializer):
""" """
def end_serialization(self): def end_serialization(self):
yaml.dump(self.objects, self.stream, **self.options) yaml.dump(self.objects, self.stream, **self.options)
def getvalue(self): def getvalue(self):
return self.stream.getvalue() return self.stream.getvalue()
@ -33,4 +33,4 @@ def Deserializer(stream_or_string, **options):
stream = stream_or_string stream = stream_or_string
for obj in PythonDeserializer(yaml.load(stream)): for obj in PythonDeserializer(yaml.load(stream)):
yield obj yield obj

View File

@ -6,13 +6,14 @@ from django.conf import settings
from django.core.serializers import base from django.core.serializers import base
from django.db import models from django.db import models
from django.utils.xmlutils import SimplerXMLGenerator from django.utils.xmlutils import SimplerXMLGenerator
from django.utils.encoding import smart_unicode
from xml.dom import pulldom from xml.dom import pulldom
class Serializer(base.Serializer): class Serializer(base.Serializer):
""" """
Serializes a QuerySet to XML. Serializes a QuerySet to XML.
""" """
def indent(self, level): def indent(self, level):
if self.options.get('indent', None) is not None: if self.options.get('indent', None) is not None:
self.xml.ignorableWhitespace('\n' + ' ' * self.options.get('indent', None) * level) self.xml.ignorableWhitespace('\n' + ' ' * self.options.get('indent', None) * level)
@ -24,7 +25,7 @@ class Serializer(base.Serializer):
self.xml = SimplerXMLGenerator(self.stream, self.options.get("encoding", settings.DEFAULT_CHARSET)) self.xml = SimplerXMLGenerator(self.stream, self.options.get("encoding", settings.DEFAULT_CHARSET))
self.xml.startDocument() self.xml.startDocument()
self.xml.startElement("django-objects", {"version" : "1.0"}) self.xml.startElement("django-objects", {"version" : "1.0"})
def end_serialization(self): def end_serialization(self):
""" """
End serialization -- end the document. End serialization -- end the document.
@ -32,27 +33,27 @@ class Serializer(base.Serializer):
self.indent(0) self.indent(0)
self.xml.endElement("django-objects") self.xml.endElement("django-objects")
self.xml.endDocument() self.xml.endDocument()
def start_object(self, obj): def start_object(self, obj):
""" """
Called as each object is handled. Called as each object is handled.
""" """
if not hasattr(obj, "_meta"): if not hasattr(obj, "_meta"):
raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj)) raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj))
self.indent(1) self.indent(1)
self.xml.startElement("object", { self.xml.startElement("object", {
"pk" : str(obj._get_pk_val()), "pk" : smart_unicode(obj._get_pk_val()),
"model" : str(obj._meta), "model" : smart_unicode(obj._meta),
}) })
def end_object(self, obj): def end_object(self, obj):
""" """
Called after handling all fields for an object. Called after handling all fields for an object.
""" """
self.indent(1) self.indent(1)
self.xml.endElement("object") self.xml.endElement("object")
def handle_field(self, obj, field): def handle_field(self, obj, field):
""" """
Called to handle each field on an object (except for ForeignKeys and Called to handle each field on an object (except for ForeignKeys and
@ -63,17 +64,17 @@ class Serializer(base.Serializer):
"name" : field.name, "name" : field.name,
"type" : field.get_internal_type() "type" : field.get_internal_type()
}) })
# Get a "string version" of the object's data (this is handled by the # Get a "string version" of the object's data (this is handled by the
# serializer base class). # serializer base class).
if getattr(obj, field.name) is not None: if getattr(obj, field.name) is not None:
value = self.get_string_value(obj, field) value = self.get_string_value(obj, field)
self.xml.characters(str(value)) self.xml.characters(smart_unicode(value))
else: else:
self.xml.addQuickElement("None") self.xml.addQuickElement("None")
self.xml.endElement("field") self.xml.endElement("field")
def handle_fk_field(self, obj, field): def handle_fk_field(self, obj, field):
""" """
Called to handle a ForeignKey (we need to treat them slightly Called to handle a ForeignKey (we need to treat them slightly
@ -82,11 +83,11 @@ class Serializer(base.Serializer):
self._start_relational_field(field) self._start_relational_field(field)
related = getattr(obj, field.name) related = getattr(obj, field.name)
if related is not None: if related is not None:
self.xml.characters(str(getattr(related, field.rel.field_name))) self.xml.characters(smart_unicode(getattr(related, field.rel.field_name)))
else: else:
self.xml.addQuickElement("None") self.xml.addQuickElement("None")
self.xml.endElement("field") self.xml.endElement("field")
def handle_m2m_field(self, obj, field): def handle_m2m_field(self, obj, field):
""" """
Called to handle a ManyToManyField. Related objects are only Called to handle a ManyToManyField. Related objects are only
@ -95,9 +96,9 @@ class Serializer(base.Serializer):
""" """
self._start_relational_field(field) self._start_relational_field(field)
for relobj in getattr(obj, field.name).iterator(): for relobj in getattr(obj, field.name).iterator():
self.xml.addQuickElement("object", attrs={"pk" : str(relobj._get_pk_val())}) self.xml.addQuickElement("object", attrs={"pk" : smart_unicode(relobj._get_pk_val())})
self.xml.endElement("field") self.xml.endElement("field")
def _start_relational_field(self, field): def _start_relational_field(self, field):
""" """
Helper to output the <field> element for relational fields Helper to output the <field> element for relational fields
@ -106,33 +107,33 @@ class Serializer(base.Serializer):
self.xml.startElement("field", { self.xml.startElement("field", {
"name" : field.name, "name" : field.name,
"rel" : field.rel.__class__.__name__, "rel" : field.rel.__class__.__name__,
"to" : str(field.rel.to._meta), "to" : smart_unicode(field.rel.to._meta),
}) })
class Deserializer(base.Deserializer): class Deserializer(base.Deserializer):
""" """
Deserialize XML. Deserialize XML.
""" """
def __init__(self, stream_or_string, **options): def __init__(self, stream_or_string, **options):
super(Deserializer, self).__init__(stream_or_string, **options) super(Deserializer, self).__init__(stream_or_string, **options)
self.encoding = self.options.get("encoding", settings.DEFAULT_CHARSET) self.event_stream = pulldom.parse(self.stream)
self.event_stream = pulldom.parse(self.stream)
def next(self): def next(self):
for event, node in self.event_stream: for event, node in self.event_stream:
if event == "START_ELEMENT" and node.nodeName == "object": if event == "START_ELEMENT" and node.nodeName == "object":
self.event_stream.expandNode(node) self.event_stream.expandNode(node)
return self._handle_object(node) return self._handle_object(node)
raise StopIteration raise StopIteration
def _handle_object(self, node): def _handle_object(self, node):
""" """
Convert an <object> node to a DeserializedObject. Convert an <object> node to a DeserializedObject.
""" """
# Look up the model using the model loading mechanism. If this fails, bail. # Look up the model using the model loading mechanism. If this fails,
# bail.
Model = self._get_model_from_node(node, "model") Model = self._get_model_from_node(node, "model")
# Start building a data dictionary from the object. If the node is # Start building a data dictionary from the object. If the node is
# missing the pk attribute, bail. # missing the pk attribute, bail.
pk = node.getAttribute("pk") pk = node.getAttribute("pk")
@ -140,11 +141,11 @@ class Deserializer(base.Deserializer):
raise base.DeserializationError("<object> node is missing the 'pk' attribute") raise base.DeserializationError("<object> node is missing the 'pk' attribute")
data = {Model._meta.pk.attname : Model._meta.pk.to_python(pk)} data = {Model._meta.pk.attname : Model._meta.pk.to_python(pk)}
# Also start building a dict of m2m data (this is saved as # Also start building a dict of m2m data (this is saved as
# {m2m_accessor_attribute : [list_of_related_objects]}) # {m2m_accessor_attribute : [list_of_related_objects]})
m2m_data = {} m2m_data = {}
# Deseralize each field. # Deseralize each field.
for field_node in node.getElementsByTagName("field"): for field_node in node.getElementsByTagName("field"):
# If the field is missing the name attribute, bail (are you # If the field is missing the name attribute, bail (are you
@ -152,12 +153,12 @@ class Deserializer(base.Deserializer):
field_name = field_node.getAttribute("name") field_name = field_node.getAttribute("name")
if not field_name: if not field_name:
raise base.DeserializationError("<field> node is missing the 'name' attribute") raise base.DeserializationError("<field> node is missing the 'name' attribute")
# Get the field from the Model. This will raise a # Get the field from the Model. This will raise a
# FieldDoesNotExist if, well, the field doesn't exist, which will # FieldDoesNotExist if, well, the field doesn't exist, which will
# be propagated correctly. # be propagated correctly.
field = Model._meta.get_field(field_name) field = Model._meta.get_field(field_name)
# As is usually the case, relation fields get the special treatment. # As is usually the case, relation fields get the special treatment.
if field.rel and isinstance(field.rel, models.ManyToManyRel): if field.rel and isinstance(field.rel, models.ManyToManyRel):
m2m_data[field.name] = self._handle_m2m_field_node(field_node, field) m2m_data[field.name] = self._handle_m2m_field_node(field_node, field)
@ -167,12 +168,12 @@ class Deserializer(base.Deserializer):
if len(field_node.childNodes) == 1 and field_node.childNodes[0].nodeName == 'None': if len(field_node.childNodes) == 1 and field_node.childNodes[0].nodeName == 'None':
value = None value = None
else: else:
value = field.to_python(getInnerText(field_node).strip().encode(self.encoding)) value = field.to_python(getInnerText(field_node).strip())
data[field.name] = value data[field.name] = value
# Return a DeserializedObject so that the m2m data has a place to live. # Return a DeserializedObject so that the m2m data has a place to live.
return base.DeserializedObject(Model(**data), m2m_data) return base.DeserializedObject(Model(**data), m2m_data)
def _handle_fk_field_node(self, node, field): def _handle_fk_field_node(self, node, field):
""" """
Handle a <field> node for a ForeignKey Handle a <field> node for a ForeignKey
@ -182,16 +183,16 @@ class Deserializer(base.Deserializer):
return None return None
else: else:
return field.rel.to._meta.get_field(field.rel.field_name).to_python( return field.rel.to._meta.get_field(field.rel.field_name).to_python(
getInnerText(node).strip().encode(self.encoding)) getInnerText(node).strip())
def _handle_m2m_field_node(self, node, field): def _handle_m2m_field_node(self, node, field):
""" """
Handle a <field> node for a ManyToManyField Handle a <field> node for a ManyToManyField.
""" """
return [field.rel.to._meta.pk.to_python( return [field.rel.to._meta.pk.to_python(
c.getAttribute("pk").encode(self.encoding)) c.getAttribute("pk"))
for c in node.getElementsByTagName("object")] for c in node.getElementsByTagName("object")]
def _get_model_from_node(self, node, attr): def _get_model_from_node(self, node, attr):
""" """
Helper to look up a model from a <object model=...> or a <field Helper to look up a model from a <object model=...> or a <field
@ -211,8 +212,8 @@ class Deserializer(base.Deserializer):
"<%s> node has invalid model identifier: '%s'" % \ "<%s> node has invalid model identifier: '%s'" % \
(node.nodeName, model_identifier)) (node.nodeName, model_identifier))
return Model return Model
def getInnerText(node): def getInnerText(node):
""" """
Get all the inner text of a DOM node (recursively). Get all the inner text of a DOM node (recursively).
@ -226,4 +227,5 @@ def getInnerText(node):
inner_text.extend(getInnerText(child)) inner_text.extend(getInnerText(child))
else: else:
pass pass
return "".join(inner_text) return u"".join(inner_text)

View File

@ -2,7 +2,7 @@
A test spanning all the capabilities of all the serializers. A test spanning all the capabilities of all the serializers.
This class defines sample data and a dynamically generated This class defines sample data and a dynamically generated
test case that is capable of testing the capabilities of test case that is capable of testing the capabilities of
the serializers. This includes all valid data values, plus the serializers. This includes all valid data values, plus
forward, backwards and self references. forward, backwards and self references.
""" """
@ -22,7 +22,7 @@ from models import *
def data_create(pk, klass, data): def data_create(pk, klass, data):
instance = klass(id=pk) instance = klass(id=pk)
instance.data = data instance.data = data
instance.save() instance.save()
return instance return instance
def generic_create(pk, klass, data): def generic_create(pk, klass, data):
@ -32,13 +32,13 @@ def generic_create(pk, klass, data):
for tag in data[1:]: for tag in data[1:]:
instance.tags.create(data=tag) instance.tags.create(data=tag)
return instance return instance
def fk_create(pk, klass, data): def fk_create(pk, klass, data):
instance = klass(id=pk) instance = klass(id=pk)
setattr(instance, 'data_id', data) setattr(instance, 'data_id', data)
instance.save() instance.save()
return instance return instance
def m2m_create(pk, klass, data): def m2m_create(pk, klass, data):
instance = klass(id=pk) instance = klass(id=pk)
instance.save() instance.save()
@ -61,14 +61,14 @@ def pk_create(pk, klass, data):
# test data objects of various kinds # test data objects of various kinds
def data_compare(testcase, pk, klass, data): def data_compare(testcase, pk, klass, data):
instance = klass.objects.get(id=pk) instance = klass.objects.get(id=pk)
testcase.assertEqual(data, instance.data, testcase.assertEqual(data, instance.data,
"Objects with PK=%d not equal; expected '%s' (%s), got '%s' (%s)" % (pk,data, type(data), instance.data, type(instance.data))) "Objects with PK=%d not equal; expected '%s' (%s), got '%s' (%s)" % (pk,data, type(data), instance.data, type(instance.data)))
def generic_compare(testcase, pk, klass, data): def generic_compare(testcase, pk, klass, data):
instance = klass.objects.get(id=pk) instance = klass.objects.get(id=pk)
testcase.assertEqual(data[0], instance.data) testcase.assertEqual(data[0], instance.data)
testcase.assertEqual(data[1:], [t.data for t in instance.tags.all()]) testcase.assertEqual(data[1:], [t.data for t in instance.tags.all()])
def fk_compare(testcase, pk, klass, data): def fk_compare(testcase, pk, klass, data):
instance = klass.objects.get(id=pk) instance = klass.objects.get(id=pk)
testcase.assertEqual(data, instance.data_id) testcase.assertEqual(data, instance.data_id)
@ -84,7 +84,7 @@ def o2o_compare(testcase, pk, klass, data):
def pk_compare(testcase, pk, klass, data): def pk_compare(testcase, pk, klass, data):
instance = klass.objects.get(data=data) instance = klass.objects.get(data=data)
testcase.assertEqual(data, instance.data) testcase.assertEqual(data, instance.data)
# Define some data types. Each data type is # Define some data types. Each data type is
# actually a pair of functions; one to create # actually a pair of functions; one to create
# and one to compare objects of that type # and one to compare objects of that type
@ -96,7 +96,7 @@ o2o_obj = (o2o_create, o2o_compare)
pk_obj = (pk_create, pk_compare) pk_obj = (pk_create, pk_compare)
test_data = [ test_data = [
# Format: (data type, PK value, Model Class, data) # Format: (data type, PK value, Model Class, data)
(data_obj, 1, BooleanData, True), (data_obj, 1, BooleanData, True),
(data_obj, 2, BooleanData, False), (data_obj, 2, BooleanData, False),
(data_obj, 10, CharData, "Test Char Data"), (data_obj, 10, CharData, "Test Char Data"),
@ -105,6 +105,9 @@ test_data = [
(data_obj, 13, CharData, "null"), (data_obj, 13, CharData, "null"),
(data_obj, 14, CharData, "NULL"), (data_obj, 14, CharData, "NULL"),
(data_obj, 15, CharData, None), (data_obj, 15, CharData, None),
# (We use something that will fit into a latin1 database encoding here,
# because that is still the default used on many system setups.)
(data_obj, 16, CharData, u'\xa5'),
(data_obj, 20, DateData, datetime.date(2006,6,16)), (data_obj, 20, DateData, datetime.date(2006,6,16)),
(data_obj, 21, DateData, None), (data_obj, 21, DateData, None),
(data_obj, 30, DateTimeData, datetime.datetime(2006,6,16,10,42,37)), (data_obj, 30, DateTimeData, datetime.datetime(2006,6,16,10,42,37)),
@ -137,10 +140,10 @@ test_data = [
(data_obj, 131, PositiveSmallIntegerData, None), (data_obj, 131, PositiveSmallIntegerData, None),
(data_obj, 140, SlugData, "this-is-a-slug"), (data_obj, 140, SlugData, "this-is-a-slug"),
(data_obj, 141, SlugData, None), (data_obj, 141, SlugData, None),
(data_obj, 150, SmallData, 12), (data_obj, 150, SmallData, 12),
(data_obj, 151, SmallData, -12), (data_obj, 151, SmallData, -12),
(data_obj, 152, SmallData, 0), (data_obj, 152, SmallData, 0),
(data_obj, 153, SmallData, None), (data_obj, 153, SmallData, None),
(data_obj, 160, TextData, """This is a long piece of text. (data_obj, 160, TextData, """This is a long piece of text.
It contains line breaks. It contains line breaks.
Several of them. Several of them.
@ -188,7 +191,7 @@ The end."""),
(fk_obj, 450, FKDataToField, "UAnchor 1"), (fk_obj, 450, FKDataToField, "UAnchor 1"),
(fk_obj, 451, FKDataToField, "UAnchor 2"), (fk_obj, 451, FKDataToField, "UAnchor 2"),
(fk_obj, 452, FKDataToField, None), (fk_obj, 452, FKDataToField, None),
(data_obj, 500, Anchor, "Anchor 3"), (data_obj, 500, Anchor, "Anchor 3"),
(data_obj, 501, Anchor, "Anchor 4"), (data_obj, 501, Anchor, "Anchor 4"),
(data_obj, 502, UniqueAnchor, "UAnchor 2"), (data_obj, 502, UniqueAnchor, "UAnchor 2"),
@ -215,9 +218,9 @@ The end."""),
(pk_obj, 720, PositiveIntegerPKData, 123456789), (pk_obj, 720, PositiveIntegerPKData, 123456789),
(pk_obj, 730, PositiveSmallIntegerPKData, 12), (pk_obj, 730, PositiveSmallIntegerPKData, 12),
(pk_obj, 740, SlugPKData, "this-is-a-slug"), (pk_obj, 740, SlugPKData, "this-is-a-slug"),
(pk_obj, 750, SmallPKData, 12), (pk_obj, 750, SmallPKData, 12),
(pk_obj, 751, SmallPKData, -12), (pk_obj, 751, SmallPKData, -12),
(pk_obj, 752, SmallPKData, 0), (pk_obj, 752, SmallPKData, 0),
# (pk_obj, 760, TextPKData, """This is a long piece of text. # (pk_obj, 760, TextPKData, """This is a long piece of text.
# It contains line breaks. # It contains line breaks.
# Several of them. # Several of them.
@ -226,7 +229,7 @@ The end."""),
(pk_obj, 780, USStatePKData, "MA"), (pk_obj, 780, USStatePKData, "MA"),
# (pk_obj, 790, XMLPKData, "<foo></foo>"), # (pk_obj, 790, XMLPKData, "<foo></foo>"),
] ]
# Dynamically create serializer tests to ensure that all # Dynamically create serializer tests to ensure that all
# registered serializers are automatically tested. # registered serializers are automatically tested.
class SerializerTests(unittest.TestCase): class SerializerTests(unittest.TestCase):
@ -234,7 +237,7 @@ class SerializerTests(unittest.TestCase):
def serializerTest(format, self): def serializerTest(format, self):
# Clear the database first # Clear the database first
management.flush(verbosity=0, interactive=False) management.flush(verbosity=0, interactive=False)
# Create all the objects defined in the test data # Create all the objects defined in the test data
objects = [] objects = []
@ -245,14 +248,14 @@ def serializerTest(format, self):
transaction.commit() transaction.commit()
transaction.leave_transaction_management() transaction.leave_transaction_management()
# Add the generic tagged objects to the object list # Add the generic tagged objects to the object list
objects.extend(Tag.objects.all()) objects.extend(Tag.objects.all())
# Serialize the test database # Serialize the test database
serialized_data = serializers.serialize(format, objects, indent=2) serialized_data = serializers.serialize(format, objects, indent=2)
# Flush the database and recreate from the serialized data # Flush the database and recreate from the serialized data
management.flush(verbosity=0, interactive=False) management.flush(verbosity=0, interactive=False)
transaction.enter_transaction_management() transaction.enter_transaction_management()
transaction.managed(True) transaction.managed(True)
for obj in serializers.deserialize(format, serialized_data): for obj in serializers.deserialize(format, serialized_data):
@ -260,10 +263,10 @@ def serializerTest(format, self):
transaction.commit() transaction.commit()
transaction.leave_transaction_management() transaction.leave_transaction_management()
# Assert that the deserialized data is the same # Assert that the deserialized data is the same
# as the original source # as the original source
for (func, pk, klass, datum) in test_data: for (func, pk, klass, datum) in test_data:
func[1](self, pk, klass, datum) func[1](self, pk, klass, datum)
for format in serializers.get_serializer_formats(): for format in serializers.get_serializer_formats():
setattr(SerializerTests, 'test_'+format+'_serializer', curry(serializerTest, format)) setattr(SerializerTests, 'test_'+format+'_serializer', curry(serializerTest, format))