1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #35038 -- Created AlterConstraint operation.

This commit is contained in:
Salvo Polizzi 2024-11-28 13:52:58 +01:00 committed by Sarah Boyce
parent b92511b474
commit b82f80906a
10 changed files with 357 additions and 3 deletions

View File

@ -219,6 +219,7 @@ class MigrationAutodetector:
self.generate_altered_unique_together()
self.generate_added_indexes()
self.generate_added_constraints()
self.generate_altered_constraints()
self.generate_altered_db_table()
self._sort_migrations()
@ -1450,6 +1451,19 @@ class MigrationAutodetector:
),
)
def _constraint_should_be_dropped_and_recreated(
self, old_constraint, new_constraint
):
old_path, old_args, old_kwargs = old_constraint.deconstruct()
new_path, new_args, new_kwargs = new_constraint.deconstruct()
for attr in old_constraint.non_db_attrs:
old_kwargs.pop(attr, None)
for attr in new_constraint.non_db_attrs:
new_kwargs.pop(attr, None)
return (old_path, old_args, old_kwargs) != (new_path, new_args, new_kwargs)
def create_altered_constraints(self):
option_name = operations.AddConstraint.option_name
for app_label, model_name in sorted(self.kept_model_keys):
@ -1461,14 +1475,41 @@ class MigrationAutodetector:
old_constraints = old_model_state.options[option_name]
new_constraints = new_model_state.options[option_name]
add_constraints = [c for c in new_constraints if c not in old_constraints]
rem_constraints = [c for c in old_constraints if c not in new_constraints]
alt_constraints = []
alt_constraints_name = []
for old_c in old_constraints:
for new_c in new_constraints:
old_c_dec = old_c.deconstruct()
new_c_dec = new_c.deconstruct()
if (
old_c_dec != new_c_dec
and old_c.name == new_c.name
and not self._constraint_should_be_dropped_and_recreated(
old_c, new_c
)
):
alt_constraints.append(new_c)
alt_constraints_name.append(new_c.name)
add_constraints = [
c
for c in new_constraints
if c not in old_constraints and c.name not in alt_constraints_name
]
rem_constraints = [
c
for c in old_constraints
if c not in new_constraints and c.name not in alt_constraints_name
]
self.altered_constraints.update(
{
(app_label, model_name): {
"added_constraints": add_constraints,
"removed_constraints": rem_constraints,
"altered_constraints": alt_constraints,
}
}
)
@ -1503,6 +1544,23 @@ class MigrationAutodetector:
),
)
def generate_altered_constraints(self):
for (
app_label,
model_name,
), alt_constraints in self.altered_constraints.items():
dependencies = self._get_dependencies_for_model(app_label, model_name)
for constraint in alt_constraints["altered_constraints"]:
self.add_operation(
app_label,
operations.AlterConstraint(
model_name=model_name,
name=constraint.name,
constraint=constraint,
),
dependencies=dependencies,
)
@staticmethod
def _get_dependencies_for_foreign_key(app_label, model_name, field, project_state):
remote_field_model = None

View File

@ -2,6 +2,7 @@ from .fields import AddField, AlterField, RemoveField, RenameField
from .models import (
AddConstraint,
AddIndex,
AlterConstraint,
AlterIndexTogether,
AlterModelManagers,
AlterModelOptions,
@ -36,6 +37,7 @@ __all__ = [
"RenameField",
"AddConstraint",
"RemoveConstraint",
"AlterConstraint",
"SeparateDatabaseAndState",
"RunSQL",
"RunPython",

View File

@ -1230,6 +1230,12 @@ class AddConstraint(IndexOperation):
and self.constraint.name == operation.name
):
return []
if (
isinstance(operation, AlterConstraint)
and self.model_name_lower == operation.model_name_lower
and self.constraint.name == operation.name
):
return [AddConstraint(self.model_name, operation.constraint)]
return super().reduce(operation, app_label)
@ -1274,3 +1280,51 @@ class RemoveConstraint(IndexOperation):
@property
def migration_name_fragment(self):
return "remove_%s_%s" % (self.model_name_lower, self.name.lower())
class AlterConstraint(IndexOperation):
category = OperationCategory.ALTERATION
option_name = "constraints"
def __init__(self, model_name, name, constraint):
self.model_name = model_name
self.name = name
self.constraint = constraint
def state_forwards(self, app_label, state):
state.alter_constraint(
app_label, self.model_name_lower, self.name, self.constraint
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
def database_backwards(self, app_label, schema_editor, from_state, to_state):
pass
def deconstruct(self):
return (
self.__class__.__name__,
[],
{
"model_name": self.model_name,
"name": self.name,
"constraint": self.constraint,
},
)
def describe(self):
return f"Alter constraint {self.name} on {self.model_name}"
@property
def migration_name_fragment(self):
return "alter_%s_%s" % (self.model_name_lower, self.constraint.name.lower())
def reduce(self, operation, app_label):
if (
isinstance(operation, (AlterConstraint, RemoveConstraint))
and self.model_name_lower == operation.model_name_lower
and self.name == operation.name
):
return [operation]
return super().reduce(operation, app_label)

View File

@ -211,6 +211,14 @@ class ProjectState:
model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
self.reload_model(app_label, model_name, delay=True)
def _alter_option(self, app_label, model_name, option_name, obj_name, alt_obj):
model_state = self.models[app_label, model_name]
objs = model_state.options[option_name]
model_state.options[option_name] = [
obj if obj.name != obj_name else alt_obj for obj in objs
]
self.reload_model(app_label, model_name, delay=True)
def add_index(self, app_label, model_name, index):
self._append_option(app_label, model_name, "indexes", index)
@ -237,6 +245,11 @@ class ProjectState:
def remove_constraint(self, app_label, model_name, constraint_name):
self._remove_option(app_label, model_name, "constraints", constraint_name)
def alter_constraint(self, app_label, model_name, constraint_name, constraint):
self._alter_option(
app_label, model_name, "constraints", constraint_name, constraint
)
def add_field(self, app_label, model_name, name, field, preserve_default):
# If preserve default is off, don't use the default for future state.
if not preserve_default:

View File

@ -23,6 +23,8 @@ class BaseConstraint:
violation_error_code = None
violation_error_message = None
non_db_attrs = ("violation_error_code", "violation_error_message")
# RemovedInDjango60Warning: When the deprecation ends, replace with:
# def __init__(
# self, *, name, violation_error_code=None, violation_error_message=None

View File

@ -278,6 +278,16 @@ the model with ``model_name``.
Removes the constraint named ``name`` from the model with ``model_name``.
``AlterConstraint``
-------------------
.. versionadded:: 5.2
.. class:: AlterConstraint(model_name, name, constraint)
Alters the constraint named ``name`` of the model with ``model_name`` with the
new ``constraint`` without affecting the database.
Special Operations
==================

View File

@ -259,7 +259,8 @@ Management Commands
Migrations
~~~~~~~~~~
* ...
* The new operation :class:`.AlterConstraint` is a no-op operation that alters
constraints without dropping and recreating constraints in the database.
Models
~~~~~~

View File

@ -2969,6 +2969,71 @@ class AutodetectorTests(BaseAutodetectorTests):
["CreateModel", "AddField", "AddIndex"],
)
def test_alter_constraint(self):
book_constraint = models.CheckConstraint(
condition=models.Q(title__contains="title"),
name="title_contains_title",
)
book_altered_constraint = models.CheckConstraint(
condition=models.Q(title__contains="title"),
name="title_contains_title",
violation_error_code="error_code",
)
author_altered_constraint = models.CheckConstraint(
condition=models.Q(name__contains="Bob"),
name="name_contains_bob",
violation_error_message="Name doesn't contain Bob",
)
book_check_constraint = copy.deepcopy(self.book)
book_check_constraint_with_error_message = copy.deepcopy(self.book)
author_name_check_constraint_with_error_message = copy.deepcopy(
self.author_name_check_constraint
)
book_check_constraint.options = {"constraints": [book_constraint]}
book_check_constraint_with_error_message.options = {
"constraints": [book_altered_constraint]
}
author_name_check_constraint_with_error_message.options = {
"constraints": [author_altered_constraint]
}
changes = self.get_changes(
[self.author_name_check_constraint, book_check_constraint],
[
author_name_check_constraint_with_error_message,
book_check_constraint_with_error_message,
],
)
self.assertNumberMigrations(changes, "testapp", 1)
self.assertOperationTypes(changes, "testapp", 0, ["AlterConstraint"])
self.assertOperationAttributes(
changes,
"testapp",
0,
0,
model_name="author",
name="name_contains_bob",
constraint=author_altered_constraint,
)
self.assertNumberMigrations(changes, "otherapp", 1)
self.assertOperationTypes(changes, "otherapp", 0, ["AlterConstraint"])
self.assertOperationAttributes(
changes,
"otherapp",
0,
0,
model_name="book",
name="title_contains_title",
constraint=book_altered_constraint,
)
self.assertMigrationDependencies(
changes, "otherapp", 0, [("testapp", "auto_1")]
)
def test_remove_constraints(self):
"""Test change detection of removed constraints."""
changes = self.get_changes(

View File

@ -4366,6 +4366,81 @@ class OperationTests(OperationTestBase):
{"model_name": "Pony", "name": "test_remove_constraint_pony_pink_gt_2"},
)
def test_alter_constraint(self):
constraint = models.UniqueConstraint(
fields=["pink"], name="test_alter_constraint_pony_fields_uq"
)
project_state = self.set_up_test_model(
"test_alterconstraint", constraints=[constraint]
)
new_state = project_state.clone()
violation_error_message = "Pink isn't unique"
uq_constraint = models.UniqueConstraint(
fields=["pink"],
name="test_alter_constraint_pony_fields_uq",
violation_error_message=violation_error_message,
)
uq_operation = migrations.AlterConstraint(
"Pony", "test_alter_constraint_pony_fields_uq", uq_constraint
)
self.assertEqual(
uq_operation.describe(),
"Alter constraint test_alter_constraint_pony_fields_uq on Pony",
)
self.assertEqual(
uq_operation.formatted_description(),
"~ Alter constraint test_alter_constraint_pony_fields_uq on Pony",
)
self.assertEqual(
uq_operation.migration_name_fragment,
"alter_pony_test_alter_constraint_pony_fields_uq",
)
uq_operation.state_forwards("test_alterconstraint", new_state)
self.assertEqual(
project_state.models["test_alterconstraint", "pony"]
.options["constraints"][0]
.violation_error_message,
"Constraint “%(name)s” is violated.",
)
self.assertEqual(
new_state.models["test_alterconstraint", "pony"]
.options["constraints"][0]
.violation_error_message,
violation_error_message,
)
with connection.schema_editor() as editor, self.assertNumQueries(0):
uq_operation.database_forwards(
"test_alterconstraint", editor, project_state, new_state
)
self.assertConstraintExists(
"test_alterconstraint_pony",
"test_alter_constraint_pony_fields_uq",
value=False,
)
with connection.schema_editor() as editor, self.assertNumQueries(0):
uq_operation.database_backwards(
"test_alterconstraint", editor, project_state, new_state
)
self.assertConstraintExists(
"test_alterconstraint_pony",
"test_alter_constraint_pony_fields_uq",
value=False,
)
definition = uq_operation.deconstruct()
self.assertEqual(definition[0], "AlterConstraint")
self.assertEqual(definition[1], [])
self.assertEqual(
definition[2],
{
"model_name": "Pony",
"name": "test_alter_constraint_pony_fields_uq",
"constraint": uq_constraint,
},
)
def test_add_partial_unique_constraint(self):
project_state = self.set_up_test_model("test_addpartialuniqueconstraint")
partial_unique_constraint = models.UniqueConstraint(

View File

@ -1232,6 +1232,80 @@ class OptimizerTests(OptimizerTestBase):
],
)
def test_multiple_alter_constraints(self):
gt_constraint_violation_msg_added = models.CheckConstraint(
condition=models.Q(pink__gt=2),
name="pink_gt_2",
violation_error_message="ERROR",
)
gt_constraint_violation_msg_altered = models.CheckConstraint(
condition=models.Q(pink__gt=2),
name="pink_gt_2",
violation_error_message="error",
)
self.assertOptimizesTo(
[
migrations.AlterConstraint(
"Pony", "pink_gt_2", gt_constraint_violation_msg_added
),
migrations.AlterConstraint(
"Pony", "pink_gt_2", gt_constraint_violation_msg_altered
),
],
[
migrations.AlterConstraint(
"Pony", "pink_gt_2", gt_constraint_violation_msg_altered
)
],
)
other_constraint_violation_msg = models.CheckConstraint(
condition=models.Q(weight__gt=3),
name="pink_gt_3",
violation_error_message="error",
)
self.assertDoesNotOptimize(
[
migrations.AlterConstraint(
"Pony", "pink_gt_2", gt_constraint_violation_msg_added
),
migrations.AlterConstraint(
"Pony", "pink_gt_3", other_constraint_violation_msg
),
]
)
def test_alter_remove_constraint(self):
self.assertOptimizesTo(
[
migrations.AlterConstraint(
"Pony",
"pink_gt_2",
models.CheckConstraint(
condition=models.Q(pink__gt=2), name="pink_gt_2"
),
),
migrations.RemoveConstraint("Pony", "pink_gt_2"),
],
[migrations.RemoveConstraint("Pony", "pink_gt_2")],
)
def test_add_alter_constraint(self):
constraint = models.CheckConstraint(
condition=models.Q(pink__gt=2), name="pink_gt_2"
)
constraint_with_error = models.CheckConstraint(
condition=models.Q(pink__gt=2),
name="pink_gt_2",
violation_error_message="error",
)
self.assertOptimizesTo(
[
migrations.AddConstraint("Pony", constraint),
migrations.AlterConstraint("Pony", "pink_gt_2", constraint_with_error),
],
[migrations.AddConstraint("Pony", constraint_with_error)],
)
def test_create_model_add_index(self):
self.assertOptimizesTo(
[