mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #11964 -- Added support for database check constraints.
This commit is contained in:
@@ -172,6 +172,7 @@ class BaseDatabaseFeatures:
|
||||
|
||||
# Does it support CHECK constraints?
|
||||
supports_column_check_constraints = True
|
||||
supports_table_check_constraints = True
|
||||
|
||||
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
|
||||
# parameter passing? Note this can be provided by the backend even if not
|
||||
|
@@ -63,7 +63,8 @@ class BaseDatabaseSchemaEditor:
|
||||
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
|
||||
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
|
||||
|
||||
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||
sql_check = "CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||
sql_create_check = "ALTER TABLE %(table)s ADD %(check)s"
|
||||
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
|
||||
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
|
||||
@@ -299,10 +300,11 @@ class BaseDatabaseSchemaEditor:
|
||||
for fields in model._meta.unique_together:
|
||||
columns = [model._meta.get_field(field).column for field in fields]
|
||||
self.deferred_sql.append(self._create_unique_sql(model, columns))
|
||||
constraints = [check.constraint_sql(model, self) for check in model._meta.constraints]
|
||||
# Make the table
|
||||
sql = self.sql_create_table % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"definition": ", ".join(column_sqls)
|
||||
"definition": ", ".join((*column_sqls, *constraints)),
|
||||
}
|
||||
if model._meta.db_tablespace:
|
||||
tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
|
||||
@@ -343,6 +345,14 @@ class BaseDatabaseSchemaEditor:
|
||||
"""Remove an index from a model."""
|
||||
self.execute(index.remove_sql(model, self))
|
||||
|
||||
def add_constraint(self, model, constraint):
|
||||
"""Add a check constraint to a model."""
|
||||
self.execute(constraint.create_sql(model, self))
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
"""Remove a check constraint from a model."""
|
||||
self.execute(constraint.remove_sql(model, self))
|
||||
|
||||
def alter_unique_together(self, model, old_unique_together, new_unique_together):
|
||||
"""
|
||||
Deal with a model changing its unique_together. The input
|
||||
@@ -752,11 +762,12 @@ class BaseDatabaseSchemaEditor:
|
||||
self.execute(
|
||||
self.sql_create_check % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"name": self.quote_name(
|
||||
self._create_index_name(model._meta.db_table, [new_field.column], suffix="_check")
|
||||
),
|
||||
"column": self.quote_name(new_field.column),
|
||||
"check": new_db_params['check'],
|
||||
"check": self.sql_check % {
|
||||
'name': self.quote_name(
|
||||
self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'),
|
||||
),
|
||||
'check': new_db_params['check'],
|
||||
},
|
||||
}
|
||||
)
|
||||
# Drop the default if we need to
|
||||
|
@@ -26,6 +26,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
can_release_savepoints = True
|
||||
atomic_transactions = False
|
||||
supports_column_check_constraints = False
|
||||
supports_table_check_constraints = False
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_select_intersection = False
|
||||
|
@@ -126,7 +126,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
else:
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
|
||||
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
|
||||
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None,
|
||||
add_constraint=None, remove_constraint=None):
|
||||
"""
|
||||
Shortcut to transform a model from old_model into new_model
|
||||
|
||||
@@ -222,6 +223,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
if delete_field.name not in index.fields
|
||||
]
|
||||
|
||||
constraints = list(model._meta.constraints)
|
||||
if add_constraint:
|
||||
constraints.append(add_constraint)
|
||||
if remove_constraint:
|
||||
constraints = [
|
||||
constraint for constraint in constraints
|
||||
if remove_constraint.name != constraint.name
|
||||
]
|
||||
|
||||
# Construct a new model for the new state
|
||||
meta_contents = {
|
||||
'app_label': model._meta.app_label,
|
||||
@@ -229,6 +239,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
'unique_together': unique_together,
|
||||
'index_together': index_together,
|
||||
'indexes': indexes,
|
||||
'constraints': constraints,
|
||||
'apps': apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
@@ -362,3 +373,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
))
|
||||
# Delete the old through table
|
||||
self.delete_model(old_field.remote_field.through)
|
||||
|
||||
def add_constraint(self, model, constraint):
|
||||
self._remake_table(model, add_constraint=constraint)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
self._remake_table(model, remove_constraint=constraint)
|
||||
|
@@ -122,6 +122,7 @@ class MigrationAutodetector:
|
||||
# resolve dependencies caused by M2Ms and FKs.
|
||||
self.generated_operations = {}
|
||||
self.altered_indexes = {}
|
||||
self.altered_constraints = {}
|
||||
|
||||
# Prepare some old/new state and model lists, separating
|
||||
# proxy models and ignoring unmigrated apps.
|
||||
@@ -175,7 +176,9 @@ class MigrationAutodetector:
|
||||
# This avoids the same computation in generate_removed_indexes()
|
||||
# and generate_added_indexes().
|
||||
self.create_altered_indexes()
|
||||
self.create_altered_constraints()
|
||||
# Generate index removal operations before field is removed
|
||||
self.generate_removed_constraints()
|
||||
self.generate_removed_indexes()
|
||||
# Generate field operations
|
||||
self.generate_renamed_fields()
|
||||
@@ -185,6 +188,7 @@ class MigrationAutodetector:
|
||||
self.generate_altered_unique_together()
|
||||
self.generate_altered_index_together()
|
||||
self.generate_added_indexes()
|
||||
self.generate_added_constraints()
|
||||
self.generate_altered_db_table()
|
||||
self.generate_altered_order_with_respect_to()
|
||||
|
||||
@@ -533,6 +537,7 @@ class MigrationAutodetector:
|
||||
related_fields[field.name] = field
|
||||
# Are there indexes/unique|index_together to defer?
|
||||
indexes = model_state.options.pop('indexes')
|
||||
constraints = model_state.options.pop('constraints')
|
||||
unique_together = model_state.options.pop('unique_together', None)
|
||||
index_together = model_state.options.pop('index_together', None)
|
||||
order_with_respect_to = model_state.options.pop('order_with_respect_to', None)
|
||||
@@ -601,6 +606,15 @@ class MigrationAutodetector:
|
||||
),
|
||||
dependencies=related_dependencies,
|
||||
)
|
||||
for constraint in constraints:
|
||||
self.add_operation(
|
||||
app_label,
|
||||
operations.AddConstraint(
|
||||
model_name=model_name,
|
||||
constraint=constraint,
|
||||
),
|
||||
dependencies=related_dependencies,
|
||||
)
|
||||
if unique_together:
|
||||
self.add_operation(
|
||||
app_label,
|
||||
@@ -997,6 +1011,46 @@ class MigrationAutodetector:
|
||||
)
|
||||
)
|
||||
|
||||
def create_altered_constraints(self):
|
||||
option_name = operations.AddConstraint.option_name
|
||||
for app_label, model_name in sorted(self.kept_model_keys):
|
||||
old_model_name = self.renamed_models.get((app_label, model_name), model_name)
|
||||
old_model_state = self.from_state.models[app_label, old_model_name]
|
||||
new_model_state = self.to_state.models[app_label, model_name]
|
||||
|
||||
old_constraints = old_model_state.options[option_name]
|
||||
new_constraints = new_model_state.options[option_name]
|
||||
add_constraints = [c for c in new_constraints if c not in old_constraints]
|
||||
rem_constraints = [c for c in old_constraints if c not in new_constraints]
|
||||
|
||||
self.altered_constraints.update({
|
||||
(app_label, model_name): {
|
||||
'added_constraints': add_constraints, 'removed_constraints': rem_constraints,
|
||||
}
|
||||
})
|
||||
|
||||
def generate_added_constraints(self):
|
||||
for (app_label, model_name), alt_constraints in self.altered_constraints.items():
|
||||
for constraint in alt_constraints['added_constraints']:
|
||||
self.add_operation(
|
||||
app_label,
|
||||
operations.AddConstraint(
|
||||
model_name=model_name,
|
||||
constraint=constraint,
|
||||
)
|
||||
)
|
||||
|
||||
def generate_removed_constraints(self):
|
||||
for (app_label, model_name), alt_constraints in self.altered_constraints.items():
|
||||
for constraint in alt_constraints['removed_constraints']:
|
||||
self.add_operation(
|
||||
app_label,
|
||||
operations.RemoveConstraint(
|
||||
model_name=model_name,
|
||||
name=constraint.name,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_dependencies_for_foreign_key(self, field):
|
||||
# Account for FKs to swappable models
|
||||
swappable_setting = getattr(field, 'swappable_setting', None)
|
||||
|
@@ -1,8 +1,9 @@
|
||||
from .fields import AddField, AlterField, RemoveField, RenameField
|
||||
from .models import (
|
||||
AddIndex, AlterIndexTogether, AlterModelManagers, AlterModelOptions,
|
||||
AlterModelTable, AlterOrderWithRespectTo, AlterUniqueTogether, CreateModel,
|
||||
DeleteModel, RemoveIndex, RenameModel,
|
||||
AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
|
||||
AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
|
||||
AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
|
||||
RemoveIndex, RenameModel,
|
||||
)
|
||||
from .special import RunPython, RunSQL, SeparateDatabaseAndState
|
||||
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
|
||||
'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
|
||||
'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
|
||||
'AddConstraint', 'RemoveConstraint',
|
||||
'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
|
||||
'AlterOrderWithRespectTo', 'AlterModelManagers',
|
||||
]
|
||||
|
@@ -822,3 +822,72 @@ class RemoveIndex(IndexOperation):
|
||||
|
||||
def describe(self):
|
||||
return 'Remove index %s from %s' % (self.name, self.model_name)
|
||||
|
||||
|
||||
class AddConstraint(IndexOperation):
|
||||
option_name = 'constraints'
|
||||
|
||||
def __init__(self, model_name, constraint):
|
||||
self.model_name = model_name
|
||||
self.constraint = constraint
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
constraints = list(model_state.options[self.option_name])
|
||||
constraints.append(self.constraint)
|
||||
model_state.options[self.option_name] = constraints
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.add_constraint(model, self.constraint)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.remove_constraint(model, self.constraint)
|
||||
|
||||
def deconstruct(self):
|
||||
return self.__class__.__name__, [], {
|
||||
'model_name': self.model_name,
|
||||
'constraint': self.constraint,
|
||||
}
|
||||
|
||||
def describe(self):
|
||||
return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
|
||||
|
||||
|
||||
class RemoveConstraint(IndexOperation):
|
||||
option_name = 'constraints'
|
||||
|
||||
def __init__(self, model_name, name):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
constraints = model_state.options[self.option_name]
|
||||
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
from_model_state = from_state.models[app_label, self.model_name_lower]
|
||||
constraint = from_model_state.get_constraint_by_name(self.name)
|
||||
schema_editor.remove_constraint(model, constraint)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
to_model_state = to_state.models[app_label, self.model_name_lower]
|
||||
constraint = to_model_state.get_constraint_by_name(self.name)
|
||||
schema_editor.add_constraint(model, constraint)
|
||||
|
||||
def deconstruct(self):
|
||||
return self.__class__.__name__, [], {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
}
|
||||
|
||||
def describe(self):
|
||||
return 'Remove constraint %s from model %s' % (self.name, self.model_name)
|
||||
|
@@ -362,6 +362,7 @@ class ModelState:
|
||||
self.fields = fields
|
||||
self.options = options or {}
|
||||
self.options.setdefault('indexes', [])
|
||||
self.options.setdefault('constraints', [])
|
||||
self.bases = bases or (models.Model,)
|
||||
self.managers = managers or []
|
||||
# Sanity-check that fields is NOT a dict. It must be ordered.
|
||||
@@ -445,6 +446,8 @@ class ModelState:
|
||||
if not index.name:
|
||||
index.set_name_with_model(model)
|
||||
options['indexes'] = indexes
|
||||
elif name == 'constraints':
|
||||
options['constraints'] = [con.clone() for con in model._meta.constraints]
|
||||
else:
|
||||
options[name] = model._meta.original_attrs[name]
|
||||
# If we're ignoring relationships, remove all field-listing model
|
||||
@@ -585,6 +588,12 @@ class ModelState:
|
||||
return index
|
||||
raise ValueError("No index named %s on model %s" % (name, self.name))
|
||||
|
||||
def get_constraint_by_name(self, name):
|
||||
for constraint in self.options['constraints']:
|
||||
if constraint.name == name:
|
||||
return constraint
|
||||
raise ValueError('No constraint named %s on model %s' % (name, self.name))
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)
|
||||
|
||||
|
@@ -2,6 +2,8 @@ from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db.models import signals
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
from django.db.models.aggregates import __all__ as aggregates_all
|
||||
from django.db.models.constraints import * # NOQA
|
||||
from django.db.models.constraints import __all__ as constraints_all
|
||||
from django.db.models.deletion import (
|
||||
CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError,
|
||||
)
|
||||
@@ -30,7 +32,7 @@ from django.db.models.fields.related import ( # isort:skip
|
||||
)
|
||||
|
||||
|
||||
__all__ = aggregates_all + fields_all + indexes_all
|
||||
__all__ = aggregates_all + constraints_all + fields_all + indexes_all
|
||||
__all__ += [
|
||||
'ObjectDoesNotExist', 'signals',
|
||||
'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL',
|
||||
|
@@ -16,6 +16,7 @@ from django.db import (
|
||||
connections, router, transaction,
|
||||
)
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.constraints import CheckConstraint
|
||||
from django.db.models.deletion import CASCADE, Collector
|
||||
from django.db.models.fields.related import (
|
||||
ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation,
|
||||
@@ -1201,6 +1202,7 @@ class Model(metaclass=ModelBase):
|
||||
*cls._check_unique_together(),
|
||||
*cls._check_indexes(),
|
||||
*cls._check_ordering(),
|
||||
*cls._check_constraints(),
|
||||
]
|
||||
|
||||
return errors
|
||||
@@ -1699,6 +1701,29 @@ class Model(metaclass=ModelBase):
|
||||
|
||||
return errors
|
||||
|
||||
@classmethod
|
||||
def _check_constraints(cls):
|
||||
errors = []
|
||||
for db in settings.DATABASES:
|
||||
if not router.allow_migrate_model(db, cls):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if connection.features.supports_table_check_constraints:
|
||||
continue
|
||||
if any(isinstance(constraint, CheckConstraint) for constraint in cls._meta.constraints):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
'%s does not support check constraints.' % connection.display_name,
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this "
|
||||
"warning if you don't care about it."
|
||||
),
|
||||
obj=cls,
|
||||
id='models.W027',
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
|
||||
############################################
|
||||
# HELPER FUNCTIONS (CURRIED MODEL METHODS) #
|
||||
|
54
django/db/models/constraints.py
Normal file
54
django/db/models/constraints.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ['CheckConstraint']
|
||||
|
||||
|
||||
class CheckConstraint:
|
||||
def __init__(self, constraint, name):
|
||||
self.constraint = constraint
|
||||
self.name = name
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
query = Query(model)
|
||||
where = query.build_where(self.constraint)
|
||||
connection = schema_editor.connection
|
||||
compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default')
|
||||
sql, params = where.as_sql(compiler, connection)
|
||||
params = tuple(schema_editor.quote_value(p) for p in params)
|
||||
return schema_editor.sql_check % {
|
||||
'name': schema_editor.quote_name(self.name),
|
||||
'check': sql % params,
|
||||
}
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
sql = self.constraint_sql(model, schema_editor)
|
||||
return schema_editor.sql_create_check % {
|
||||
'table': schema_editor.quote_name(model._meta.db_table),
|
||||
'check': sql,
|
||||
}
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
quote_name = schema_editor.quote_name
|
||||
return schema_editor.sql_delete_check % {
|
||||
'table': quote_name(model._meta.db_table),
|
||||
'name': quote_name(self.name),
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: constraint='%s' name='%s'>" % (self.__class__.__name__, self.constraint, self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, CheckConstraint) and
|
||||
self.name == other.name and
|
||||
self.constraint == other.constraint
|
||||
)
|
||||
|
||||
def deconstruct(self):
|
||||
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace('django.db.models.constraints', 'django.db.models')
|
||||
return (path, (), {'constraint': self.constraint, 'name': self.name})
|
||||
|
||||
def clone(self):
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
@@ -505,8 +505,9 @@ class F(Combinable):
|
||||
def __repr__(self):
|
||||
return "{}({})".format(self.__class__.__name__, self.name)
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None,
|
||||
summarize=False, for_save=False, simple_col=False):
|
||||
return query.resolve_ref(self.name, allow_joins, reuse, summarize, simple_col)
|
||||
|
||||
def asc(self, **kwargs):
|
||||
return OrderBy(self, **kwargs)
|
||||
@@ -542,7 +543,8 @@ class ResolvedOuterRef(F):
|
||||
|
||||
|
||||
class OuterRef(F):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None,
|
||||
summarize=False, for_save=False, simple_col=False):
|
||||
if isinstance(self.name, self.__class__):
|
||||
return self.name
|
||||
return ResolvedOuterRef(self.name)
|
||||
@@ -746,6 +748,40 @@ class Col(Expression):
|
||||
self.target.get_db_converters(connection))
|
||||
|
||||
|
||||
class SimpleCol(Expression):
|
||||
"""
|
||||
Represents the SQL of a column name without the table name.
|
||||
|
||||
This variant of Col doesn't include the table name (or an alias) to
|
||||
avoid a syntax error in check constraints.
|
||||
"""
|
||||
contains_column_references = True
|
||||
|
||||
def __init__(self, target, output_field=None):
|
||||
if output_field is None:
|
||||
output_field = target
|
||||
super().__init__(output_field=output_field)
|
||||
self.target = target
|
||||
|
||||
def __repr__(self):
|
||||
return '{}({})'.format(self.__class__.__name__, self.target)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
qn = compiler.quote_name_unless_alias
|
||||
return qn(self.target.column), []
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [self]
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
if self.target == self.output_field:
|
||||
return self.output_field.get_db_converters(connection)
|
||||
return (
|
||||
self.output_field.get_db_converters(connection) +
|
||||
self.target.get_db_converters(connection)
|
||||
)
|
||||
|
||||
|
||||
class Ref(Expression):
|
||||
"""
|
||||
Reference to column alias of the query. For example, Ref('sum_cost') in
|
||||
|
@@ -32,7 +32,7 @@ DEFAULT_NAMES = (
|
||||
'auto_created', 'index_together', 'apps', 'default_permissions',
|
||||
'select_on_save', 'default_related_name', 'required_db_features',
|
||||
'required_db_vendor', 'base_manager_name', 'default_manager_name',
|
||||
'indexes',
|
||||
'indexes', 'constraints',
|
||||
# For backwards compatibility with Django 1.11. RemovedInDjango30Warning
|
||||
'manager_inheritance_from_future',
|
||||
)
|
||||
@@ -89,6 +89,7 @@ class Options:
|
||||
self.ordering = []
|
||||
self._ordering_clash = False
|
||||
self.indexes = []
|
||||
self.constraints = []
|
||||
self.unique_together = []
|
||||
self.index_together = []
|
||||
self.select_on_save = False
|
||||
|
@@ -18,7 +18,7 @@ from django.core.exceptions import (
|
||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
|
||||
from django.db.models.aggregates import Count
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import Col, Ref
|
||||
from django.db.models.expressions import Col, F, Ref, SimpleCol
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.fields.related_lookups import MultiColSource
|
||||
from django.db.models.lookups import Lookup
|
||||
@@ -62,6 +62,12 @@ JoinInfo = namedtuple(
|
||||
)
|
||||
|
||||
|
||||
def _get_col(target, field, alias, simple_col):
|
||||
if simple_col:
|
||||
return SimpleCol(target, field)
|
||||
return target.get_col(alias, field)
|
||||
|
||||
|
||||
class RawQuery:
|
||||
"""A single raw SQL query."""
|
||||
|
||||
@@ -1011,15 +1017,24 @@ class Query:
|
||||
def as_sql(self, compiler, connection):
|
||||
return self.get_compiler(connection=connection).as_sql()
|
||||
|
||||
def resolve_lookup_value(self, value, can_reuse, allow_joins):
|
||||
def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col):
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
|
||||
kwargs = {'reuse': can_reuse, 'allow_joins': allow_joins}
|
||||
if isinstance(value, F):
|
||||
kwargs['simple_col'] = simple_col
|
||||
value = value.resolve_expression(self, **kwargs)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# The items of the iterable may be expressions and therefore need
|
||||
# to be resolved independently.
|
||||
for sub_value in value:
|
||||
if hasattr(sub_value, 'resolve_expression'):
|
||||
sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
|
||||
if isinstance(sub_value, F):
|
||||
sub_value.resolve_expression(
|
||||
self, reuse=can_reuse, allow_joins=allow_joins,
|
||||
simple_col=simple_col,
|
||||
)
|
||||
else:
|
||||
sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
|
||||
return value
|
||||
|
||||
def solve_lookup_type(self, lookup):
|
||||
@@ -1133,7 +1148,7 @@ class Query:
|
||||
|
||||
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
||||
can_reuse=None, allow_joins=True, split_subq=True,
|
||||
reuse_with_filtered_relation=False):
|
||||
reuse_with_filtered_relation=False, simple_col=False):
|
||||
"""
|
||||
Build a WhereNode for a single filter clause but don't add it
|
||||
to this Query. Query.add_q() will then add this filter to the where
|
||||
@@ -1179,7 +1194,7 @@ class Query:
|
||||
raise FieldError("Joined field references are not permitted in this query")
|
||||
|
||||
pre_joins = self.alias_refcount.copy()
|
||||
value = self.resolve_lookup_value(value, can_reuse, allow_joins)
|
||||
value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col)
|
||||
used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)}
|
||||
|
||||
clause = self.where_class()
|
||||
@@ -1222,11 +1237,11 @@ class Query:
|
||||
if num_lookups > 1:
|
||||
raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
|
||||
if len(targets) == 1:
|
||||
col = targets[0].get_col(alias, join_info.final_field)
|
||||
col = _get_col(targets[0], join_info.final_field, alias, simple_col)
|
||||
else:
|
||||
col = MultiColSource(alias, targets, join_info.targets, join_info.final_field)
|
||||
else:
|
||||
col = targets[0].get_col(alias, join_info.final_field)
|
||||
col = _get_col(targets[0], join_info.final_field, alias, simple_col)
|
||||
|
||||
condition = self.build_lookup(lookups, col, value)
|
||||
lookup_type = condition.lookup_name
|
||||
@@ -1248,7 +1263,8 @@ class Query:
|
||||
# <=>
|
||||
# NOT (col IS NOT NULL AND col = someval).
|
||||
lookup_class = targets[0].get_lookup('isnull')
|
||||
clause.add(lookup_class(targets[0].get_col(alias, join_info.targets[0]), False), AND)
|
||||
col = _get_col(targets[0], join_info.targets[0], alias, simple_col)
|
||||
clause.add(lookup_class(col, False), AND)
|
||||
return clause, used_joins if not require_outer else ()
|
||||
|
||||
def add_filter(self, filter_clause):
|
||||
@@ -1271,8 +1287,12 @@ class Query:
|
||||
self.where.add(clause, AND)
|
||||
self.demote_joins(existing_inner)
|
||||
|
||||
def build_where(self, q_object):
|
||||
return self._add_q(q_object, used_aliases=set(), allow_joins=False, simple_col=True)[0]
|
||||
|
||||
def _add_q(self, q_object, used_aliases, branch_negated=False,
|
||||
current_negated=False, allow_joins=True, split_subq=True):
|
||||
current_negated=False, allow_joins=True, split_subq=True,
|
||||
simple_col=False):
|
||||
"""Add a Q-object to the current filter."""
|
||||
connector = q_object.connector
|
||||
current_negated = current_negated ^ q_object.negated
|
||||
@@ -1290,7 +1310,7 @@ class Query:
|
||||
child_clause, needed_inner = self.build_filter(
|
||||
child, can_reuse=used_aliases, branch_negated=branch_negated,
|
||||
current_negated=current_negated, allow_joins=allow_joins,
|
||||
split_subq=split_subq,
|
||||
split_subq=split_subq, simple_col=simple_col,
|
||||
)
|
||||
joinpromoter.add_votes(needed_inner)
|
||||
if child_clause:
|
||||
@@ -1559,7 +1579,7 @@ class Query:
|
||||
self.unref_alias(joins.pop())
|
||||
return targets, joins[-1], joins
|
||||
|
||||
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False):
|
||||
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simple_col=False):
|
||||
if not allow_joins and LOOKUP_SEP in name:
|
||||
raise FieldError("Joined field references are not permitted in this query")
|
||||
if name in self.annotations:
|
||||
@@ -1580,7 +1600,7 @@ class Query:
|
||||
"isn't supported")
|
||||
if reuse is not None:
|
||||
reuse.update(join_list)
|
||||
col = targets[0].get_col(join_list[-1], join_info.targets[0])
|
||||
col = _get_col(targets[0], join_info.targets[0], join_list[-1], simple_col)
|
||||
return col
|
||||
|
||||
def split_exclude(self, filter_expr, can_reuse, names_with_path):
|
||||
|
@@ -297,6 +297,7 @@ Models
|
||||
field accessor.
|
||||
* **models.E026**: The model cannot have more than one field with
|
||||
``primary_key=True``.
|
||||
* **models.W027**: ``<database>`` does not support check constraints.
|
||||
|
||||
Security
|
||||
--------
|
||||
|
@@ -207,6 +207,25 @@ Creates an index in the database table for the model with ``model_name``.
|
||||
|
||||
Removes the index named ``name`` from the model with ``model_name``.
|
||||
|
||||
``AddConstraint``
|
||||
-----------------
|
||||
|
||||
.. class:: AddConstraint(model_name, constraint)
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
Creates a constraint in the database table for the model with ``model_name``.
|
||||
``constraint`` is an instance of :class:`~django.db.models.CheckConstraint`.
|
||||
|
||||
``RemoveConstraint``
|
||||
--------------------
|
||||
|
||||
.. class:: RemoveConstraint(model_name, name)
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
Removes the constraint named ``name`` from the model with ``model_name``.
|
||||
|
||||
Special Operations
|
||||
==================
|
||||
|
||||
|
46
docs/ref/models/check-constraints.txt
Normal file
46
docs/ref/models/check-constraints.txt
Normal file
@@ -0,0 +1,46 @@
|
||||
===========================
|
||||
Check constraints reference
|
||||
===========================
|
||||
|
||||
.. module:: django.db.models.constraints
|
||||
|
||||
.. currentmodule:: django.db.models
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
The ``CheckConstraint`` class creates database check constraints. They are
|
||||
added in the model :attr:`Meta.constraints
|
||||
<django.db.models.Options.constraints>` option. This document
|
||||
explains the API references of :class:`CheckConstraint`.
|
||||
|
||||
.. admonition:: Referencing built-in constraints
|
||||
|
||||
Constraints are defined in ``django.db.models.constraints``, but for
|
||||
convenience they're imported into :mod:`django.db.models`. The standard
|
||||
convention is to use ``from django.db import models`` and refer to the
|
||||
constraints as ``models.CheckConstraint``.
|
||||
|
||||
``CheckConstraint`` options
|
||||
===========================
|
||||
|
||||
.. class:: CheckConstraint(constraint, name)
|
||||
|
||||
Creates a check constraint in the database.
|
||||
|
||||
``constraint``
|
||||
--------------
|
||||
|
||||
.. attribute:: CheckConstraint.constraint
|
||||
|
||||
A :class:`Q` object that specifies the condition you want the constraint to
|
||||
enforce.
|
||||
|
||||
For example ``CheckConstraint(Q(age__gte=18), 'age_gte_18')`` ensures the age
|
||||
field is never less than 18.
|
||||
|
||||
``name``
|
||||
--------
|
||||
|
||||
.. attribute:: CheckConstraint.name
|
||||
|
||||
The name of the constraint.
|
@@ -9,6 +9,7 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`.
|
||||
|
||||
fields
|
||||
indexes
|
||||
check-constraints
|
||||
meta
|
||||
relations
|
||||
class
|
||||
|
@@ -451,6 +451,26 @@ Django quotes column and table names behind the scenes.
|
||||
|
||||
index_together = ["pub_date", "deadline"]
|
||||
|
||||
``constraints``
|
||||
---------------
|
||||
|
||||
.. attribute:: Options.constraints
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
A list of :doc:`constraints </ref/models/check-constraints>` that you want
|
||||
to define on the model::
|
||||
|
||||
from django.db import models
|
||||
|
||||
class Customer(models.Model):
|
||||
age = models.IntegerField()
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
models.CheckConstraint(models.Q(age__gte=18), 'age_gte_18'),
|
||||
]
|
||||
|
||||
``verbose_name``
|
||||
----------------
|
||||
|
||||
|
@@ -30,6 +30,13 @@ officially support the latest release of each series.
|
||||
What's new in Django 2.2
|
||||
========================
|
||||
|
||||
Check Constraints
|
||||
-----------------
|
||||
|
||||
The new :class:`~django.db.models.CheckConstraint` class enables adding custom
|
||||
database constraints. Constraints are added to models using the
|
||||
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
|
||||
|
||||
Minor features
|
||||
--------------
|
||||
|
||||
@@ -213,7 +220,9 @@ Backwards incompatible changes in 2.2
|
||||
Database backend API
|
||||
--------------------
|
||||
|
||||
* ...
|
||||
* Third-party database backends must implement support for table check
|
||||
constraints or set ``DatabaseFeatures.supports_table_check_constraints`` to
|
||||
``False``.
|
||||
|
||||
:mod:`django.contrib.gis`
|
||||
-------------------------
|
||||
|
0
tests/constraints/__init__.py
Normal file
0
tests/constraints/__init__.py
Normal file
15
tests/constraints/models.py
Normal file
15
tests/constraints/models.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Product(models.Model):
|
||||
name = models.CharField(max_length=255)
|
||||
price = models.IntegerField()
|
||||
discounted_price = models.IntegerField()
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
models.CheckConstraint(
|
||||
models.Q(price__gt=models.F('discounted_price')),
|
||||
'price_gt_discounted_price'
|
||||
)
|
||||
]
|
30
tests/constraints/tests.py
Normal file
30
tests/constraints/tests.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from django.db import IntegrityError, models
|
||||
from django.test import TestCase, skipUnlessDBFeature
|
||||
|
||||
from .models import Product
|
||||
|
||||
|
||||
class CheckConstraintTests(TestCase):
|
||||
def test_repr(self):
|
||||
constraint = models.Q(price__gt=models.F('discounted_price'))
|
||||
name = 'price_gt_discounted_price'
|
||||
check = models.CheckConstraint(constraint, name)
|
||||
self.assertEqual(
|
||||
repr(check),
|
||||
"<CheckConstraint: constraint='{}' name='{}'>".format(constraint, name),
|
||||
)
|
||||
|
||||
def test_deconstruction(self):
|
||||
constraint = models.Q(price__gt=models.F('discounted_price'))
|
||||
name = 'price_gt_discounted_price'
|
||||
check = models.CheckConstraint(constraint, name)
|
||||
path, args, kwargs = check.deconstruct()
|
||||
self.assertEqual(path, 'django.db.models.CheckConstraint')
|
||||
self.assertEqual(args, ())
|
||||
self.assertEqual(kwargs, {'constraint': constraint, 'name': name})
|
||||
|
||||
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||
def test_model_constraint(self):
|
||||
Product.objects.create(name='Valid', price=10, discounted_price=5)
|
||||
with self.assertRaises(IntegrityError):
|
||||
Product.objects.create(name='Invalid', price=10, discounted_price=20)
|
@@ -1,10 +1,10 @@
|
||||
import unittest
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.checks import Error
|
||||
from django.core.checks import Error, Warning
|
||||
from django.core.checks.model_checks import _check_lazy_references
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import connections, models
|
||||
from django.db import connection, connections, models
|
||||
from django.db.models.signals import post_init
|
||||
from django.test import SimpleTestCase
|
||||
from django.test.utils import isolate_apps, override_settings
|
||||
@@ -972,3 +972,26 @@ class OtherModelTests(SimpleTestCase):
|
||||
id='signals.E001',
|
||||
),
|
||||
])
|
||||
|
||||
|
||||
@isolate_apps('invalid_models_tests')
|
||||
class ConstraintsTests(SimpleTestCase):
|
||||
def test_check_constraints(self):
|
||||
class Model(models.Model):
|
||||
age = models.IntegerField()
|
||||
|
||||
class Meta:
|
||||
constraints = [models.CheckConstraint(models.Q(age__gte=18), 'is_adult')]
|
||||
|
||||
errors = Model.check()
|
||||
warn = Warning(
|
||||
'%s does not support check constraints.' % connection.display_name,
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this warning if you "
|
||||
"don't care about it."
|
||||
),
|
||||
obj=Model,
|
||||
id='models.W027',
|
||||
)
|
||||
expected = [] if connection.features.supports_table_check_constraints else [warn, warn]
|
||||
self.assertCountEqual(errors, expected)
|
||||
|
@@ -61,6 +61,12 @@ class AutodetectorTests(TestCase):
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("name", models.CharField(max_length=200, default='Ada Lovelace')),
|
||||
])
|
||||
author_name_check_constraint = ModelState("testapp", "Author", [
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("name", models.CharField(max_length=200)),
|
||||
],
|
||||
{'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]},
|
||||
)
|
||||
author_dates_of_birth_auto_now = ModelState("testapp", "Author", [
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("date_of_birth", models.DateField(auto_now=True)),
|
||||
@@ -1389,6 +1395,40 @@ class AutodetectorTests(TestCase):
|
||||
added_index = models.Index(fields=['title', 'author'], name='book_author_title_idx')
|
||||
self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='book', index=added_index)
|
||||
|
||||
def test_create_model_with_check_constraint(self):
|
||||
"""Test creation of new model with constraints already defined."""
|
||||
author = ModelState('otherapp', 'Author', [
|
||||
('id', models.AutoField(primary_key=True)),
|
||||
('name', models.CharField(max_length=200)),
|
||||
], {'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]})
|
||||
changes = self.get_changes([], [author])
|
||||
added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')
|
||||
# Right number of migrations?
|
||||
self.assertEqual(len(changes['otherapp']), 1)
|
||||
# Right number of actions?
|
||||
migration = changes['otherapp'][0]
|
||||
self.assertEqual(len(migration.operations), 2)
|
||||
# Right actions order?
|
||||
self.assertOperationTypes(changes, 'otherapp', 0, ['CreateModel', 'AddConstraint'])
|
||||
self.assertOperationAttributes(changes, 'otherapp', 0, 0, name='Author')
|
||||
self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='author', constraint=added_constraint)
|
||||
|
||||
def test_add_constraints(self):
|
||||
"""Test change detection of new constraints."""
|
||||
changes = self.get_changes([self.author_name], [self.author_name_check_constraint])
|
||||
self.assertNumberMigrations(changes, 'testapp', 1)
|
||||
self.assertOperationTypes(changes, 'testapp', 0, ['AddConstraint'])
|
||||
added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')
|
||||
self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', constraint=added_constraint)
|
||||
|
||||
def test_remove_constraints(self):
|
||||
"""Test change detection of removed constraints."""
|
||||
changes = self.get_changes([self.author_name_check_constraint], [self.author_name])
|
||||
# Right number/type of migrations?
|
||||
self.assertNumberMigrations(changes, 'testapp', 1)
|
||||
self.assertOperationTypes(changes, 'testapp', 0, ['RemoveConstraint'])
|
||||
self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', name='name_contains_bob')
|
||||
|
||||
def test_add_foo_together(self):
|
||||
"""Tests index/unique_together detection."""
|
||||
changes = self.get_changes([self.author_empty, self.book], [self.author_empty, self.book_foo_together])
|
||||
@@ -1520,7 +1560,7 @@ class AutodetectorTests(TestCase):
|
||||
self.assertNumberMigrations(changes, "testapp", 1)
|
||||
self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"])
|
||||
self.assertOperationAttributes(
|
||||
changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": []}
|
||||
changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": [], "constraints": []}
|
||||
)
|
||||
# Now, we test turning a proxy model into a non-proxy model
|
||||
# It should delete the proxy then make the real one
|
||||
|
@@ -67,6 +67,17 @@ class MigrationTestBase(TransactionTestCase):
|
||||
def assertIndexNotExists(self, table, columns):
|
||||
return self.assertIndexExists(table, columns, False)
|
||||
|
||||
def assertConstraintExists(self, table, name, value=True, using='default'):
|
||||
with connections[using].cursor() as cursor:
|
||||
constraints = connections[using].introspection.get_constraints(cursor, table).items()
|
||||
self.assertEqual(
|
||||
value,
|
||||
any(c['check'] for n, c in constraints if n == name),
|
||||
)
|
||||
|
||||
def assertConstraintNotExists(self, table, name):
|
||||
return self.assertConstraintExists(table, name, False)
|
||||
|
||||
def assertFKExists(self, table, columns, to, value=True, using='default'):
|
||||
with connections[using].cursor() as cursor:
|
||||
self.assertEqual(
|
||||
|
@@ -53,7 +53,7 @@ class OperationTestBase(MigrationTestBase):
|
||||
def set_up_test_model(
|
||||
self, app_label, second_model=False, third_model=False, index=False, multicol_index=False,
|
||||
related_model=False, mti_model=False, proxy_model=False, manager_model=False,
|
||||
unique_together=False, options=False, db_table=None, index_together=False):
|
||||
unique_together=False, options=False, db_table=None, index_together=False, check_constraint=False):
|
||||
"""
|
||||
Creates a test model state and database table.
|
||||
"""
|
||||
@@ -106,6 +106,11 @@ class OperationTestBase(MigrationTestBase):
|
||||
"Pony",
|
||||
models.Index(fields=["pink", "weight"], name="pony_test_idx")
|
||||
))
|
||||
if check_constraint:
|
||||
operations.append(migrations.AddConstraint(
|
||||
"Pony",
|
||||
models.CheckConstraint(models.Q(pink__gt=2), name="pony_test_constraint")
|
||||
))
|
||||
if second_model:
|
||||
operations.append(migrations.CreateModel(
|
||||
"Stable",
|
||||
@@ -462,6 +467,45 @@ class OperationTests(OperationTestBase):
|
||||
self.assertTableNotExists("test_crummo_unmanagedpony")
|
||||
self.assertTableExists("test_crummo_pony")
|
||||
|
||||
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||
def test_create_model_with_constraint(self):
|
||||
where = models.Q(pink__gt=2)
|
||||
check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2')
|
||||
operation = migrations.CreateModel(
|
||||
"Pony",
|
||||
[
|
||||
("id", models.AutoField(primary_key=True)),
|
||||
("pink", models.IntegerField(default=3)),
|
||||
],
|
||||
options={'constraints': [check_constraint]},
|
||||
)
|
||||
|
||||
# Test the state alteration
|
||||
project_state = ProjectState()
|
||||
new_state = project_state.clone()
|
||||
operation.state_forwards("test_crmo", new_state)
|
||||
self.assertEqual(len(new_state.models['test_crmo', 'pony'].options['constraints']), 1)
|
||||
|
||||
# Test database alteration
|
||||
self.assertTableNotExists("test_crmo_pony")
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_forwards("test_crmo", editor, project_state, new_state)
|
||||
self.assertTableExists("test_crmo_pony")
|
||||
with connection.cursor() as cursor:
|
||||
with self.assertRaises(IntegrityError):
|
||||
cursor.execute("INSERT INTO test_crmo_pony (id, pink) VALUES (1, 1)")
|
||||
|
||||
# Test reversal
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_backwards("test_crmo", editor, new_state, project_state)
|
||||
self.assertTableNotExists("test_crmo_pony")
|
||||
|
||||
# Test deconstruction
|
||||
definition = operation.deconstruct()
|
||||
self.assertEqual(definition[0], "CreateModel")
|
||||
self.assertEqual(definition[1], [])
|
||||
self.assertEqual(definition[2]['options']['constraints'], [check_constraint])
|
||||
|
||||
def test_create_model_managers(self):
|
||||
"""
|
||||
The managers on a model are set.
|
||||
@@ -1708,6 +1752,87 @@ class OperationTests(OperationTestBase):
|
||||
operation = migrations.AlterIndexTogether("Pony", None)
|
||||
self.assertEqual(operation.describe(), "Alter index_together for Pony (0 constraint(s))")
|
||||
|
||||
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||
def test_add_constraint(self):
|
||||
"""Test the AddConstraint operation."""
|
||||
project_state = self.set_up_test_model('test_addconstraint')
|
||||
|
||||
where = models.Q(pink__gt=2)
|
||||
check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2')
|
||||
operation = migrations.AddConstraint('Pony', check_constraint)
|
||||
self.assertEqual(operation.describe(), 'Create constraint test_constraint_pony_pink_gt_2 on model Pony')
|
||||
|
||||
new_state = project_state.clone()
|
||||
operation.state_forwards('test_addconstraint', new_state)
|
||||
self.assertEqual(len(new_state.models['test_addconstraint', 'pony'].options['constraints']), 1)
|
||||
|
||||
# Test database alteration
|
||||
with connection.cursor() as cursor:
|
||||
with atomic():
|
||||
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
|
||||
cursor.execute("DELETE FROM test_addconstraint_pony")
|
||||
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_forwards("test_addconstraint", editor, project_state, new_state)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
with self.assertRaises(IntegrityError):
|
||||
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
|
||||
|
||||
# Test reversal
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_backwards("test_addconstraint", editor, new_state, project_state)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
with atomic():
|
||||
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
|
||||
cursor.execute("DELETE FROM test_addconstraint_pony")
|
||||
|
||||
# Test deconstruction
|
||||
definition = operation.deconstruct()
|
||||
self.assertEqual(definition[0], "AddConstraint")
|
||||
self.assertEqual(definition[1], [])
|
||||
self.assertEqual(definition[2], {'model_name': "Pony", 'constraint': check_constraint})
|
||||
|
||||
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||
def test_remove_constraint(self):
|
||||
"""Test the RemoveConstraint operation."""
|
||||
project_state = self.set_up_test_model("test_removeconstraint", check_constraint=True)
|
||||
self.assertTableExists("test_removeconstraint_pony")
|
||||
operation = migrations.RemoveConstraint("Pony", "pony_test_constraint")
|
||||
self.assertEqual(operation.describe(), "Remove constraint pony_test_constraint from model Pony")
|
||||
new_state = project_state.clone()
|
||||
operation.state_forwards("test_removeconstraint", new_state)
|
||||
# Test state alteration
|
||||
self.assertEqual(len(new_state.models["test_removeconstraint", "pony"].options['constraints']), 0)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
with self.assertRaises(IntegrityError):
|
||||
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
|
||||
|
||||
# Test database alteration
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_forwards("test_removeconstraint", editor, project_state, new_state)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
with atomic():
|
||||
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
|
||||
cursor.execute("DELETE FROM test_removeconstraint_pony")
|
||||
|
||||
# Test reversal
|
||||
with connection.schema_editor() as editor:
|
||||
operation.database_backwards("test_removeconstraint", editor, new_state, project_state)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
with self.assertRaises(IntegrityError):
|
||||
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
|
||||
|
||||
# Test deconstruction
|
||||
definition = operation.deconstruct()
|
||||
self.assertEqual(definition[0], "RemoveConstraint")
|
||||
self.assertEqual(definition[1], [])
|
||||
self.assertEqual(definition[2], {'model_name': "Pony", 'name': "pony_test_constraint"})
|
||||
|
||||
def test_alter_model_options(self):
|
||||
"""
|
||||
Tests the AlterModelOptions operation.
|
||||
|
@@ -127,7 +127,12 @@ class StateTests(SimpleTestCase):
|
||||
self.assertIs(author_state.fields[3][1].null, True)
|
||||
self.assertEqual(
|
||||
author_state.options,
|
||||
{"unique_together": {("name", "bio")}, "index_together": {("bio", "age")}, "indexes": []}
|
||||
{
|
||||
"unique_together": {("name", "bio")},
|
||||
"index_together": {("bio", "age")},
|
||||
"indexes": [],
|
||||
"constraints": [],
|
||||
}
|
||||
)
|
||||
self.assertEqual(author_state.bases, (models.Model,))
|
||||
|
||||
@@ -139,14 +144,17 @@ class StateTests(SimpleTestCase):
|
||||
self.assertEqual(book_state.fields[3][1].__class__.__name__, "ManyToManyField")
|
||||
self.assertEqual(
|
||||
book_state.options,
|
||||
{"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index]},
|
||||
{"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index], "constraints": []},
|
||||
)
|
||||
self.assertEqual(book_state.bases, (models.Model,))
|
||||
|
||||
self.assertEqual(author_proxy_state.app_label, "migrations")
|
||||
self.assertEqual(author_proxy_state.name, "AuthorProxy")
|
||||
self.assertEqual(author_proxy_state.fields, [])
|
||||
self.assertEqual(author_proxy_state.options, {"proxy": True, "ordering": ["name"], "indexes": []})
|
||||
self.assertEqual(
|
||||
author_proxy_state.options,
|
||||
{"proxy": True, "ordering": ["name"], "indexes": [], "constraints": []},
|
||||
)
|
||||
self.assertEqual(author_proxy_state.bases, ("migrations.author",))
|
||||
|
||||
self.assertEqual(sub_author_state.app_label, "migrations")
|
||||
@@ -1002,7 +1010,7 @@ class ModelStateTests(SimpleTestCase):
|
||||
self.assertEqual(author_state.fields[1][1].max_length, 255)
|
||||
self.assertIs(author_state.fields[2][1].null, False)
|
||||
self.assertIs(author_state.fields[3][1].null, True)
|
||||
self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []})
|
||||
self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], "constraints": []})
|
||||
self.assertEqual(author_state.bases, (models.Model,))
|
||||
self.assertEqual(author_state.managers, [])
|
||||
|
||||
@@ -1047,7 +1055,7 @@ class ModelStateTests(SimpleTestCase):
|
||||
self.assertEqual(station_state.fields[2][1].null, False)
|
||||
self.assertEqual(
|
||||
station_state.options,
|
||||
{'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []}
|
||||
{'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], 'constraints': []}
|
||||
)
|
||||
self.assertEqual(station_state.bases, ('migrations.searchablelocation',))
|
||||
self.assertEqual(station_state.managers, [])
|
||||
@@ -1129,6 +1137,21 @@ class ModelStateTests(SimpleTestCase):
|
||||
index_names = [index.name for index in model_state.options['indexes']]
|
||||
self.assertEqual(index_names, ['foo_idx'])
|
||||
|
||||
@isolate_apps('migrations')
|
||||
def test_from_model_constraints(self):
|
||||
class ModelWithConstraints(models.Model):
|
||||
size = models.IntegerField()
|
||||
|
||||
class Meta:
|
||||
constraints = [models.CheckConstraint(models.Q(size__gt=1), 'size_gt_1')]
|
||||
|
||||
state = ModelState.from_model(ModelWithConstraints)
|
||||
model_constraints = ModelWithConstraints._meta.constraints
|
||||
state_constraints = state.options['constraints']
|
||||
self.assertEqual(model_constraints, state_constraints)
|
||||
self.assertIsNot(model_constraints, state_constraints)
|
||||
self.assertIsNot(model_constraints[0], state_constraints[0])
|
||||
|
||||
|
||||
class RelatedModelsTests(SimpleTestCase):
|
||||
|
||||
|
95
tests/queries/test_query.py
Normal file
95
tests/queries/test_query.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models import CharField, F, Q
|
||||
from django.db.models.expressions import SimpleCol
|
||||
from django.db.models.fields.related_lookups import RelatedIsNull
|
||||
from django.db.models.functions import Lower
|
||||
from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan
|
||||
from django.db.models.sql.query import Query
|
||||
from django.db.models.sql.where import OR
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import Author, Item, ObjectC, Ranking
|
||||
|
||||
|
||||
class TestQuery(TestCase):
|
||||
def test_simple_query(self):
|
||||
query = Query(Author)
|
||||
where = query.build_where(Q(num__gt=2))
|
||||
lookup = where.children[0]
|
||||
self.assertIsInstance(lookup, GreaterThan)
|
||||
self.assertEqual(lookup.rhs, 2)
|
||||
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
|
||||
|
||||
def test_complex_query(self):
|
||||
query = Query(Author)
|
||||
where = query.build_where(Q(num__gt=2) | Q(num__lt=0))
|
||||
self.assertEqual(where.connector, OR)
|
||||
|
||||
lookup = where.children[0]
|
||||
self.assertIsInstance(lookup, GreaterThan)
|
||||
self.assertEqual(lookup.rhs, 2)
|
||||
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
|
||||
|
||||
lookup = where.children[1]
|
||||
self.assertIsInstance(lookup, LessThan)
|
||||
self.assertEqual(lookup.rhs, 0)
|
||||
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
|
||||
|
||||
def test_multiple_fields(self):
|
||||
query = Query(Item)
|
||||
where = query.build_where(Q(modified__gt=F('created')))
|
||||
lookup = where.children[0]
|
||||
self.assertIsInstance(lookup, GreaterThan)
|
||||
self.assertIsInstance(lookup.rhs, SimpleCol)
|
||||
self.assertIsInstance(lookup.lhs, SimpleCol)
|
||||
self.assertEqual(lookup.rhs.target, Item._meta.get_field('created'))
|
||||
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
|
||||
|
||||
def test_transform(self):
|
||||
query = Query(Author)
|
||||
CharField.register_lookup(Lower, 'lower')
|
||||
try:
|
||||
where = query.build_where(~Q(name__lower='foo'))
|
||||
finally:
|
||||
CharField._unregister_lookup(Lower, 'lower')
|
||||
lookup = where.children[0]
|
||||
self.assertIsInstance(lookup, Exact)
|
||||
self.assertIsInstance(lookup.lhs, Lower)
|
||||
self.assertIsInstance(lookup.lhs.lhs, SimpleCol)
|
||||
self.assertEqual(lookup.lhs.lhs.target, Author._meta.get_field('name'))
|
||||
|
||||
def test_negated_nullable(self):
|
||||
query = Query(Item)
|
||||
where = query.build_where(~Q(modified__lt=datetime(2017, 1, 1)))
|
||||
self.assertTrue(where.negated)
|
||||
lookup = where.children[0]
|
||||
self.assertIsInstance(lookup, LessThan)
|
||||
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
|
||||
lookup = where.children[1]
|
||||
self.assertIsInstance(lookup, IsNull)
|
||||
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
|
||||
|
||||
def test_foreign_key(self):
|
||||
query = Query(Item)
|
||||
msg = 'Joined field references are not permitted in this query'
|
||||
with self.assertRaisesMessage(FieldError, msg):
|
||||
query.build_where(Q(creator__num__gt=2))
|
||||
|
||||
def test_foreign_key_f(self):
|
||||
query = Query(Ranking)
|
||||
with self.assertRaises(FieldError):
|
||||
query.build_where(Q(rank__gt=F('author__num')))
|
||||
|
||||
def test_foreign_key_exclusive(self):
|
||||
query = Query(ObjectC)
|
||||
where = query.build_where(Q(objecta=None) | Q(objectb=None))
|
||||
a_isnull = where.children[0]
|
||||
self.assertIsInstance(a_isnull, RelatedIsNull)
|
||||
self.assertIsInstance(a_isnull.lhs, SimpleCol)
|
||||
self.assertEqual(a_isnull.lhs.target, ObjectC._meta.get_field('objecta'))
|
||||
b_isnull = where.children[1]
|
||||
self.assertIsInstance(b_isnull, RelatedIsNull)
|
||||
self.assertIsInstance(b_isnull.lhs, SimpleCol)
|
||||
self.assertEqual(b_isnull.lhs.target, ObjectC._meta.get_field('objectb'))
|
Reference in New Issue
Block a user