mirror of
https://github.com/django/django.git
synced 2025-01-23 08:39:17 +00:00
Fix altering of SERIAL columns and InnoDB being picky about FK changes
This commit is contained in:
parent
cee4fe7307
commit
5db028affb
@ -2,4 +2,57 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
pass
|
||||
|
||||
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
|
||||
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
|
||||
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
|
||||
|
||||
def _alter_column_type_sql(self, table, column, type):
|
||||
"""
|
||||
Makes ALTER TYPE with SERIAL make sense.
|
||||
"""
|
||||
if type.lower() == "serial":
|
||||
sequence_name = "%s_%s_seq" % (table, column)
|
||||
return (
|
||||
(
|
||||
self.sql_alter_column_type % {
|
||||
"column": self.quote_name(column),
|
||||
"type": "integer",
|
||||
},
|
||||
[],
|
||||
),
|
||||
[
|
||||
(
|
||||
self.sql_delete_sequence % {
|
||||
"sequence": sequence_name,
|
||||
},
|
||||
[],
|
||||
),
|
||||
(
|
||||
self.sql_create_sequence % {
|
||||
"sequence": sequence_name,
|
||||
},
|
||||
[],
|
||||
),
|
||||
(
|
||||
self.sql_alter_column % {
|
||||
"table": table,
|
||||
"changes": self.sql_alter_column_default % {
|
||||
"column": column,
|
||||
"default": "nextval('%s')" % sequence_name,
|
||||
}
|
||||
},
|
||||
[],
|
||||
),
|
||||
(
|
||||
self.sql_set_sequence_max % {
|
||||
"table": table,
|
||||
"column": column,
|
||||
"sequence": sequence_name,
|
||||
},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
else:
|
||||
return super(DatabaseSchemaEditor, self)._alter_column_type_sql(table, column, type)
|
||||
|
@ -498,6 +498,18 @@ class BaseDatabaseSchemaEditor(object):
|
||||
"name": fk_name,
|
||||
}
|
||||
)
|
||||
# Drop incoming FK constraints if we're a primary key and things are going
|
||||
# to change.
|
||||
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
||||
for rel in new_field.model._meta.get_all_related_objects():
|
||||
rel_fk_names = self._constraint_names(rel.model, [rel.field.column], foreign_key=True)
|
||||
for fk_name in rel_fk_names:
|
||||
self.execute(
|
||||
self.sql_delete_fk % {
|
||||
"table": self.quote_name(rel.model._meta.db_table),
|
||||
"name": fk_name,
|
||||
}
|
||||
)
|
||||
# Change check constraints?
|
||||
if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
|
||||
constraint_names = self._constraint_names(model, [old_field.column], check=True)
|
||||
@ -524,15 +536,12 @@ class BaseDatabaseSchemaEditor(object):
|
||||
})
|
||||
# Next, start accumulating actions to do
|
||||
actions = []
|
||||
post_actions = []
|
||||
# Type change?
|
||||
if old_type != new_type:
|
||||
actions.append((
|
||||
self.sql_alter_column_type % {
|
||||
"column": self.quote_name(new_field.column),
|
||||
"type": new_type,
|
||||
},
|
||||
[],
|
||||
))
|
||||
fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type)
|
||||
actions.append(fragment)
|
||||
post_actions.extend(other_actions)
|
||||
# Default change?
|
||||
old_default = self.effective_default(old_field)
|
||||
new_default = self.effective_default(new_field)
|
||||
@ -596,6 +605,9 @@ class BaseDatabaseSchemaEditor(object):
|
||||
},
|
||||
params,
|
||||
)
|
||||
if post_actions:
|
||||
for sql, params in post_actions:
|
||||
self.execute(sql, params)
|
||||
# Added a unique?
|
||||
if not old_field.unique and new_field.unique:
|
||||
self.execute(
|
||||
@ -619,7 +631,7 @@ class BaseDatabaseSchemaEditor(object):
|
||||
# referring to us.
|
||||
rels_to_update = []
|
||||
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
||||
rels_to_update.extend(model._meta.get_all_related_objects())
|
||||
rels_to_update.extend(new_field.model._meta.get_all_related_objects())
|
||||
# Changed to become primary key?
|
||||
# Note that we don't detect unsetting of a PK, as we assume another field
|
||||
# will always come along and replace it.
|
||||
@ -647,8 +659,8 @@ class BaseDatabaseSchemaEditor(object):
|
||||
}
|
||||
)
|
||||
# Update all referencing columns
|
||||
rels_to_update.extend(model._meta.get_all_related_objects())
|
||||
# Handle out type alters on the other end of rels from the PK stuff above
|
||||
rels_to_update.extend(new_field.model._meta.get_all_related_objects())
|
||||
# Handle our type alters on the other end of rels from the PK stuff above
|
||||
for rel in rels_to_update:
|
||||
rel_db_params = rel.field.db_parameters(connection=self.connection)
|
||||
rel_type = rel_db_params['type']
|
||||
@ -672,6 +684,18 @@ class BaseDatabaseSchemaEditor(object):
|
||||
"to_column": self.quote_name(new_field.rel.get_related_field().column),
|
||||
}
|
||||
)
|
||||
# Rebuild FKs that pointed to us if we previously had to drop them
|
||||
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
||||
for rel in new_field.model._meta.get_all_related_objects():
|
||||
self.execute(
|
||||
self.sql_create_fk % {
|
||||
"table": self.quote_name(rel.model._meta.db_table),
|
||||
"name": self._create_index_name(rel.model, [rel.field.column], suffix="_fk"),
|
||||
"column": self.quote_name(rel.field.column),
|
||||
"to_table": self.quote_name(model._meta.db_table),
|
||||
"to_column": self.quote_name(new_field.column),
|
||||
}
|
||||
)
|
||||
# Does it have check constraints we need to add?
|
||||
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
|
||||
self.execute(
|
||||
@ -686,6 +710,27 @@ class BaseDatabaseSchemaEditor(object):
|
||||
if self.connection.features.connection_persists_old_columns:
|
||||
self.connection.close()
|
||||
|
||||
def _alter_column_type_sql(self, table, column, type):
|
||||
"""
|
||||
Hook to specialise column type alteration for different backends,
|
||||
for cases when a creation type is different to an alteration type
|
||||
(e.g. SERIAL in PostgreSQL, PostGIS fields).
|
||||
|
||||
Should return two things; an SQL fragment of (sql, params) to insert
|
||||
into an ALTER TABLE statement, and a list of extra (sql, params) tuples
|
||||
to run once the field is altered.
|
||||
"""
|
||||
return (
|
||||
(
|
||||
self.sql_alter_column_type % {
|
||||
"column": self.quote_name(column),
|
||||
"type": type,
|
||||
},
|
||||
[],
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||
"""
|
||||
Alters M2Ms to repoint their to= endpoints.
|
||||
|
@ -24,9 +24,10 @@ class AddField(Operation):
|
||||
state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.render().get_model(app_label, self.model_name)
|
||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||
schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
|
||||
schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.render().get_model(app_label, self.model_name)
|
||||
@ -73,9 +74,10 @@ class RemoveField(Operation):
|
||||
schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.render().get_model(app_label, self.model_name)
|
||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||
schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
|
||||
schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
|
||||
|
||||
def describe(self):
|
||||
return "Remove field %s from %s" % (self.name, self.model_name)
|
||||
@ -107,7 +109,7 @@ class AlterField(Operation):
|
||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||
schema_editor.alter_field(
|
||||
to_model,
|
||||
from_model,
|
||||
from_model._meta.get_field_by_name(self.name)[0],
|
||||
to_model._meta.get_field_by_name(self.name)[0],
|
||||
)
|
||||
@ -153,7 +155,7 @@ class RenameField(Operation):
|
||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||
schema_editor.alter_field(
|
||||
to_model,
|
||||
from_model,
|
||||
from_model._meta.get_field_by_name(self.old_name)[0],
|
||||
to_model._meta.get_field_by_name(self.new_name)[0],
|
||||
)
|
||||
@ -163,7 +165,7 @@ class RenameField(Operation):
|
||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||
schema_editor.alter_field(
|
||||
to_model,
|
||||
from_model,
|
||||
from_model._meta.get_field_by_name(self.new_name)[0],
|
||||
to_model._meta.get_field_by_name(self.old_name)[0],
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import unittest
|
||||
from django.db import connection, models, migrations, router
|
||||
from django.db.models.fields import NOT_PROVIDED
|
||||
from django.db.transaction import atomic
|
||||
@ -13,7 +14,7 @@ class OperationTests(MigrationTestBase):
|
||||
both forwards and backwards.
|
||||
"""
|
||||
|
||||
def set_up_test_model(self, app_label, second_model=False):
|
||||
def set_up_test_model(self, app_label, second_model=False, related_model=False):
|
||||
"""
|
||||
Creates a test model state and database table.
|
||||
"""
|
||||
@ -38,6 +39,14 @@ class OperationTests(MigrationTestBase):
|
||||
)]
|
||||
if second_model:
|
||||
operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))]))
|
||||
if related_model:
|
||||
operations.append(migrations.CreateModel(
|
||||
"Rider",
|
||||
[
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("pony", models.ForeignKey("Pony")),
|
||||
],
|
||||
))
|
||||
project_state = ProjectState()
|
||||
for operation in operations:
|
||||
operation.state_forwards(app_label, project_state)
|
||||
@ -269,6 +278,52 @@ class OperationTests(MigrationTestBase):
|
||||
operation.database_backwards("test_alfl", editor, new_state, project_state)
|
||||
self.assertColumnNotNull("test_alfl_pony", "pink")
|
||||
|
||||
def test_alter_field_pk(self):
|
||||
"""
|
||||
Tests the AlterField operation on primary keys (for things like PostgreSQL's SERIAL weirdness)
|
||||
"""
|
||||
project_state = self.set_up_test_model("test_alflpk")
|
||||
# Test the state alteration
|
||||
operation = migrations.AlterField("Pony", "id", models.IntegerField(primary_key=True))
|
||||
new_state = project_state.clone()
|
||||
operation.state_forwards("test_alflpk", new_state)
|
||||
self.assertIsInstance(project_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.AutoField)
|
||||
self.assertIsInstance(new_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.IntegerField)
|
||||
# Test the database alteration
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_forwards("test_alflpk", editor, project_state, new_state)
|
||||
# And test reversal
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_backwards("test_alflpk", editor, new_state, project_state)
|
||||
|
||||
@unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support")
|
||||
def test_alter_field_pk_fk(self):
|
||||
"""
|
||||
Tests the AlterField operation on primary keys changes any FKs pointing to it.
|
||||
"""
|
||||
project_state = self.set_up_test_model("test_alflpkfk", related_model=True)
|
||||
# Test the state alteration
|
||||
operation = migrations.AlterField("Pony", "id", models.FloatField(primary_key=True))
|
||||
new_state = project_state.clone()
|
||||
operation.state_forwards("test_alflpkfk", new_state)
|
||||
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
|
||||
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
|
||||
# Test the database alteration
|
||||
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
|
||||
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
|
||||
self.assertEqual(id_type, fk_type)
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
|
||||
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
|
||||
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
|
||||
self.assertEqual(id_type, fk_type)
|
||||
# And test reversal
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
|
||||
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
|
||||
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
|
||||
self.assertEqual(id_type, fk_type)
|
||||
|
||||
def test_rename_field(self):
|
||||
"""
|
||||
Tests the RenameField operation.
|
||||
|
@ -636,6 +636,7 @@ class SchemaTests(TransactionTestCase):
|
||||
# Alter to change the PK
|
||||
new_field = SlugField(primary_key=True)
|
||||
new_field.set_attributes_from_name("slug")
|
||||
new_field.model = Tag
|
||||
with connection.schema_editor() as editor:
|
||||
editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0])
|
||||
editor.alter_field(
|
||||
|
Loading…
x
Reference in New Issue
Block a user