From e24e9e0438cf3ba6c557558f68587d135a85e126 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 21 Jul 2014 11:36:34 +0100 Subject: [PATCH] Fixed #23014: Renaming not atomic with unique together --- django/db/backends/sqlite3/schema.py | 15 +++++++++++++-- django/db/migrations/autodetector.py | 13 ++++++++++++- django/db/migrations/operations/fields.py | 8 ++++++++ tests/migrations/test_operations.py | 15 +++++++++++++-- 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index 9b62b66578..3e122d707d 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -43,7 +43,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): else: raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value))) - def _remake_table(self, model, create_fields=[], delete_fields=[], alter_fields=[], rename_fields=[], override_uniques=None): + def _remake_table(self, model, create_fields=[], delete_fields=[], alter_fields=[], override_uniques=None): """ Shortcut to transform a model from old_model into new_model """ @@ -52,6 +52,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Since mapping might mix column names and default values, # its values must be already quoted. mapping = dict((f.column, self.quote_name(f.column)) for f in model._meta.local_fields) + # This maps field names (not columns) for things like unique_together + rename_mapping = {} # If any of the new or altered fields is introducing a new PK, # remove the old one restore_pk_field = None @@ -77,6 +79,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): del mapping[old_field.column] body[new_field.name] = new_field mapping[new_field.column] = self.quote_name(old_field.column) + rename_mapping[old_field.name] = new_field.name # Remove any deleted fields for field in delete_fields: del body[field.name] @@ -92,11 +95,19 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # the internal references of some of the provided fields. body = copy.deepcopy(body) + # Work out the new value of unique_together, taking renames into + # account + if override_uniques is None: + override_uniques = [ + [rename_mapping.get(n, n) for n in unique] + for unique in model._meta.unique_together + ] + # Construct a new model for the new state meta_contents = { 'app_label': model._meta.app_label, 'db_table': model._meta.db_table + "__new", - 'unique_together': model._meta.unique_together if override_uniques is None else override_uniques, + 'unique_together': override_uniques, 'apps': apps, } meta = type("Meta", tuple(), meta_contents) diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index 4fd2246a4a..37824e169c 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -791,7 +791,18 @@ class MigrationAutodetector(object): old_model_name = self.renamed_models.get((app_label, model_name), model_name) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] - if old_model_state.options.get(option_name) != new_model_state.options.get(option_name): + # We run the old version through the field renames to account for those + if old_model_state.options.get(option_name) is None: + old_value = None + else: + old_value = [ + [ + self.renamed_fields.get((app_label, model_name, n), n) + for n in unique + ] + for unique in old_model_state.options[option_name] + ] + if old_value != new_model_state.options.get(option_name): self.add_operation( app_label, operation( diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index dee9337e95..d1e98bde3f 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -162,9 +162,17 @@ class RenameField(Operation): self.new_name = new_name def state_forwards(self, app_label, state): + # Rename the field state.models[app_label, self.model_name.lower()].fields = [ (self.new_name if n == self.old_name else n, f) for n, f in state.models[app_label, self.model_name.lower()].fields ] + # Fix unique_together to refer to the new field + options = state.models[app_label, self.model_name.lower()].options + if "unique_together" in options: + options['unique_together'] = [ + [self.new_name if n == self.old_name else n for n in unique] + for unique in options['unique_together'] + ] def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 7e3e030518..815838d826 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -46,7 +46,7 @@ class OperationTestBase(MigrationTestBase): operation.state_forwards(app_label, new_state) return project_state, new_state - def set_up_test_model(self, app_label, second_model=False, third_model=False, related_model=False, mti_model=False, proxy_model=False): + def set_up_test_model(self, app_label, second_model=False, third_model=False, related_model=False, mti_model=False, proxy_model=False, unique_together=False): """ Creates a test model state and database table. """ @@ -85,6 +85,7 @@ class OperationTestBase(MigrationTestBase): ], options={ "swappable": "TEST_SWAP_MODEL", + "unique_together": [["pink", "weight"]] if unique_together else [], }, )] if second_model: @@ -862,7 +863,7 @@ class OperationTests(OperationTestBase): """ Tests the RenameField operation. """ - project_state = self.set_up_test_model("test_rnfl") + project_state = self.set_up_test_model("test_rnfl", unique_together=True) # Test the state alteration operation = migrations.RenameField("Pony", "pink", "blue") self.assertEqual(operation.describe(), "Rename field pink on Pony to blue") @@ -870,6 +871,9 @@ class OperationTests(OperationTestBase): operation.state_forwards("test_rnfl", new_state) self.assertIn("blue", [n for n, f in new_state.models["test_rnfl", "pony"].fields]) self.assertNotIn("pink", [n for n, f in new_state.models["test_rnfl", "pony"].fields]) + # Make sure the unique_together has the renamed column too + self.assertIn("blue", new_state.models["test_rnfl", "pony"].options['unique_together'][0]) + self.assertNotIn("pink", new_state.models["test_rnfl", "pony"].options['unique_together'][0]) # Test the database alteration self.assertColumnExists("test_rnfl_pony", "pink") self.assertColumnNotExists("test_rnfl_pony", "blue") @@ -877,6 +881,13 @@ class OperationTests(OperationTestBase): operation.database_forwards("test_rnfl", editor, project_state, new_state) self.assertColumnExists("test_rnfl_pony", "blue") self.assertColumnNotExists("test_rnfl_pony", "pink") + # Ensure the unique constraint has been ported over + with connection.cursor() as cursor: + cursor.execute("INSERT INTO test_rnfl_pony (blue, weight) VALUES (1, 1)") + with self.assertRaises(IntegrityError): + with atomic(): + cursor.execute("INSERT INTO test_rnfl_pony (blue, weight) VALUES (1, 1)") + cursor.execute("DELETE FROM test_rnfl_pony") # And test reversal with connection.schema_editor() as editor: operation.database_backwards("test_rnfl", editor, new_state, project_state)