mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fix migration planner to fully understand squashed migrations. And test.
This commit is contained in:
		| @@ -11,7 +11,6 @@ class MigrationExecutor(object): | |||||||
|     def __init__(self, connection, progress_callback=None): |     def __init__(self, connection, progress_callback=None): | ||||||
|         self.connection = connection |         self.connection = connection | ||||||
|         self.loader = MigrationLoader(self.connection) |         self.loader = MigrationLoader(self.connection) | ||||||
|         self.loader.load_disk() |  | ||||||
|         self.recorder = MigrationRecorder(self.connection) |         self.recorder = MigrationRecorder(self.connection) | ||||||
|         self.progress_callback = progress_callback |         self.progress_callback = progress_callback | ||||||
|  |  | ||||||
| @@ -20,7 +19,7 @@ class MigrationExecutor(object): | |||||||
|         Given a set of targets, returns a list of (Migration instance, backwards?). |         Given a set of targets, returns a list of (Migration instance, backwards?). | ||||||
|         """ |         """ | ||||||
|         plan = [] |         plan = [] | ||||||
|         applied = self.recorder.applied_migrations() |         applied = set(self.loader.applied_migrations) | ||||||
|         for target in targets: |         for target in targets: | ||||||
|             # If the target is (appname, None), that means unmigrate everything |             # If the target is (appname, None), that means unmigrate everything | ||||||
|             if target[1] is None: |             if target[1] is None: | ||||||
| @@ -87,7 +86,13 @@ class MigrationExecutor(object): | |||||||
|             with self.connection.schema_editor() as schema_editor: |             with self.connection.schema_editor() as schema_editor: | ||||||
|                 project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) |                 project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) | ||||||
|                 migration.apply(project_state, schema_editor) |                 migration.apply(project_state, schema_editor) | ||||||
|  |         # For replacement migrations, record individual statuses | ||||||
|  |         if migration.replaces: | ||||||
|  |             for app_label, name in migration.replaces: | ||||||
|  |                 self.recorder.record_applied(app_label, name) | ||||||
|  |         else: | ||||||
|             self.recorder.record_applied(migration.app_label, migration.name) |             self.recorder.record_applied(migration.app_label, migration.name) | ||||||
|  |         # Report prgress | ||||||
|         if self.progress_callback: |         if self.progress_callback: | ||||||
|             self.progress_callback("apply_success", migration) |             self.progress_callback("apply_success", migration) | ||||||
|  |  | ||||||
| @@ -101,6 +106,12 @@ class MigrationExecutor(object): | |||||||
|             with self.connection.schema_editor() as schema_editor: |             with self.connection.schema_editor() as schema_editor: | ||||||
|                 project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) |                 project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) | ||||||
|                 migration.unapply(project_state, schema_editor) |                 migration.unapply(project_state, schema_editor) | ||||||
|  |         # For replacement migrations, record individual statuses | ||||||
|  |         if migration.replaces: | ||||||
|  |             for app_label, name in migration.replaces: | ||||||
|  |                 self.recorder.record_unapplied(app_label, name) | ||||||
|  |         else: | ||||||
|             self.recorder.record_unapplied(migration.app_label, migration.name) |             self.recorder.record_unapplied(migration.app_label, migration.name) | ||||||
|  |         # Report progress | ||||||
|         if self.progress_callback: |         if self.progress_callback: | ||||||
|             self.progress_callback("unapply_success", migration) |             self.progress_callback("unapply_success", migration) | ||||||
|   | |||||||
| @@ -1,9 +1,10 @@ | |||||||
| import os | import os | ||||||
|  | import sys | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
| from django.utils.functional import cached_property |  | ||||||
| from django.db.models.loading import cache | from django.db.models.loading import cache | ||||||
| from django.db.migrations.recorder import MigrationRecorder | from django.db.migrations.recorder import MigrationRecorder | ||||||
| from django.db.migrations.graph import MigrationGraph | from django.db.migrations.graph import MigrationGraph | ||||||
|  | from django.utils import six | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -32,10 +33,12 @@ class MigrationLoader(object): | |||||||
|     in memory. |     in memory. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, connection): |     def __init__(self, connection, load=True): | ||||||
|         self.connection = connection |         self.connection = connection | ||||||
|         self.disk_migrations = None |         self.disk_migrations = None | ||||||
|         self.applied_migrations = None |         self.applied_migrations = None | ||||||
|  |         if load: | ||||||
|  |             self.build_graph() | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def migrations_module(cls, app_label): |     def migrations_module(cls, app_label): | ||||||
| @@ -55,6 +58,7 @@ class MigrationLoader(object): | |||||||
|             # Get the migrations module directory |             # Get the migrations module directory | ||||||
|             app_label = app.__name__.split(".")[-2] |             app_label = app.__name__.split(".")[-2] | ||||||
|             module_name = self.migrations_module(app_label) |             module_name = self.migrations_module(app_label) | ||||||
|  |             was_loaded = module_name in sys.modules | ||||||
|             try: |             try: | ||||||
|                 module = import_module(module_name) |                 module = import_module(module_name) | ||||||
|             except ImportError as e: |             except ImportError as e: | ||||||
| @@ -71,6 +75,9 @@ class MigrationLoader(object): | |||||||
|                 # Module is not a package (e.g. migrations.py). |                 # Module is not a package (e.g. migrations.py). | ||||||
|                 if not hasattr(module, '__path__'): |                 if not hasattr(module, '__path__'): | ||||||
|                     continue |                     continue | ||||||
|  |                 # Force a reload if it's already loaded (tests need this) | ||||||
|  |                 if was_loaded: | ||||||
|  |                     six.moves.reload_module(module) | ||||||
|             self.migrated_apps.add(app_label) |             self.migrated_apps.add(app_label) | ||||||
|             directory = os.path.dirname(module.__file__) |             directory = os.path.dirname(module.__file__) | ||||||
|             # Scan for .py[c|o] files |             # Scan for .py[c|o] files | ||||||
| @@ -107,9 +114,6 @@ class MigrationLoader(object): | |||||||
|  |  | ||||||
|     def get_migration_by_prefix(self, app_label, name_prefix): |     def get_migration_by_prefix(self, app_label, name_prefix): | ||||||
|         "Returns the migration(s) which match the given app label and name _prefix_" |         "Returns the migration(s) which match the given app label and name _prefix_" | ||||||
|         # Make sure we have the disk data |  | ||||||
|         if self.disk_migrations is None: |  | ||||||
|             self.load_disk() |  | ||||||
|         # Do the search |         # Do the search | ||||||
|         results = [] |         results = [] | ||||||
|         for l, n in self.disk_migrations: |         for l, n in self.disk_migrations: | ||||||
| @@ -122,16 +126,15 @@ class MigrationLoader(object): | |||||||
|         else: |         else: | ||||||
|             return self.disk_migrations[results[0]] |             return self.disk_migrations[results[0]] | ||||||
|  |  | ||||||
|     @cached_property |     def build_graph(self): | ||||||
|     def graph(self): |  | ||||||
|         """ |         """ | ||||||
|         Builds a migration dependency graph using both the disk and database. |         Builds a migration dependency graph using both the disk and database. | ||||||
|  |         You'll need to rebuild the graph if you apply migrations. This isn't | ||||||
|  |         usually a problem as generally migration stuff runs in a one-shot process. | ||||||
|         """ |         """ | ||||||
|         # Make sure we have the disk data |         # Load disk data | ||||||
|         if self.disk_migrations is None: |  | ||||||
|         self.load_disk() |         self.load_disk() | ||||||
|         # And the database data |         # Load database data | ||||||
|         if self.applied_migrations is None: |  | ||||||
|         recorder = MigrationRecorder(self.connection) |         recorder = MigrationRecorder(self.connection) | ||||||
|         self.applied_migrations = recorder.applied_migrations() |         self.applied_migrations = recorder.applied_migrations() | ||||||
|         # Do a first pass to separate out replacing and non-replacing migrations |         # Do a first pass to separate out replacing and non-replacing migrations | ||||||
| @@ -152,12 +155,12 @@ class MigrationLoader(object): | |||||||
|         # Carry out replacements if we can - that is, if all replaced migrations |         # Carry out replacements if we can - that is, if all replaced migrations | ||||||
|         # are either unapplied or missing. |         # are either unapplied or missing. | ||||||
|         for key, migration in replacing.items(): |         for key, migration in replacing.items(): | ||||||
|             # Do the check |             # Ensure this replacement migration is not in applied_migrations | ||||||
|             can_replace = True |             self.applied_migrations.discard(key) | ||||||
|             for target in migration.replaces: |             # Do the check. We can replace if all our replace targets are | ||||||
|                 if target in self.applied_migrations: |             # applied, or if all of them are unapplied. | ||||||
|                     can_replace = False |             applied_statuses = [(target in self.applied_migrations) for target in migration.replaces] | ||||||
|                     break |             can_replace = all(applied_statuses) or (not any(applied_statuses)) | ||||||
|             if not can_replace: |             if not can_replace: | ||||||
|                 continue |                 continue | ||||||
|             # Alright, time to replace. Step through the replaced migrations |             # Alright, time to replace. Step through the replaced migrations | ||||||
| @@ -171,14 +174,16 @@ class MigrationLoader(object): | |||||||
|                     normal[child_key].dependencies.remove(replaced) |                     normal[child_key].dependencies.remove(replaced) | ||||||
|                     normal[child_key].dependencies.append(key) |                     normal[child_key].dependencies.append(key) | ||||||
|             normal[key] = migration |             normal[key] = migration | ||||||
|  |             # Mark the replacement as applied if all its replaced ones are | ||||||
|  |             if all(applied_statuses): | ||||||
|  |                 self.applied_migrations.add(key) | ||||||
|         # Finally, make a graph and load everything into it |         # Finally, make a graph and load everything into it | ||||||
|         graph = MigrationGraph() |         self.graph = MigrationGraph() | ||||||
|         for key, migration in normal.items(): |         for key, migration in normal.items(): | ||||||
|             graph.add_node(key, migration) |             self.graph.add_node(key, migration) | ||||||
|         for key, migration in normal.items(): |         for key, migration in normal.items(): | ||||||
|             for parent in migration.dependencies: |             for parent in migration.dependencies: | ||||||
|                 graph.add_dependency(key, parent) |                 self.graph.add_dependency(key, parent) | ||||||
|         return graph |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class BadMigrationError(Exception): | class BadMigrationError(Exception): | ||||||
|   | |||||||
| @@ -39,6 +39,11 @@ class Migration(object): | |||||||
|     def __init__(self, name, app_label): |     def __init__(self, name, app_label): | ||||||
|         self.name = name |         self.name = name | ||||||
|         self.app_label = app_label |         self.app_label = app_label | ||||||
|  |         # Copy dependencies & other attrs as we might mutate them at runtime | ||||||
|  |         self.operations = list(self.__class__.operations) | ||||||
|  |         self.dependencies = list(self.__class__.dependencies) | ||||||
|  |         self.run_before = list(self.__class__.run_before) | ||||||
|  |         self.replaces = list(self.__class__.replaces) | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         if not isinstance(other, Migration): |         if not isinstance(other, Migration): | ||||||
|   | |||||||
| @@ -38,7 +38,58 @@ class ExecutorTests(TransactionTestCase): | |||||||
|         # Are the tables there now? |         # Are the tables there now? | ||||||
|         self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) |         self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) | ||||||
|         self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) |         self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) | ||||||
|  |         # Rebuild the graph to reflect the new DB state | ||||||
|  |         executor.loader.build_graph() | ||||||
|         # Alright, let's undo what we did |         # Alright, let's undo what we did | ||||||
|  |         plan = executor.migration_plan([("migrations", None)]) | ||||||
|  |         self.assertEqual( | ||||||
|  |             plan, | ||||||
|  |             [ | ||||||
|  |                 (executor.loader.graph.nodes["migrations", "0002_second"], True), | ||||||
|  |                 (executor.loader.graph.nodes["migrations", "0001_initial"], True), | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         executor.migrate([("migrations", None)]) | ||||||
|  |         # Are the tables gone? | ||||||
|  |         self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) | ||||||
|  |         self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) | ||||||
|  |  | ||||||
|  |     @override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"}) | ||||||
|  |     def test_run_with_squashed(self): | ||||||
|  |         """ | ||||||
|  |         Tests running a squashed migration from zero (should ignore what it replaces) | ||||||
|  |         """ | ||||||
|  |         executor = MigrationExecutor(connection) | ||||||
|  |         executor.recorder.flush() | ||||||
|  |         # Check our leaf node is the squashed one | ||||||
|  |         leaves = [key for key in executor.loader.graph.leaf_nodes() if key[0] == "migrations"] | ||||||
|  |         self.assertEqual(leaves, [("migrations", "0001_squashed_0002")]) | ||||||
|  |         # Check the plan | ||||||
|  |         plan = executor.migration_plan([("migrations", "0001_squashed_0002")]) | ||||||
|  |         self.assertEqual( | ||||||
|  |             plan, | ||||||
|  |             [ | ||||||
|  |                 (executor.loader.graph.nodes["migrations", "0001_squashed_0002"], False), | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         # Were the tables there before? | ||||||
|  |         self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) | ||||||
|  |         self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) | ||||||
|  |         # Alright, let's try running it | ||||||
|  |         executor.migrate([("migrations", "0001_squashed_0002")]) | ||||||
|  |         # Are the tables there now? | ||||||
|  |         self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) | ||||||
|  |         self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) | ||||||
|  |         # Rebuild the graph to reflect the new DB state | ||||||
|  |         executor.loader.build_graph() | ||||||
|  |         # Alright, let's undo what we did. Should also just use squashed. | ||||||
|  |         plan = executor.migration_plan([("migrations", None)]) | ||||||
|  |         self.assertEqual( | ||||||
|  |             plan, | ||||||
|  |             [ | ||||||
|  |                 (executor.loader.graph.nodes["migrations", "0001_squashed_0002"], True), | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|         executor.migrate([("migrations", None)]) |         executor.migrate([("migrations", None)]) | ||||||
|         # Are the tables gone? |         # Are the tables gone? | ||||||
|         self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) |         self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) | ||||||
| @@ -70,6 +121,8 @@ class ExecutorTests(TransactionTestCase): | |||||||
|         ) |         ) | ||||||
|         # Fake-apply all migrations |         # Fake-apply all migrations | ||||||
|         executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True) |         executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True) | ||||||
|  |         # Rebuild the graph to reflect the new DB state | ||||||
|  |         executor.loader.build_graph() | ||||||
|         # Now plan a second time and make sure it's empty |         # Now plan a second time and make sure it's empty | ||||||
|         plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")]) |         plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")]) | ||||||
|         self.assertEqual(plan, []) |         self.assertEqual(plan, []) | ||||||
|   | |||||||
| @@ -82,21 +82,34 @@ class LoaderTests(TestCase): | |||||||
|             migration_loader.get_migration_by_prefix("migrations", "blarg") |             migration_loader.get_migration_by_prefix("migrations", "blarg") | ||||||
|  |  | ||||||
|     def test_load_import_error(self): |     def test_load_import_error(self): | ||||||
|         migration_loader = MigrationLoader(connection) |  | ||||||
|  |  | ||||||
|         with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}): |         with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}): | ||||||
|             with self.assertRaises(ImportError): |             with self.assertRaises(ImportError): | ||||||
|                 migration_loader.load_disk() |                 MigrationLoader(connection) | ||||||
|  |  | ||||||
|     def test_load_module_file(self): |     def test_load_module_file(self): | ||||||
|         migration_loader = MigrationLoader(connection) |  | ||||||
|  |  | ||||||
|         with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}): |         with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}): | ||||||
|             migration_loader.load_disk() |             MigrationLoader(connection) | ||||||
|  |  | ||||||
|     @skipIf(six.PY2, "PY2 doesn't load empty dirs.") |     @skipIf(six.PY2, "PY2 doesn't load empty dirs.") | ||||||
|     def test_load_empty_dir(self): |     def test_load_empty_dir(self): | ||||||
|         migration_loader = MigrationLoader(connection) |  | ||||||
|  |  | ||||||
|         with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}): |         with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}): | ||||||
|             migration_loader.load_disk() |             MigrationLoader(connection) | ||||||
|  |  | ||||||
|  |     @override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"}) | ||||||
|  |     def test_loading_squashed(self): | ||||||
|  |         "Tests loading a squashed migration" | ||||||
|  |         migration_loader = MigrationLoader(connection) | ||||||
|  |         recorder = MigrationRecorder(connection) | ||||||
|  |         # Loading with nothing applied should just give us the one node | ||||||
|  |         self.assertEqual( | ||||||
|  |             len(migration_loader.graph.nodes), | ||||||
|  |             1, | ||||||
|  |         ) | ||||||
|  |         # However, fake-apply one migration and it should now use the old two | ||||||
|  |         recorder.record_applied("migrations", "0001_initial") | ||||||
|  |         migration_loader.build_graph() | ||||||
|  |         self.assertEqual( | ||||||
|  |             len(migration_loader.graph.nodes), | ||||||
|  |             2, | ||||||
|  |         ) | ||||||
|  |         recorder.flush() | ||||||
|   | |||||||
							
								
								
									
										27
									
								
								tests/migrations/test_migrations_squashed/0001_initial.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								tests/migrations/test_migrations_squashed/0001_initial.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | |||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |  | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             "Author", | ||||||
|  |             [ | ||||||
|  |                 ("id", models.AutoField(primary_key=True)), | ||||||
|  |                 ("name", models.CharField(max_length=255)), | ||||||
|  |                 ("slug", models.SlugField(null=True)), | ||||||
|  |                 ("age", models.IntegerField(default=0)), | ||||||
|  |                 ("silly_field", models.BooleanField(default=False)), | ||||||
|  |             ], | ||||||
|  |         ), | ||||||
|  |  | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             "Tribble", | ||||||
|  |             [ | ||||||
|  |                 ("id", models.AutoField(primary_key=True)), | ||||||
|  |                 ("fluffy", models.BooleanField(default=True)), | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     ] | ||||||
| @@ -0,0 +1,32 @@ | |||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     replaces = [ | ||||||
|  |         ("migrations", "0001_initial"), | ||||||
|  |         ("migrations", "0002_second"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |  | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             "Author", | ||||||
|  |             [ | ||||||
|  |                 ("id", models.AutoField(primary_key=True)), | ||||||
|  |                 ("name", models.CharField(max_length=255)), | ||||||
|  |                 ("slug", models.SlugField(null=True)), | ||||||
|  |                 ("age", models.IntegerField(default=0)), | ||||||
|  |                 ("rating", models.IntegerField(default=0)), | ||||||
|  |             ], | ||||||
|  |         ), | ||||||
|  |  | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             "Book", | ||||||
|  |             [ | ||||||
|  |                 ("id", models.AutoField(primary_key=True)), | ||||||
|  |                 ("author", models.ForeignKey("migrations.Author", null=True)), | ||||||
|  |             ], | ||||||
|  |         ), | ||||||
|  |  | ||||||
|  |     ] | ||||||
							
								
								
									
										24
									
								
								tests/migrations/test_migrations_squashed/0002_second.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								tests/migrations/test_migrations_squashed/0002_second.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [("migrations", "0001_initial")] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |  | ||||||
|  |         migrations.DeleteModel("Tribble"), | ||||||
|  |  | ||||||
|  |         migrations.RemoveField("Author", "silly_field"), | ||||||
|  |  | ||||||
|  |         migrations.AddField("Author", "rating", models.IntegerField(default=0)), | ||||||
|  |  | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             "Book", | ||||||
|  |             [ | ||||||
|  |                 ("id", models.AutoField(primary_key=True)), | ||||||
|  |                 ("author", models.ForeignKey("migrations.Author", null=True)), | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     ] | ||||||
		Reference in New Issue
	
	Block a user