diff --git a/tests/serializers/test_data.py b/tests/serializers/test_data.py index c913c59dd3..f8763f6e42 100644 --- a/tests/serializers/test_data.py +++ b/tests/serializers/test_data.py @@ -10,6 +10,7 @@ forward, backwards and self references. import datetime import decimal import uuid +from collections import namedtuple from django.core import serializers from django.db import connection, models @@ -239,22 +240,22 @@ def inherited_compare(testcase, pk, klass, data): testcase.assertEqual(value, getattr(instance, key)) -# Define some data types. Each data type is -# actually a pair of functions; one to create -# and one to compare objects of that type -data_obj = (data_create, data_compare) -generic_obj = (generic_create, generic_compare) -fk_obj = (fk_create, fk_compare) -m2m_obj = (m2m_create, m2m_compare) -im2m_obj = (im2m_create, im2m_compare) -im_obj = (im_create, im_compare) -o2o_obj = (o2o_create, o2o_compare) -pk_obj = (pk_create, pk_compare) -inherited_obj = (inherited_create, inherited_compare) +# Define some test helpers. Each has a pair of functions: one to create objects and one +# to make assertions against objects of a particular type. +TestHelper = namedtuple("TestHelper", ["create_object", "compare_object"]) +data_obj = TestHelper(data_create, data_compare) +generic_obj = TestHelper(generic_create, generic_compare) +fk_obj = TestHelper(fk_create, fk_compare) +m2m_obj = TestHelper(m2m_create, m2m_compare) +im2m_obj = TestHelper(im2m_create, im2m_compare) +im_obj = TestHelper(im_create, im_compare) +o2o_obj = TestHelper(o2o_create, o2o_compare) +pk_obj = TestHelper(pk_create, pk_compare) +inherited_obj = TestHelper(inherited_create, inherited_compare) uuid_obj = uuid.uuid4() test_data = [ - # Format: (data type, PK value, Model Class, data) + # Format: (test helper, PK value, Model Class, data) (data_obj, 1, BinaryData, memoryview(b"\x05\xFD\x00")), (data_obj, 5, BooleanData, True), (data_obj, 6, BooleanData, False), @@ -410,35 +411,36 @@ class SerializerDataTests(TestCase): def assert_serializer(self, format, data): - # Create all the objects defined in the test data + # Create all the objects defined in the test data. objects = [] - instance_count = {} - for func, pk, klass, datum in test_data: + for test_helper, pk, model, data_value in data: with connection.constraint_checks_disabled(): - objects.extend(func[0](pk, klass, datum)) + objects.extend(test_helper.create_object(pk, model, data_value)) - # Get a count of the number of objects created for each class - for klass in instance_count: - instance_count[klass] = klass.objects.count() + # Get a count of the number of objects created for each model class. + instance_counts = {} + for _, _, model, _ in data: + if model not in instance_counts: + instance_counts[model] = model.objects.count() - # Add the generic tagged objects to the object list + # Add the generic tagged objects to the object list. objects.extend(Tag.objects.all()) - # Serialize the test database + # Serialize the test database. serialized_data = serializers.serialize(format, objects, indent=2) for obj in serializers.deserialize(format, serialized_data): obj.save() - # Assert that the deserialized data is the same - # as the original source - for func, pk, klass, datum in test_data: - func[1](self, pk, klass, datum) + # Assert that the deserialized data is the same as the original source. + for test_helper, pk, model, data_value in data: + with self.subTest(model=model, data_value=data_value): + test_helper.compare_object(self, pk, model, data_value) - # Assert that the number of objects deserialized is the - # same as the number that was serialized. - for klass, count in instance_count.items(): - self.assertEqual(count, klass.objects.count()) + # Assert no new objects were created. + for model, count in instance_counts.items(): + with self.subTest(model=model, count=count): + self.assertEqual(count, model.objects.count()) def serializerTest(self, format):