From 98f23a8af0be7e87535426c5c83058e2682bfdf8 Mon Sep 17 00:00:00 2001 From: Matthijs Kooijman Date: Mon, 2 Dec 2019 00:42:06 +0100 Subject: [PATCH] Fixed #26552 -- Deferred constraint checks when reloading the database with data for tests. deserialize_db_from_string() loads the full serialized database contents, which might contain forward references and cycles. That caused IntegrityError because constraints were checked immediately. Now, it loads data in a transaction with constraint checks deferred until the end of the transaction. --- django/db/backends/base/creation.py | 13 +++++++++++-- tests/backends/base/test_creation.py | 28 ++++++++++++++++++++++++++++ tests/backends/models.py | 1 + 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/django/db/backends/base/creation.py b/django/db/backends/base/creation.py index c50fc90171..230954ddfc 100644 --- a/django/db/backends/base/creation.py +++ b/django/db/backends/base/creation.py @@ -6,6 +6,7 @@ from django.apps import apps from django.conf import settings from django.core import serializers from django.db import router +from django.db.transaction import atomic # The prefix to put on the default database name when creating # the test database. @@ -126,8 +127,16 @@ class BaseDatabaseCreation: the serialize_db_to_string() method. """ data = StringIO(data) - for obj in serializers.deserialize("json", data, using=self.connection.alias): - obj.save() + # Load data in a transaction to handle forward references and cycles. + with atomic(using=self.connection.alias): + # Disable constraint checks, because some databases (MySQL) doesn't + # support deferred checks. + with self.connection.constraint_checks_disabled(): + for obj in serializers.deserialize('json', data, using=self.connection.alias): + obj.save() + # Manually check for any invalid keys that might have been added, + # because constraint checks were disabled. + self.connection.check_constraints() def _get_database_display_str(self, verbosity, database_name): """ diff --git a/tests/backends/base/test_creation.py b/tests/backends/base/test_creation.py index b91466911a..f627f2e7c8 100644 --- a/tests/backends/base/test_creation.py +++ b/tests/backends/base/test_creation.py @@ -7,6 +7,8 @@ from django.db.backends.base.creation import ( ) from django.test import SimpleTestCase +from ..models import Object, ObjectReference + def get_connection_copy(): # Get a copy of the default connection. (Can't use django.db.connection @@ -73,3 +75,29 @@ class TestDbCreationTests(SimpleTestCase): finally: with mock.patch.object(creation, '_destroy_test_db'): creation.destroy_test_db(old_database_name, verbosity=0) + + +class TestDeserializeDbFromString(SimpleTestCase): + databases = {'default'} + + def test_circular_reference(self): + # deserialize_db_from_string() handles circular references. + data = """ + [ + { + "model": "backends.object", + "pk": 1, + "fields": {"obj_ref": 1, "related_objects": []} + }, + { + "model": "backends.objectreference", + "pk": 1, + "fields": {"obj": 1} + } + ] + """ + connection.creation.deserialize_db_from_string(data) + obj = Object.objects.get() + obj_ref = ObjectReference.objects.get() + self.assertEqual(obj.obj_ref, obj_ref) + self.assertEqual(obj_ref.obj, obj) diff --git a/tests/backends/models.py b/tests/backends/models.py index 1fa8d44e63..277a3a1203 100644 --- a/tests/backends/models.py +++ b/tests/backends/models.py @@ -89,6 +89,7 @@ class Item(models.Model): class Object(models.Model): related_objects = models.ManyToManyField("self", db_constraint=False, symmetrical=False) + obj_ref = models.ForeignKey('ObjectReference', models.CASCADE, null=True) def __str__(self): return str(self.id)