diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index 353b992258..1dc6377494 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -219,6 +219,7 @@ class MigrationAutodetector: self.generate_altered_unique_together() self.generate_added_indexes() self.generate_added_constraints() + self.generate_altered_constraints() self.generate_altered_db_table() self._sort_migrations() @@ -1450,6 +1451,19 @@ class MigrationAutodetector: ), ) + def _constraint_should_be_dropped_and_recreated( + self, old_constraint, new_constraint + ): + old_path, old_args, old_kwargs = old_constraint.deconstruct() + new_path, new_args, new_kwargs = new_constraint.deconstruct() + + for attr in old_constraint.non_db_attrs: + old_kwargs.pop(attr, None) + for attr in new_constraint.non_db_attrs: + new_kwargs.pop(attr, None) + + return (old_path, old_args, old_kwargs) != (new_path, new_args, new_kwargs) + def create_altered_constraints(self): option_name = operations.AddConstraint.option_name for app_label, model_name in sorted(self.kept_model_keys): @@ -1461,14 +1475,41 @@ class MigrationAutodetector: old_constraints = old_model_state.options[option_name] new_constraints = new_model_state.options[option_name] - add_constraints = [c for c in new_constraints if c not in old_constraints] - rem_constraints = [c for c in old_constraints if c not in new_constraints] + + alt_constraints = [] + alt_constraints_name = [] + + for old_c in old_constraints: + for new_c in new_constraints: + old_c_dec = old_c.deconstruct() + new_c_dec = new_c.deconstruct() + if ( + old_c_dec != new_c_dec + and old_c.name == new_c.name + and not self._constraint_should_be_dropped_and_recreated( + old_c, new_c + ) + ): + alt_constraints.append(new_c) + alt_constraints_name.append(new_c.name) + + add_constraints = [ + c + for c in new_constraints + if c not in old_constraints and c.name not in alt_constraints_name + ] + rem_constraints = [ + c + for c in old_constraints + if c not in new_constraints and c.name not in alt_constraints_name + ] self.altered_constraints.update( { (app_label, model_name): { "added_constraints": add_constraints, "removed_constraints": rem_constraints, + "altered_constraints": alt_constraints, } } ) @@ -1503,6 +1544,23 @@ class MigrationAutodetector: ), ) + def generate_altered_constraints(self): + for ( + app_label, + model_name, + ), alt_constraints in self.altered_constraints.items(): + dependencies = self._get_dependencies_for_model(app_label, model_name) + for constraint in alt_constraints["altered_constraints"]: + self.add_operation( + app_label, + operations.AlterConstraint( + model_name=model_name, + name=constraint.name, + constraint=constraint, + ), + dependencies=dependencies, + ) + @staticmethod def _get_dependencies_for_foreign_key(app_label, model_name, field, project_state): remote_field_model = None diff --git a/django/db/migrations/operations/__init__.py b/django/db/migrations/operations/__init__.py index 90dbdf8256..012f2027d4 100644 --- a/django/db/migrations/operations/__init__.py +++ b/django/db/migrations/operations/__init__.py @@ -2,6 +2,7 @@ from .fields import AddField, AlterField, RemoveField, RenameField from .models import ( AddConstraint, AddIndex, + AlterConstraint, AlterIndexTogether, AlterModelManagers, AlterModelOptions, @@ -36,6 +37,7 @@ __all__ = [ "RenameField", "AddConstraint", "RemoveConstraint", + "AlterConstraint", "SeparateDatabaseAndState", "RunSQL", "RunPython", diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index 2469d01efb..40526b94c3 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -1230,6 +1230,12 @@ class AddConstraint(IndexOperation): and self.constraint.name == operation.name ): return [] + if ( + isinstance(operation, AlterConstraint) + and self.model_name_lower == operation.model_name_lower + and self.constraint.name == operation.name + ): + return [AddConstraint(self.model_name, operation.constraint)] return super().reduce(operation, app_label) @@ -1274,3 +1280,51 @@ class RemoveConstraint(IndexOperation): @property def migration_name_fragment(self): return "remove_%s_%s" % (self.model_name_lower, self.name.lower()) + + +class AlterConstraint(IndexOperation): + category = OperationCategory.ALTERATION + option_name = "constraints" + + def __init__(self, model_name, name, constraint): + self.model_name = model_name + self.name = name + self.constraint = constraint + + def state_forwards(self, app_label, state): + state.alter_constraint( + app_label, self.model_name_lower, self.name, self.constraint + ) + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + pass + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + pass + + def deconstruct(self): + return ( + self.__class__.__name__, + [], + { + "model_name": self.model_name, + "name": self.name, + "constraint": self.constraint, + }, + ) + + def describe(self): + return f"Alter constraint {self.name} on {self.model_name}" + + @property + def migration_name_fragment(self): + return "alter_%s_%s" % (self.model_name_lower, self.constraint.name.lower()) + + def reduce(self, operation, app_label): + if ( + isinstance(operation, (AlterConstraint, RemoveConstraint)) + and self.model_name_lower == operation.model_name_lower + and self.name == operation.name + ): + return [operation] + return super().reduce(operation, app_label) diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index e13de5ba6f..f3b70196db 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -211,6 +211,14 @@ class ProjectState: model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name] self.reload_model(app_label, model_name, delay=True) + def _alter_option(self, app_label, model_name, option_name, obj_name, alt_obj): + model_state = self.models[app_label, model_name] + objs = model_state.options[option_name] + model_state.options[option_name] = [ + obj if obj.name != obj_name else alt_obj for obj in objs + ] + self.reload_model(app_label, model_name, delay=True) + def add_index(self, app_label, model_name, index): self._append_option(app_label, model_name, "indexes", index) @@ -237,6 +245,11 @@ class ProjectState: def remove_constraint(self, app_label, model_name, constraint_name): self._remove_option(app_label, model_name, "constraints", constraint_name) + def alter_constraint(self, app_label, model_name, constraint_name, constraint): + self._alter_option( + app_label, model_name, "constraints", constraint_name, constraint + ) + def add_field(self, app_label, model_name, name, field, preserve_default): # If preserve default is off, don't use the default for future state. if not preserve_default: diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 788e2b635b..00829aee28 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -23,6 +23,8 @@ class BaseConstraint: violation_error_code = None violation_error_message = None + non_db_attrs = ("violation_error_code", "violation_error_message") + # RemovedInDjango60Warning: When the deprecation ends, replace with: # def __init__( # self, *, name, violation_error_code=None, violation_error_message=None diff --git a/docs/ref/migration-operations.txt b/docs/ref/migration-operations.txt index 9e90b78623..b969b3dbfd 100644 --- a/docs/ref/migration-operations.txt +++ b/docs/ref/migration-operations.txt @@ -278,6 +278,16 @@ the model with ``model_name``. Removes the constraint named ``name`` from the model with ``model_name``. +``AlterConstraint`` +------------------- + +.. versionadded:: 5.2 + +.. class:: AlterConstraint(model_name, name, constraint) + +Alters the constraint named ``name`` of the model with ``model_name`` with the +new ``constraint`` without affecting the database. + Special Operations ================== diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index f1ffe07569..2ec90429d0 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -259,7 +259,8 @@ Management Commands Migrations ~~~~~~~~~~ -* ... +* The new operation :class:`.AlterConstraint` is a no-op operation that alters + constraints without dropping and recreating constraints in the database. Models ~~~~~~ diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index d93114564a..de62170eb3 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -2969,6 +2969,71 @@ class AutodetectorTests(BaseAutodetectorTests): ["CreateModel", "AddField", "AddIndex"], ) + def test_alter_constraint(self): + book_constraint = models.CheckConstraint( + condition=models.Q(title__contains="title"), + name="title_contains_title", + ) + book_altered_constraint = models.CheckConstraint( + condition=models.Q(title__contains="title"), + name="title_contains_title", + violation_error_code="error_code", + ) + author_altered_constraint = models.CheckConstraint( + condition=models.Q(name__contains="Bob"), + name="name_contains_bob", + violation_error_message="Name doesn't contain Bob", + ) + + book_check_constraint = copy.deepcopy(self.book) + book_check_constraint_with_error_message = copy.deepcopy(self.book) + author_name_check_constraint_with_error_message = copy.deepcopy( + self.author_name_check_constraint + ) + + book_check_constraint.options = {"constraints": [book_constraint]} + book_check_constraint_with_error_message.options = { + "constraints": [book_altered_constraint] + } + author_name_check_constraint_with_error_message.options = { + "constraints": [author_altered_constraint] + } + + changes = self.get_changes( + [self.author_name_check_constraint, book_check_constraint], + [ + author_name_check_constraint_with_error_message, + book_check_constraint_with_error_message, + ], + ) + + self.assertNumberMigrations(changes, "testapp", 1) + self.assertOperationTypes(changes, "testapp", 0, ["AlterConstraint"]) + self.assertOperationAttributes( + changes, + "testapp", + 0, + 0, + model_name="author", + name="name_contains_bob", + constraint=author_altered_constraint, + ) + + self.assertNumberMigrations(changes, "otherapp", 1) + self.assertOperationTypes(changes, "otherapp", 0, ["AlterConstraint"]) + self.assertOperationAttributes( + changes, + "otherapp", + 0, + 0, + model_name="book", + name="title_contains_title", + constraint=book_altered_constraint, + ) + self.assertMigrationDependencies( + changes, "otherapp", 0, [("testapp", "auto_1")] + ) + def test_remove_constraints(self): """Test change detection of removed constraints.""" changes = self.get_changes( diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 3ac813b899..da0ec93dcd 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -4366,6 +4366,81 @@ class OperationTests(OperationTestBase): {"model_name": "Pony", "name": "test_remove_constraint_pony_pink_gt_2"}, ) + def test_alter_constraint(self): + constraint = models.UniqueConstraint( + fields=["pink"], name="test_alter_constraint_pony_fields_uq" + ) + project_state = self.set_up_test_model( + "test_alterconstraint", constraints=[constraint] + ) + + new_state = project_state.clone() + violation_error_message = "Pink isn't unique" + uq_constraint = models.UniqueConstraint( + fields=["pink"], + name="test_alter_constraint_pony_fields_uq", + violation_error_message=violation_error_message, + ) + uq_operation = migrations.AlterConstraint( + "Pony", "test_alter_constraint_pony_fields_uq", uq_constraint + ) + self.assertEqual( + uq_operation.describe(), + "Alter constraint test_alter_constraint_pony_fields_uq on Pony", + ) + self.assertEqual( + uq_operation.formatted_description(), + "~ Alter constraint test_alter_constraint_pony_fields_uq on Pony", + ) + self.assertEqual( + uq_operation.migration_name_fragment, + "alter_pony_test_alter_constraint_pony_fields_uq", + ) + + uq_operation.state_forwards("test_alterconstraint", new_state) + self.assertEqual( + project_state.models["test_alterconstraint", "pony"] + .options["constraints"][0] + .violation_error_message, + "Constraint ā€œ%(name)sā€ is violated.", + ) + self.assertEqual( + new_state.models["test_alterconstraint", "pony"] + .options["constraints"][0] + .violation_error_message, + violation_error_message, + ) + + with connection.schema_editor() as editor, self.assertNumQueries(0): + uq_operation.database_forwards( + "test_alterconstraint", editor, project_state, new_state + ) + self.assertConstraintExists( + "test_alterconstraint_pony", + "test_alter_constraint_pony_fields_uq", + value=False, + ) + with connection.schema_editor() as editor, self.assertNumQueries(0): + uq_operation.database_backwards( + "test_alterconstraint", editor, project_state, new_state + ) + self.assertConstraintExists( + "test_alterconstraint_pony", + "test_alter_constraint_pony_fields_uq", + value=False, + ) + definition = uq_operation.deconstruct() + self.assertEqual(definition[0], "AlterConstraint") + self.assertEqual(definition[1], []) + self.assertEqual( + definition[2], + { + "model_name": "Pony", + "name": "test_alter_constraint_pony_fields_uq", + "constraint": uq_constraint, + }, + ) + def test_add_partial_unique_constraint(self): project_state = self.set_up_test_model("test_addpartialuniqueconstraint") partial_unique_constraint = models.UniqueConstraint( diff --git a/tests/migrations/test_optimizer.py b/tests/migrations/test_optimizer.py index 0a40b50edc..a871e67a45 100644 --- a/tests/migrations/test_optimizer.py +++ b/tests/migrations/test_optimizer.py @@ -1232,6 +1232,80 @@ class OptimizerTests(OptimizerTestBase): ], ) + def test_multiple_alter_constraints(self): + gt_constraint_violation_msg_added = models.CheckConstraint( + condition=models.Q(pink__gt=2), + name="pink_gt_2", + violation_error_message="ERROR", + ) + gt_constraint_violation_msg_altered = models.CheckConstraint( + condition=models.Q(pink__gt=2), + name="pink_gt_2", + violation_error_message="error", + ) + self.assertOptimizesTo( + [ + migrations.AlterConstraint( + "Pony", "pink_gt_2", gt_constraint_violation_msg_added + ), + migrations.AlterConstraint( + "Pony", "pink_gt_2", gt_constraint_violation_msg_altered + ), + ], + [ + migrations.AlterConstraint( + "Pony", "pink_gt_2", gt_constraint_violation_msg_altered + ) + ], + ) + other_constraint_violation_msg = models.CheckConstraint( + condition=models.Q(weight__gt=3), + name="pink_gt_3", + violation_error_message="error", + ) + self.assertDoesNotOptimize( + [ + migrations.AlterConstraint( + "Pony", "pink_gt_2", gt_constraint_violation_msg_added + ), + migrations.AlterConstraint( + "Pony", "pink_gt_3", other_constraint_violation_msg + ), + ] + ) + + def test_alter_remove_constraint(self): + self.assertOptimizesTo( + [ + migrations.AlterConstraint( + "Pony", + "pink_gt_2", + models.CheckConstraint( + condition=models.Q(pink__gt=2), name="pink_gt_2" + ), + ), + migrations.RemoveConstraint("Pony", "pink_gt_2"), + ], + [migrations.RemoveConstraint("Pony", "pink_gt_2")], + ) + + def test_add_alter_constraint(self): + constraint = models.CheckConstraint( + condition=models.Q(pink__gt=2), name="pink_gt_2" + ) + constraint_with_error = models.CheckConstraint( + condition=models.Q(pink__gt=2), + name="pink_gt_2", + violation_error_message="error", + ) + self.assertOptimizesTo( + [ + migrations.AddConstraint("Pony", constraint), + migrations.AlterConstraint("Pony", "pink_gt_2", constraint_with_error), + ], + [migrations.AddConstraint("Pony", constraint_with_error)], + ) + def test_create_model_add_index(self): self.assertOptimizesTo( [