diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index c8f7a2627a..2469d01efb 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -979,7 +979,7 @@ class AddIndex(IndexOperation): return [] if isinstance(operation, RenameIndex) and self.index.name == operation.old_name: self.index.name = operation.new_name - return [AddIndex(model_name=self.model_name, index=self.index)] + return [self.__class__(model_name=self.model_name, index=self.index)] return super().reduce(operation, app_label) diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py index 0ff1dda1d9..cde4063d04 100644 --- a/tests/migrations/test_base.py +++ b/tests/migrations/test_base.py @@ -7,9 +7,11 @@ from importlib import import_module from django.apps import apps from django.db import connection, connections, migrations, models from django.db.migrations.migration import Migration +from django.db.migrations.optimizer import MigrationOptimizer from django.db.migrations.recorder import MigrationRecorder +from django.db.migrations.serializer import serializer_factory from django.db.migrations.state import ProjectState -from django.test import TransactionTestCase +from django.test import SimpleTestCase, TransactionTestCase from django.test.utils import extend_sys_path from django.utils.module_loading import module_dir @@ -400,3 +402,38 @@ class OperationTestBase(MigrationTestBase): ) ) return self.apply_operations(app_label, ProjectState(), operations) + + +class OptimizerTestBase(SimpleTestCase): + """Common functions to help test the optimizer.""" + + def optimize(self, operations, app_label): + """ + Handy shortcut for getting results + number of loops + """ + optimizer = MigrationOptimizer() + return optimizer.optimize(operations, app_label), optimizer._iterations + + def serialize(self, value): + return serializer_factory(value).serialize()[0] + + def assertOptimizesTo( + self, operations, expected, exact=None, less_than=None, app_label=None + ): + result, iterations = self.optimize(operations, app_label or "migrations") + result = [self.serialize(f) for f in result] + expected = [self.serialize(f) for f in expected] + self.assertEqual(expected, result) + if exact is not None and iterations != exact: + raise self.failureException( + "Optimization did not take exactly %s iterations (it took %s)" + % (exact, iterations) + ) + if less_than is not None and iterations >= less_than: + raise self.failureException( + "Optimization did not take less than %s iterations (it took %s)" + % (less_than, iterations) + ) + + def assertDoesNotOptimize(self, operations, **kwargs): + self.assertOptimizesTo(operations, operations, **kwargs) diff --git a/tests/migrations/test_optimizer.py b/tests/migrations/test_optimizer.py index 3ed30102bf..0a40b50edc 100644 --- a/tests/migrations/test_optimizer.py +++ b/tests/migrations/test_optimizer.py @@ -1,49 +1,17 @@ from django.db import migrations, models from django.db.migrations import operations from django.db.migrations.optimizer import MigrationOptimizer -from django.db.migrations.serializer import serializer_factory from django.db.models.functions import Abs -from django.test import SimpleTestCase from .models import EmptyManager, UnicodeModel +from .test_base import OptimizerTestBase -class OptimizerTests(SimpleTestCase): +class OptimizerTests(OptimizerTestBase): """ - Tests the migration autodetector. + Tests the migration optimizer. """ - def optimize(self, operations, app_label): - """ - Handy shortcut for getting results + number of loops - """ - optimizer = MigrationOptimizer() - return optimizer.optimize(operations, app_label), optimizer._iterations - - def serialize(self, value): - return serializer_factory(value).serialize()[0] - - def assertOptimizesTo( - self, operations, expected, exact=None, less_than=None, app_label=None - ): - result, iterations = self.optimize(operations, app_label or "migrations") - result = [self.serialize(f) for f in result] - expected = [self.serialize(f) for f in expected] - self.assertEqual(expected, result) - if exact is not None and iterations != exact: - raise self.failureException( - "Optimization did not take exactly %s iterations (it took %s)" - % (exact, iterations) - ) - if less_than is not None and iterations >= less_than: - raise self.failureException( - "Optimization did not take less than %s iterations (it took %s)" - % (less_than, iterations) - ) - - def assertDoesNotOptimize(self, operations, **kwargs): - self.assertOptimizesTo(operations, operations, **kwargs) - def test_none_app_label(self): optimizer = MigrationOptimizer() with self.assertRaisesMessage(TypeError, "app_label must be a str"): diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py index 5780348251..f344d4ae74 100644 --- a/tests/postgres_tests/test_operations.py +++ b/tests/postgres_tests/test_operations.py @@ -1,8 +1,9 @@ import unittest -from migrations.test_base import OperationTestBase +from migrations.test_base import OperationTestBase, OptimizerTestBase from django.db import IntegrityError, NotSupportedError, connection, transaction +from django.db.migrations.operations import RemoveIndex, RenameIndex from django.db.migrations.state import ProjectState from django.db.migrations.writer import OperationWriter from django.db.models import CheckConstraint, Index, Q, UniqueConstraint @@ -30,7 +31,7 @@ except ImportError: @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") @modify_settings(INSTALLED_APPS={"append": "migrations"}) -class AddIndexConcurrentlyTests(OperationTestBase): +class AddIndexConcurrentlyTests(OptimizerTestBase, OperationTestBase): app_label = "test_add_concurrently" def test_requires_atomic_false(self): @@ -129,6 +130,51 @@ class AddIndexConcurrentlyTests(OperationTestBase): ) self.assertIndexNotExists(table_name, ["pink"]) + def test_reduce_add_remove_concurrently(self): + self.assertOptimizesTo( + [ + AddIndexConcurrently( + "Pony", + Index(fields=["pink"], name="pony_pink_idx"), + ), + RemoveIndex("Pony", "pony_pink_idx"), + ], + [], + ) + + def test_reduce_add_remove(self): + self.assertOptimizesTo( + [ + AddIndexConcurrently( + "Pony", + Index(fields=["pink"], name="pony_pink_idx"), + ), + RemoveIndexConcurrently("Pony", "pony_pink_idx"), + ], + [], + ) + + def test_reduce_add_rename(self): + self.assertOptimizesTo( + [ + AddIndexConcurrently( + "Pony", + Index(fields=["pink"], name="pony_pink_idx"), + ), + RenameIndex( + "Pony", + old_name="pony_pink_idx", + new_name="pony_pink_index", + ), + ], + [ + AddIndexConcurrently( + "Pony", + Index(fields=["pink"], name="pony_pink_index"), + ), + ], + ) + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") @modify_settings(INSTALLED_APPS={"append": "migrations"})