From e6f7f4533c183800c2a9ac526d8ee8887e96ac5d Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Thu, 30 May 2013 18:08:58 +0100 Subject: [PATCH] Add an Executor for end-to-end running --- django/db/migrations/executor.py | 68 ++++++++++++++++++++++ django/db/migrations/migration.py | 48 +++++++++++++++ django/db/migrations/operations/fields.py | 16 ++--- tests/migrations/migrations/0002_second.py | 2 +- tests/migrations/test_executor.py | 35 +++++++++++ tests/migrations/test_loader.py | 2 +- tests/migrations/test_operations.py | 28 ++++++++- 7 files changed, 188 insertions(+), 11 deletions(-) create mode 100644 django/db/migrations/executor.py create mode 100644 tests/migrations/test_executor.py diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py new file mode 100644 index 0000000000..e9e98d41fd --- /dev/null +++ b/django/db/migrations/executor.py @@ -0,0 +1,68 @@ +from .loader import MigrationLoader +from .recorder import MigrationRecorder + + +class MigrationExecutor(object): + """ + End-to-end migration execution - loads migrations, and runs them + up or down to a specified set of targets. + """ + + def __init__(self, connection): + self.connection = connection + self.loader = MigrationLoader(self.connection) + self.recorder = MigrationRecorder(self.connection) + + def migration_plan(self, targets): + """ + Given a set of targets, returns a list of (Migration instance, backwards?). + """ + plan = [] + applied = self.recorder.applied_migrations() + for target in targets: + # If the migration is already applied, do backwards mode, + # otherwise do forwards mode. + if target in applied: + for migration in self.loader.graph.backwards_plan(target)[:-1]: + if migration in applied: + plan.append((self.loader.graph.nodes[migration], True)) + applied.remove(migration) + else: + for migration in self.loader.graph.forwards_plan(target): + if migration not in applied: + plan.append((self.loader.graph.nodes[migration], False)) + applied.add(migration) + return plan + + def migrate(self, targets): + """ + Migrates the database up to the given targets. + """ + plan = self.migration_plan(targets) + for migration, backwards in plan: + if not backwards: + self.apply_migration(migration) + else: + self.unapply_migration(migration) + + def apply_migration(self, migration): + """ + Runs a migration forwards. + """ + print "Applying %s" % migration + with self.connection.schema_editor() as schema_editor: + project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) + migration.apply(project_state, schema_editor) + self.recorder.record_applied(migration.app_label, migration.name) + print "Finished %s" % migration + + def unapply_migration(self, migration): + """ + Runs a migration backwards. + """ + print "Unapplying %s" % migration + with self.connection.schema_editor() as schema_editor: + project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) + migration.unapply(project_state, schema_editor) + self.recorder.record_unapplied(migration.app_label, migration.name) + print "Finished %s" % migration diff --git a/django/db/migrations/migration.py b/django/db/migrations/migration.py index a8b744a9b4..672e7440ad 100644 --- a/django/db/migrations/migration.py +++ b/django/db/migrations/migration.py @@ -36,6 +36,17 @@ class Migration(object): self.name = name self.app_label = app_label + def __eq__(self, other): + if not isinstance(other, Migration): + return False + return (self.name == other.name) and (self.app_label == other.app_label) + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "" % (self.app_label, self.name) + def mutate_state(self, project_state): """ Takes a ProjectState and returns a new one with the migration's @@ -45,3 +56,40 @@ class Migration(object): for operation in self.operations: operation.state_forwards(self.app_label, new_state) return new_state + + def apply(self, project_state, schema_editor): + """ + Takes a project_state representing all migrations prior to this one + and a schema_editor for a live database and applies the migration + in a forwards order. + + Returns the resulting project state for efficient re-use by following + Migrations. + """ + for operation in self.operations: + # Get the state after the operation has run + new_state = project_state.clone() + operation.state_forwards(self.app_label, new_state) + # Run the operation + operation.database_forwards(self.app_label, schema_editor, project_state, new_state) + # Switch states + project_state = new_state + return project_state + + def unapply(self, project_state, schema_editor): + """ + Takes a project_state representing all migrations prior to this one + and a schema_editor for a live database and applies the migration + in a reverse order. + """ + # We need to pre-calculate the stack of project states + to_run = [] + for operation in self.operations: + new_state = project_state.clone() + operation.state_forwards(self.app_label, new_state) + to_run.append((operation, project_state, new_state)) + project_state = new_state + # Now run them in reverse + to_run.reverse() + for operation, to_state, from_state in to_run: + operation.database_backwards(self.app_label, schema_editor, from_state, to_state) diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 2ecf77f7ef..efb12b22c3 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -16,13 +16,13 @@ class AddField(Operation): def database_forwards(self, app_label, schema_editor, from_state, to_state): app_cache = to_state.render() - model = app_cache.get_model(app_label, self.name) - schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) + model = app_cache.get_model(app_label, self.model_name) + schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): app_cache = from_state.render() - model = app_cache.get_model(app_label, self.name) - schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) + model = app_cache.get_model(app_label, self.model_name) + schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0]) class RemoveField(Operation): @@ -43,10 +43,10 @@ class RemoveField(Operation): def database_forwards(self, app_label, schema_editor, from_state, to_state): app_cache = from_state.render() - model = app_cache.get_model(app_label, self.name) - schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) + model = app_cache.get_model(app_label, self.model_name) + schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): app_cache = to_state.render() - model = app_cache.get_model(app_label, self.name) - schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) + model = app_cache.get_model(app_label, self.model_name) + schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0]) diff --git a/tests/migrations/migrations/0002_second.py b/tests/migrations/migrations/0002_second.py index fbaef11f71..ace9a83347 100644 --- a/tests/migrations/migrations/0002_second.py +++ b/tests/migrations/migrations/0002_second.py @@ -11,7 +11,7 @@ class Migration(migrations.Migration): migrations.RemoveField("Author", "silly_field"), - migrations.AddField("Author", "important", models.BooleanField()), + migrations.AddField("Author", "rating", models.IntegerField(default=0)), migrations.CreateModel( "Book", diff --git a/tests/migrations/test_executor.py b/tests/migrations/test_executor.py new file mode 100644 index 0000000000..629c47de56 --- /dev/null +++ b/tests/migrations/test_executor.py @@ -0,0 +1,35 @@ +from django.test import TransactionTestCase +from django.db import connection +from django.db.migrations.executor import MigrationExecutor + + +class ExecutorTests(TransactionTestCase): + """ + Tests the migration executor (full end-to-end running). + + Bear in mind that if these are failing you should fix the other + test failures first, as they may be propagating into here. + """ + + def test_run(self): + """ + Tests running a simple set of migrations. + """ + executor = MigrationExecutor(connection) + # Let's look at the plan first and make sure it's up to scratch + plan = executor.migration_plan([("migrations", "0002_second")]) + self.assertEqual( + plan, + [ + (executor.loader.graph.nodes["migrations", "0001_initial"], False), + (executor.loader.graph.nodes["migrations", "0002_second"], False), + ], + ) + # Were the tables there before? + self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) + self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) + # Alright, let's try running it + executor.migrate([("migrations", "0002_second")]) + # Are the tables there now? + self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) + self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) diff --git a/tests/migrations/test_loader.py b/tests/migrations/test_loader.py index badace57cc..9318f77004 100644 --- a/tests/migrations/test_loader.py +++ b/tests/migrations/test_loader.py @@ -54,7 +54,7 @@ class LoaderTests(TransactionTestCase): author_state = project_state.models["migrations", "author"] self.assertEqual( [x for x, y in author_state.fields], - ["id", "name", "slug", "age", "important"] + ["id", "name", "slug", "age", "rating"] ) book_state = project_state.models["migrations", "book"] diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index ea6dea0302..bf8549e092 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -1,6 +1,6 @@ from django.test import TransactionTestCase from django.db import connection, models, migrations -from django.db.migrations.state import ProjectState, ModelState +from django.db.migrations.state import ProjectState class OperationTests(TransactionTestCase): @@ -16,6 +16,12 @@ class OperationTests(TransactionTestCase): def assertTableNotExists(self, table): self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor())) + def assertColumnExists(self, table, column): + self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) + + def assertColumnNotExists(self, table, column): + self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) + def set_up_test_model(self, app_label): """ Creates a test model state and database table. @@ -82,3 +88,23 @@ class OperationTests(TransactionTestCase): with connection.schema_editor() as editor: operation.database_backwards("test_dlmo", editor, new_state, project_state) self.assertTableExists("test_dlmo_pony") + + def test_add_field(self): + """ + Tests the AddField operation. + """ + project_state = self.set_up_test_model("test_adfl") + # Test the state alteration + operation = migrations.AddField("Pony", "height", models.FloatField(null=True)) + new_state = project_state.clone() + operation.state_forwards("test_adfl", new_state) + self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3) + # Test the database alteration + self.assertColumnNotExists("test_adfl_pony", "height") + with connection.schema_editor() as editor: + operation.database_forwards("test_adfl", editor, project_state, new_state) + self.assertColumnExists("test_adfl_pony", "height") + # And test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_adfl", editor, new_state, project_state) + self.assertColumnNotExists("test_adfl_pony", "height")