mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fix migration planner to fully understand squashed migrations. And test.
This commit is contained in:
parent
4cfbde71a3
commit
5ab8b5d72c
@ -11,7 +11,6 @@ class MigrationExecutor(object):
|
||||
def __init__(self, connection, progress_callback=None):
|
||||
self.connection = connection
|
||||
self.loader = MigrationLoader(self.connection)
|
||||
self.loader.load_disk()
|
||||
self.recorder = MigrationRecorder(self.connection)
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
@ -20,7 +19,7 @@ class MigrationExecutor(object):
|
||||
Given a set of targets, returns a list of (Migration instance, backwards?).
|
||||
"""
|
||||
plan = []
|
||||
applied = self.recorder.applied_migrations()
|
||||
applied = set(self.loader.applied_migrations)
|
||||
for target in targets:
|
||||
# If the target is (appname, None), that means unmigrate everything
|
||||
if target[1] is None:
|
||||
@ -87,7 +86,13 @@ class MigrationExecutor(object):
|
||||
with self.connection.schema_editor() as schema_editor:
|
||||
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
|
||||
migration.apply(project_state, schema_editor)
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_applied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
# Report prgress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_success", migration)
|
||||
|
||||
@ -101,6 +106,12 @@ class MigrationExecutor(object):
|
||||
with self.connection.schema_editor() as schema_editor:
|
||||
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
|
||||
migration.unapply(project_state, schema_editor)
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_unapplied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_success", migration)
|
||||
|
@ -1,9 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from django.utils.functional import cached_property
|
||||
from django.db.models.loading import cache
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
from django.db.migrations.graph import MigrationGraph
|
||||
from django.utils import six
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
@ -32,10 +33,12 @@ class MigrationLoader(object):
|
||||
in memory.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection, load=True):
|
||||
self.connection = connection
|
||||
self.disk_migrations = None
|
||||
self.applied_migrations = None
|
||||
if load:
|
||||
self.build_graph()
|
||||
|
||||
@classmethod
|
||||
def migrations_module(cls, app_label):
|
||||
@ -55,6 +58,7 @@ class MigrationLoader(object):
|
||||
# Get the migrations module directory
|
||||
app_label = app.__name__.split(".")[-2]
|
||||
module_name = self.migrations_module(app_label)
|
||||
was_loaded = module_name in sys.modules
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
except ImportError as e:
|
||||
@ -71,6 +75,9 @@ class MigrationLoader(object):
|
||||
# Module is not a package (e.g. migrations.py).
|
||||
if not hasattr(module, '__path__'):
|
||||
continue
|
||||
# Force a reload if it's already loaded (tests need this)
|
||||
if was_loaded:
|
||||
six.moves.reload_module(module)
|
||||
self.migrated_apps.add(app_label)
|
||||
directory = os.path.dirname(module.__file__)
|
||||
# Scan for .py[c|o] files
|
||||
@ -107,9 +114,6 @@ class MigrationLoader(object):
|
||||
|
||||
def get_migration_by_prefix(self, app_label, name_prefix):
|
||||
"Returns the migration(s) which match the given app label and name _prefix_"
|
||||
# Make sure we have the disk data
|
||||
if self.disk_migrations is None:
|
||||
self.load_disk()
|
||||
# Do the search
|
||||
results = []
|
||||
for l, n in self.disk_migrations:
|
||||
@ -122,18 +126,17 @@ class MigrationLoader(object):
|
||||
else:
|
||||
return self.disk_migrations[results[0]]
|
||||
|
||||
@cached_property
|
||||
def graph(self):
|
||||
def build_graph(self):
|
||||
"""
|
||||
Builds a migration dependency graph using both the disk and database.
|
||||
You'll need to rebuild the graph if you apply migrations. This isn't
|
||||
usually a problem as generally migration stuff runs in a one-shot process.
|
||||
"""
|
||||
# Make sure we have the disk data
|
||||
if self.disk_migrations is None:
|
||||
self.load_disk()
|
||||
# And the database data
|
||||
if self.applied_migrations is None:
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# Load disk data
|
||||
self.load_disk()
|
||||
# Load database data
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# Do a first pass to separate out replacing and non-replacing migrations
|
||||
normal = {}
|
||||
replacing = {}
|
||||
@ -152,12 +155,12 @@ class MigrationLoader(object):
|
||||
# Carry out replacements if we can - that is, if all replaced migrations
|
||||
# are either unapplied or missing.
|
||||
for key, migration in replacing.items():
|
||||
# Do the check
|
||||
can_replace = True
|
||||
for target in migration.replaces:
|
||||
if target in self.applied_migrations:
|
||||
can_replace = False
|
||||
break
|
||||
# Ensure this replacement migration is not in applied_migrations
|
||||
self.applied_migrations.discard(key)
|
||||
# Do the check. We can replace if all our replace targets are
|
||||
# applied, or if all of them are unapplied.
|
||||
applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
|
||||
can_replace = all(applied_statuses) or (not any(applied_statuses))
|
||||
if not can_replace:
|
||||
continue
|
||||
# Alright, time to replace. Step through the replaced migrations
|
||||
@ -171,14 +174,16 @@ class MigrationLoader(object):
|
||||
normal[child_key].dependencies.remove(replaced)
|
||||
normal[child_key].dependencies.append(key)
|
||||
normal[key] = migration
|
||||
# Mark the replacement as applied if all its replaced ones are
|
||||
if all(applied_statuses):
|
||||
self.applied_migrations.add(key)
|
||||
# Finally, make a graph and load everything into it
|
||||
graph = MigrationGraph()
|
||||
self.graph = MigrationGraph()
|
||||
for key, migration in normal.items():
|
||||
graph.add_node(key, migration)
|
||||
self.graph.add_node(key, migration)
|
||||
for key, migration in normal.items():
|
||||
for parent in migration.dependencies:
|
||||
graph.add_dependency(key, parent)
|
||||
return graph
|
||||
self.graph.add_dependency(key, parent)
|
||||
|
||||
|
||||
class BadMigrationError(Exception):
|
||||
|
@ -39,6 +39,11 @@ class Migration(object):
|
||||
def __init__(self, name, app_label):
|
||||
self.name = name
|
||||
self.app_label = app_label
|
||||
# Copy dependencies & other attrs as we might mutate them at runtime
|
||||
self.operations = list(self.__class__.operations)
|
||||
self.dependencies = list(self.__class__.dependencies)
|
||||
self.run_before = list(self.__class__.run_before)
|
||||
self.replaces = list(self.__class__.replaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Migration):
|
||||
|
@ -38,7 +38,58 @@ class ExecutorTests(TransactionTestCase):
|
||||
# Are the tables there now?
|
||||
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
|
||||
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
|
||||
# Rebuild the graph to reflect the new DB state
|
||||
executor.loader.build_graph()
|
||||
# Alright, let's undo what we did
|
||||
plan = executor.migration_plan([("migrations", None)])
|
||||
self.assertEqual(
|
||||
plan,
|
||||
[
|
||||
(executor.loader.graph.nodes["migrations", "0002_second"], True),
|
||||
(executor.loader.graph.nodes["migrations", "0001_initial"], True),
|
||||
],
|
||||
)
|
||||
executor.migrate([("migrations", None)])
|
||||
# Are the tables gone?
|
||||
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
|
||||
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
|
||||
|
||||
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
|
||||
def test_run_with_squashed(self):
|
||||
"""
|
||||
Tests running a squashed migration from zero (should ignore what it replaces)
|
||||
"""
|
||||
executor = MigrationExecutor(connection)
|
||||
executor.recorder.flush()
|
||||
# Check our leaf node is the squashed one
|
||||
leaves = [key for key in executor.loader.graph.leaf_nodes() if key[0] == "migrations"]
|
||||
self.assertEqual(leaves, [("migrations", "0001_squashed_0002")])
|
||||
# Check the plan
|
||||
plan = executor.migration_plan([("migrations", "0001_squashed_0002")])
|
||||
self.assertEqual(
|
||||
plan,
|
||||
[
|
||||
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], False),
|
||||
],
|
||||
)
|
||||
# Were the tables there before?
|
||||
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
|
||||
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
|
||||
# Alright, let's try running it
|
||||
executor.migrate([("migrations", "0001_squashed_0002")])
|
||||
# Are the tables there now?
|
||||
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
|
||||
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
|
||||
# Rebuild the graph to reflect the new DB state
|
||||
executor.loader.build_graph()
|
||||
# Alright, let's undo what we did. Should also just use squashed.
|
||||
plan = executor.migration_plan([("migrations", None)])
|
||||
self.assertEqual(
|
||||
plan,
|
||||
[
|
||||
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], True),
|
||||
],
|
||||
)
|
||||
executor.migrate([("migrations", None)])
|
||||
# Are the tables gone?
|
||||
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
|
||||
@ -70,6 +121,8 @@ class ExecutorTests(TransactionTestCase):
|
||||
)
|
||||
# Fake-apply all migrations
|
||||
executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True)
|
||||
# Rebuild the graph to reflect the new DB state
|
||||
executor.loader.build_graph()
|
||||
# Now plan a second time and make sure it's empty
|
||||
plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")])
|
||||
self.assertEqual(plan, [])
|
||||
|
@ -82,21 +82,34 @@ class LoaderTests(TestCase):
|
||||
migration_loader.get_migration_by_prefix("migrations", "blarg")
|
||||
|
||||
def test_load_import_error(self):
|
||||
migration_loader = MigrationLoader(connection)
|
||||
|
||||
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}):
|
||||
with self.assertRaises(ImportError):
|
||||
migration_loader.load_disk()
|
||||
MigrationLoader(connection)
|
||||
|
||||
def test_load_module_file(self):
|
||||
migration_loader = MigrationLoader(connection)
|
||||
|
||||
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}):
|
||||
migration_loader.load_disk()
|
||||
MigrationLoader(connection)
|
||||
|
||||
@skipIf(six.PY2, "PY2 doesn't load empty dirs.")
|
||||
def test_load_empty_dir(self):
|
||||
migration_loader = MigrationLoader(connection)
|
||||
|
||||
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}):
|
||||
migration_loader.load_disk()
|
||||
MigrationLoader(connection)
|
||||
|
||||
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
|
||||
def test_loading_squashed(self):
|
||||
"Tests loading a squashed migration"
|
||||
migration_loader = MigrationLoader(connection)
|
||||
recorder = MigrationRecorder(connection)
|
||||
# Loading with nothing applied should just give us the one node
|
||||
self.assertEqual(
|
||||
len(migration_loader.graph.nodes),
|
||||
1,
|
||||
)
|
||||
# However, fake-apply one migration and it should now use the old two
|
||||
recorder.record_applied("migrations", "0001_initial")
|
||||
migration_loader.build_graph()
|
||||
self.assertEqual(
|
||||
len(migration_loader.graph.nodes),
|
||||
2,
|
||||
)
|
||||
recorder.flush()
|
||||
|
27
tests/migrations/test_migrations_squashed/0001_initial.py
Normal file
27
tests/migrations/test_migrations_squashed/0001_initial.py
Normal file
@ -0,0 +1,27 @@
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
operations = [
|
||||
|
||||
migrations.CreateModel(
|
||||
"Author",
|
||||
[
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("name", models.CharField(max_length=255)),
|
||||
("slug", models.SlugField(null=True)),
|
||||
("age", models.IntegerField(default=0)),
|
||||
("silly_field", models.BooleanField(default=False)),
|
||||
],
|
||||
),
|
||||
|
||||
migrations.CreateModel(
|
||||
"Tribble",
|
||||
[
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("fluffy", models.BooleanField(default=True)),
|
||||
],
|
||||
)
|
||||
|
||||
]
|
@ -0,0 +1,32 @@
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
replaces = [
|
||||
("migrations", "0001_initial"),
|
||||
("migrations", "0002_second"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
|
||||
migrations.CreateModel(
|
||||
"Author",
|
||||
[
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("name", models.CharField(max_length=255)),
|
||||
("slug", models.SlugField(null=True)),
|
||||
("age", models.IntegerField(default=0)),
|
||||
("rating", models.IntegerField(default=0)),
|
||||
],
|
||||
),
|
||||
|
||||
migrations.CreateModel(
|
||||
"Book",
|
||||
[
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("author", models.ForeignKey("migrations.Author", null=True)),
|
||||
],
|
||||
),
|
||||
|
||||
]
|
24
tests/migrations/test_migrations_squashed/0002_second.py
Normal file
24
tests/migrations/test_migrations_squashed/0002_second.py
Normal file
@ -0,0 +1,24 @@
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [("migrations", "0001_initial")]
|
||||
|
||||
operations = [
|
||||
|
||||
migrations.DeleteModel("Tribble"),
|
||||
|
||||
migrations.RemoveField("Author", "silly_field"),
|
||||
|
||||
migrations.AddField("Author", "rating", models.IntegerField(default=0)),
|
||||
|
||||
migrations.CreateModel(
|
||||
"Book",
|
||||
[
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("author", models.ForeignKey("migrations.Author", null=True)),
|
||||
],
|
||||
)
|
||||
|
||||
]
|
Loading…
Reference in New Issue
Block a user