1
0
mirror of https://github.com/django/django.git synced 2025-01-22 00:02:15 +00:00

Fixed #11964 -- Added support for database check constraints.

This commit is contained in:
Ian Foote 2016-11-05 13:12:12 +00:00 committed by Tim Graham
parent 6fbfb5cb96
commit 952f05a6db
29 changed files with 799 additions and 39 deletions

View File

@ -172,6 +172,7 @@ class BaseDatabaseFeatures:
# Does it support CHECK constraints? # Does it support CHECK constraints?
supports_column_check_constraints = True supports_column_check_constraints = True
supports_table_check_constraints = True
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value}) # Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
# parameter passing? Note this can be provided by the backend even if not # parameter passing? Note this can be provided by the backend even if not

View File

@ -63,7 +63,8 @@ class BaseDatabaseSchemaEditor:
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_check = "CONSTRAINT %(name)s CHECK (%(check)s)"
sql_create_check = "ALTER TABLE %(table)s ADD %(check)s"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)" sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
@ -299,10 +300,11 @@ class BaseDatabaseSchemaEditor:
for fields in model._meta.unique_together: for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields] columns = [model._meta.get_field(field).column for field in fields]
self.deferred_sql.append(self._create_unique_sql(model, columns)) self.deferred_sql.append(self._create_unique_sql(model, columns))
constraints = [check.constraint_sql(model, self) for check in model._meta.constraints]
# Make the table # Make the table
sql = self.sql_create_table % { sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table), "table": self.quote_name(model._meta.db_table),
"definition": ", ".join(column_sqls) "definition": ", ".join((*column_sqls, *constraints)),
} }
if model._meta.db_tablespace: if model._meta.db_tablespace:
tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace) tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
@ -343,6 +345,14 @@ class BaseDatabaseSchemaEditor:
"""Remove an index from a model.""" """Remove an index from a model."""
self.execute(index.remove_sql(model, self)) self.execute(index.remove_sql(model, self))
def add_constraint(self, model, constraint):
"""Add a check constraint to a model."""
self.execute(constraint.create_sql(model, self))
def remove_constraint(self, model, constraint):
"""Remove a check constraint from a model."""
self.execute(constraint.remove_sql(model, self))
def alter_unique_together(self, model, old_unique_together, new_unique_together): def alter_unique_together(self, model, old_unique_together, new_unique_together):
""" """
Deal with a model changing its unique_together. The input Deal with a model changing its unique_together. The input
@ -752,11 +762,12 @@ class BaseDatabaseSchemaEditor:
self.execute( self.execute(
self.sql_create_check % { self.sql_create_check % {
"table": self.quote_name(model._meta.db_table), "table": self.quote_name(model._meta.db_table),
"name": self.quote_name( "check": self.sql_check % {
self._create_index_name(model._meta.db_table, [new_field.column], suffix="_check") 'name': self.quote_name(
), self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'),
"column": self.quote_name(new_field.column), ),
"check": new_db_params['check'], 'check': new_db_params['check'],
},
} }
) )
# Drop the default if we need to # Drop the default if we need to

View File

@ -26,6 +26,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_release_savepoints = True can_release_savepoints = True
atomic_transactions = False atomic_transactions = False
supports_column_check_constraints = False supports_column_check_constraints = False
supports_table_check_constraints = False
can_clone_databases = True can_clone_databases = True
supports_temporal_subtraction = True supports_temporal_subtraction = True
supports_select_intersection = False supports_select_intersection = False

View File

@ -126,7 +126,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
else: else:
super().alter_field(model, old_field, new_field, strict=strict) super().alter_field(model, old_field, new_field, strict=strict)
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None): def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None,
add_constraint=None, remove_constraint=None):
""" """
Shortcut to transform a model from old_model into new_model Shortcut to transform a model from old_model into new_model
@ -222,6 +223,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if delete_field.name not in index.fields if delete_field.name not in index.fields
] ]
constraints = list(model._meta.constraints)
if add_constraint:
constraints.append(add_constraint)
if remove_constraint:
constraints = [
constraint for constraint in constraints
if remove_constraint.name != constraint.name
]
# Construct a new model for the new state # Construct a new model for the new state
meta_contents = { meta_contents = {
'app_label': model._meta.app_label, 'app_label': model._meta.app_label,
@ -229,6 +239,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
'unique_together': unique_together, 'unique_together': unique_together,
'index_together': index_together, 'index_together': index_together,
'indexes': indexes, 'indexes': indexes,
'constraints': constraints,
'apps': apps, 'apps': apps,
} }
meta = type("Meta", (), meta_contents) meta = type("Meta", (), meta_contents)
@ -362,3 +373,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
)) ))
# Delete the old through table # Delete the old through table
self.delete_model(old_field.remote_field.through) self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint):
self._remake_table(model, add_constraint=constraint)
def remove_constraint(self, model, constraint):
self._remake_table(model, remove_constraint=constraint)

View File

@ -122,6 +122,7 @@ class MigrationAutodetector:
# resolve dependencies caused by M2Ms and FKs. # resolve dependencies caused by M2Ms and FKs.
self.generated_operations = {} self.generated_operations = {}
self.altered_indexes = {} self.altered_indexes = {}
self.altered_constraints = {}
# Prepare some old/new state and model lists, separating # Prepare some old/new state and model lists, separating
# proxy models and ignoring unmigrated apps. # proxy models and ignoring unmigrated apps.
@ -175,7 +176,9 @@ class MigrationAutodetector:
# This avoids the same computation in generate_removed_indexes() # This avoids the same computation in generate_removed_indexes()
# and generate_added_indexes(). # and generate_added_indexes().
self.create_altered_indexes() self.create_altered_indexes()
self.create_altered_constraints()
# Generate index removal operations before field is removed # Generate index removal operations before field is removed
self.generate_removed_constraints()
self.generate_removed_indexes() self.generate_removed_indexes()
# Generate field operations # Generate field operations
self.generate_renamed_fields() self.generate_renamed_fields()
@ -185,6 +188,7 @@ class MigrationAutodetector:
self.generate_altered_unique_together() self.generate_altered_unique_together()
self.generate_altered_index_together() self.generate_altered_index_together()
self.generate_added_indexes() self.generate_added_indexes()
self.generate_added_constraints()
self.generate_altered_db_table() self.generate_altered_db_table()
self.generate_altered_order_with_respect_to() self.generate_altered_order_with_respect_to()
@ -533,6 +537,7 @@ class MigrationAutodetector:
related_fields[field.name] = field related_fields[field.name] = field
# Are there indexes/unique|index_together to defer? # Are there indexes/unique|index_together to defer?
indexes = model_state.options.pop('indexes') indexes = model_state.options.pop('indexes')
constraints = model_state.options.pop('constraints')
unique_together = model_state.options.pop('unique_together', None) unique_together = model_state.options.pop('unique_together', None)
index_together = model_state.options.pop('index_together', None) index_together = model_state.options.pop('index_together', None)
order_with_respect_to = model_state.options.pop('order_with_respect_to', None) order_with_respect_to = model_state.options.pop('order_with_respect_to', None)
@ -601,6 +606,15 @@ class MigrationAutodetector:
), ),
dependencies=related_dependencies, dependencies=related_dependencies,
) )
for constraint in constraints:
self.add_operation(
app_label,
operations.AddConstraint(
model_name=model_name,
constraint=constraint,
),
dependencies=related_dependencies,
)
if unique_together: if unique_together:
self.add_operation( self.add_operation(
app_label, app_label,
@ -997,6 +1011,46 @@ class MigrationAutodetector:
) )
) )
def create_altered_constraints(self):
option_name = operations.AddConstraint.option_name
for app_label, model_name in sorted(self.kept_model_keys):
old_model_name = self.renamed_models.get((app_label, model_name), model_name)
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
old_constraints = old_model_state.options[option_name]
new_constraints = new_model_state.options[option_name]
add_constraints = [c for c in new_constraints if c not in old_constraints]
rem_constraints = [c for c in old_constraints if c not in new_constraints]
self.altered_constraints.update({
(app_label, model_name): {
'added_constraints': add_constraints, 'removed_constraints': rem_constraints,
}
})
def generate_added_constraints(self):
for (app_label, model_name), alt_constraints in self.altered_constraints.items():
for constraint in alt_constraints['added_constraints']:
self.add_operation(
app_label,
operations.AddConstraint(
model_name=model_name,
constraint=constraint,
)
)
def generate_removed_constraints(self):
for (app_label, model_name), alt_constraints in self.altered_constraints.items():
for constraint in alt_constraints['removed_constraints']:
self.add_operation(
app_label,
operations.RemoveConstraint(
model_name=model_name,
name=constraint.name,
)
)
def _get_dependencies_for_foreign_key(self, field): def _get_dependencies_for_foreign_key(self, field):
# Account for FKs to swappable models # Account for FKs to swappable models
swappable_setting = getattr(field, 'swappable_setting', None) swappable_setting = getattr(field, 'swappable_setting', None)

View File

@ -1,8 +1,9 @@
from .fields import AddField, AlterField, RemoveField, RenameField from .fields import AddField, AlterField, RemoveField, RenameField
from .models import ( from .models import (
AddIndex, AlterIndexTogether, AlterModelManagers, AlterModelOptions, AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
AlterModelTable, AlterOrderWithRespectTo, AlterUniqueTogether, CreateModel, AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
DeleteModel, RemoveIndex, RenameModel, AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
RemoveIndex, RenameModel,
) )
from .special import RunPython, RunSQL, SeparateDatabaseAndState from .special import RunPython, RunSQL, SeparateDatabaseAndState
@ -10,6 +11,7 @@ __all__ = [
'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether', 'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex', 'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField', 'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
'AddConstraint', 'RemoveConstraint',
'SeparateDatabaseAndState', 'RunSQL', 'RunPython', 'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
'AlterOrderWithRespectTo', 'AlterModelManagers', 'AlterOrderWithRespectTo', 'AlterModelManagers',
] ]

View File

@ -822,3 +822,72 @@ class RemoveIndex(IndexOperation):
def describe(self): def describe(self):
return 'Remove index %s from %s' % (self.name, self.model_name) return 'Remove index %s from %s' % (self.name, self.model_name)
class AddConstraint(IndexOperation):
option_name = 'constraints'
def __init__(self, model_name, constraint):
self.model_name = model_name
self.constraint = constraint
def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
constraints = list(model_state.options[self.option_name])
constraints.append(self.constraint)
model_state.options[self.option_name] = constraints
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.add_constraint(model, self.constraint)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.remove_constraint(model, self.constraint)
def deconstruct(self):
return self.__class__.__name__, [], {
'model_name': self.model_name,
'constraint': self.constraint,
}
def describe(self):
return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
class RemoveConstraint(IndexOperation):
option_name = 'constraints'
def __init__(self, model_name, name):
self.model_name = model_name
self.name = name
def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
constraints = model_state.options[self.option_name]
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
from_model_state = from_state.models[app_label, self.model_name_lower]
constraint = from_model_state.get_constraint_by_name(self.name)
schema_editor.remove_constraint(model, constraint)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
to_model_state = to_state.models[app_label, self.model_name_lower]
constraint = to_model_state.get_constraint_by_name(self.name)
schema_editor.add_constraint(model, constraint)
def deconstruct(self):
return self.__class__.__name__, [], {
'model_name': self.model_name,
'name': self.name,
}
def describe(self):
return 'Remove constraint %s from model %s' % (self.name, self.model_name)

View File

@ -362,6 +362,7 @@ class ModelState:
self.fields = fields self.fields = fields
self.options = options or {} self.options = options or {}
self.options.setdefault('indexes', []) self.options.setdefault('indexes', [])
self.options.setdefault('constraints', [])
self.bases = bases or (models.Model,) self.bases = bases or (models.Model,)
self.managers = managers or [] self.managers = managers or []
# Sanity-check that fields is NOT a dict. It must be ordered. # Sanity-check that fields is NOT a dict. It must be ordered.
@ -445,6 +446,8 @@ class ModelState:
if not index.name: if not index.name:
index.set_name_with_model(model) index.set_name_with_model(model)
options['indexes'] = indexes options['indexes'] = indexes
elif name == 'constraints':
options['constraints'] = [con.clone() for con in model._meta.constraints]
else: else:
options[name] = model._meta.original_attrs[name] options[name] = model._meta.original_attrs[name]
# If we're ignoring relationships, remove all field-listing model # If we're ignoring relationships, remove all field-listing model
@ -585,6 +588,12 @@ class ModelState:
return index return index
raise ValueError("No index named %s on model %s" % (name, self.name)) raise ValueError("No index named %s on model %s" % (name, self.name))
def get_constraint_by_name(self, name):
for constraint in self.options['constraints']:
if constraint.name == name:
return constraint
raise ValueError('No constraint named %s on model %s' % (name, self.name))
def __repr__(self): def __repr__(self):
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name) return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)

View File

@ -2,6 +2,8 @@ from django.core.exceptions import ObjectDoesNotExist
from django.db.models import signals from django.db.models import signals
from django.db.models.aggregates import * # NOQA from django.db.models.aggregates import * # NOQA
from django.db.models.aggregates import __all__ as aggregates_all from django.db.models.aggregates import __all__ as aggregates_all
from django.db.models.constraints import * # NOQA
from django.db.models.constraints import __all__ as constraints_all
from django.db.models.deletion import ( from django.db.models.deletion import (
CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError, CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError,
) )
@ -30,7 +32,7 @@ from django.db.models.fields.related import ( # isort:skip
) )
__all__ = aggregates_all + fields_all + indexes_all __all__ = aggregates_all + constraints_all + fields_all + indexes_all
__all__ += [ __all__ += [
'ObjectDoesNotExist', 'signals', 'ObjectDoesNotExist', 'signals',
'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL',

View File

@ -16,6 +16,7 @@ from django.db import (
connections, router, transaction, connections, router, transaction,
) )
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.constraints import CheckConstraint
from django.db.models.deletion import CASCADE, Collector from django.db.models.deletion import CASCADE, Collector
from django.db.models.fields.related import ( from django.db.models.fields.related import (
ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation, ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation,
@ -1201,6 +1202,7 @@ class Model(metaclass=ModelBase):
*cls._check_unique_together(), *cls._check_unique_together(),
*cls._check_indexes(), *cls._check_indexes(),
*cls._check_ordering(), *cls._check_ordering(),
*cls._check_constraints(),
] ]
return errors return errors
@ -1699,6 +1701,29 @@ class Model(metaclass=ModelBase):
return errors return errors
@classmethod
def _check_constraints(cls):
errors = []
for db in settings.DATABASES:
if not router.allow_migrate_model(db, cls):
continue
connection = connections[db]
if connection.features.supports_table_check_constraints:
continue
if any(isinstance(constraint, CheckConstraint) for constraint in cls._meta.constraints):
errors.append(
checks.Warning(
'%s does not support check constraints.' % connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
id='models.W027',
)
)
return errors
############################################ ############################################
# HELPER FUNCTIONS (CURRIED MODEL METHODS) # # HELPER FUNCTIONS (CURRIED MODEL METHODS) #

View File

@ -0,0 +1,54 @@
from django.db.models.sql.query import Query
__all__ = ['CheckConstraint']
class CheckConstraint:
def __init__(self, constraint, name):
self.constraint = constraint
self.name = name
def constraint_sql(self, model, schema_editor):
query = Query(model)
where = query.build_where(self.constraint)
connection = schema_editor.connection
compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default')
sql, params = where.as_sql(compiler, connection)
params = tuple(schema_editor.quote_value(p) for p in params)
return schema_editor.sql_check % {
'name': schema_editor.quote_name(self.name),
'check': sql % params,
}
def create_sql(self, model, schema_editor):
sql = self.constraint_sql(model, schema_editor)
return schema_editor.sql_create_check % {
'table': schema_editor.quote_name(model._meta.db_table),
'check': sql,
}
def remove_sql(self, model, schema_editor):
quote_name = schema_editor.quote_name
return schema_editor.sql_delete_check % {
'table': quote_name(model._meta.db_table),
'name': quote_name(self.name),
}
def __repr__(self):
return "<%s: constraint='%s' name='%s'>" % (self.__class__.__name__, self.constraint, self.name)
def __eq__(self, other):
return (
isinstance(other, CheckConstraint) and
self.name == other.name and
self.constraint == other.constraint
)
def deconstruct(self):
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.constraints', 'django.db.models')
return (path, (), {'constraint': self.constraint, 'name': self.name})
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)

View File

@ -505,8 +505,9 @@ class F(Combinable):
def __repr__(self): def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.name) return "{}({})".format(self.__class__.__name__, self.name)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None,
return query.resolve_ref(self.name, allow_joins, reuse, summarize) summarize=False, for_save=False, simple_col=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize, simple_col)
def asc(self, **kwargs): def asc(self, **kwargs):
return OrderBy(self, **kwargs) return OrderBy(self, **kwargs)
@ -542,7 +543,8 @@ class ResolvedOuterRef(F):
class OuterRef(F): class OuterRef(F):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None,
summarize=False, for_save=False, simple_col=False):
if isinstance(self.name, self.__class__): if isinstance(self.name, self.__class__):
return self.name return self.name
return ResolvedOuterRef(self.name) return ResolvedOuterRef(self.name)
@ -746,6 +748,40 @@ class Col(Expression):
self.target.get_db_converters(connection)) self.target.get_db_converters(connection))
class SimpleCol(Expression):
"""
Represents the SQL of a column name without the table name.
This variant of Col doesn't include the table name (or an alias) to
avoid a syntax error in check constraints.
"""
contains_column_references = True
def __init__(self, target, output_field=None):
if output_field is None:
output_field = target
super().__init__(output_field=output_field)
self.target = target
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.target)
def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias
return qn(self.target.column), []
def get_group_by_cols(self):
return [self]
def get_db_converters(self, connection):
if self.target == self.output_field:
return self.output_field.get_db_converters(connection)
return (
self.output_field.get_db_converters(connection) +
self.target.get_db_converters(connection)
)
class Ref(Expression): class Ref(Expression):
""" """
Reference to column alias of the query. For example, Ref('sum_cost') in Reference to column alias of the query. For example, Ref('sum_cost') in

View File

@ -32,7 +32,7 @@ DEFAULT_NAMES = (
'auto_created', 'index_together', 'apps', 'default_permissions', 'auto_created', 'index_together', 'apps', 'default_permissions',
'select_on_save', 'default_related_name', 'required_db_features', 'select_on_save', 'default_related_name', 'required_db_features',
'required_db_vendor', 'base_manager_name', 'default_manager_name', 'required_db_vendor', 'base_manager_name', 'default_manager_name',
'indexes', 'indexes', 'constraints',
# For backwards compatibility with Django 1.11. RemovedInDjango30Warning # For backwards compatibility with Django 1.11. RemovedInDjango30Warning
'manager_inheritance_from_future', 'manager_inheritance_from_future',
) )
@ -89,6 +89,7 @@ class Options:
self.ordering = [] self.ordering = []
self._ordering_clash = False self._ordering_clash = False
self.indexes = [] self.indexes = []
self.constraints = []
self.unique_together = [] self.unique_together = []
self.index_together = [] self.index_together = []
self.select_on_save = False self.select_on_save = False

View File

@ -18,7 +18,7 @@ from django.core.exceptions import (
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref from django.db.models.expressions import Col, F, Ref, SimpleCol
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup from django.db.models.lookups import Lookup
@ -62,6 +62,12 @@ JoinInfo = namedtuple(
) )
def _get_col(target, field, alias, simple_col):
if simple_col:
return SimpleCol(target, field)
return target.get_col(alias, field)
class RawQuery: class RawQuery:
"""A single raw SQL query.""" """A single raw SQL query."""
@ -1011,15 +1017,24 @@ class Query:
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return self.get_compiler(connection=connection).as_sql() return self.get_compiler(connection=connection).as_sql()
def resolve_lookup_value(self, value, can_reuse, allow_joins): def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col):
if hasattr(value, 'resolve_expression'): if hasattr(value, 'resolve_expression'):
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) kwargs = {'reuse': can_reuse, 'allow_joins': allow_joins}
if isinstance(value, F):
kwargs['simple_col'] = simple_col
value = value.resolve_expression(self, **kwargs)
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
# The items of the iterable may be expressions and therefore need # The items of the iterable may be expressions and therefore need
# to be resolved independently. # to be resolved independently.
for sub_value in value: for sub_value in value:
if hasattr(sub_value, 'resolve_expression'): if hasattr(sub_value, 'resolve_expression'):
sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) if isinstance(sub_value, F):
sub_value.resolve_expression(
self, reuse=can_reuse, allow_joins=allow_joins,
simple_col=simple_col,
)
else:
sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
return value return value
def solve_lookup_type(self, lookup): def solve_lookup_type(self, lookup):
@ -1133,7 +1148,7 @@ class Query:
def build_filter(self, filter_expr, branch_negated=False, current_negated=False, def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, allow_joins=True, split_subq=True, can_reuse=None, allow_joins=True, split_subq=True,
reuse_with_filtered_relation=False): reuse_with_filtered_relation=False, simple_col=False):
""" """
Build a WhereNode for a single filter clause but don't add it Build a WhereNode for a single filter clause but don't add it
to this Query. Query.add_q() will then add this filter to the where to this Query. Query.add_q() will then add this filter to the where
@ -1179,7 +1194,7 @@ class Query:
raise FieldError("Joined field references are not permitted in this query") raise FieldError("Joined field references are not permitted in this query")
pre_joins = self.alias_refcount.copy() pre_joins = self.alias_refcount.copy()
value = self.resolve_lookup_value(value, can_reuse, allow_joins) value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col)
used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)}
clause = self.where_class() clause = self.where_class()
@ -1222,11 +1237,11 @@ class Query:
if num_lookups > 1: if num_lookups > 1:
raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
if len(targets) == 1: if len(targets) == 1:
col = targets[0].get_col(alias, join_info.final_field) col = _get_col(targets[0], join_info.final_field, alias, simple_col)
else: else:
col = MultiColSource(alias, targets, join_info.targets, join_info.final_field) col = MultiColSource(alias, targets, join_info.targets, join_info.final_field)
else: else:
col = targets[0].get_col(alias, join_info.final_field) col = _get_col(targets[0], join_info.final_field, alias, simple_col)
condition = self.build_lookup(lookups, col, value) condition = self.build_lookup(lookups, col, value)
lookup_type = condition.lookup_name lookup_type = condition.lookup_name
@ -1248,7 +1263,8 @@ class Query:
# <=> # <=>
# NOT (col IS NOT NULL AND col = someval). # NOT (col IS NOT NULL AND col = someval).
lookup_class = targets[0].get_lookup('isnull') lookup_class = targets[0].get_lookup('isnull')
clause.add(lookup_class(targets[0].get_col(alias, join_info.targets[0]), False), AND) col = _get_col(targets[0], join_info.targets[0], alias, simple_col)
clause.add(lookup_class(col, False), AND)
return clause, used_joins if not require_outer else () return clause, used_joins if not require_outer else ()
def add_filter(self, filter_clause): def add_filter(self, filter_clause):
@ -1271,8 +1287,12 @@ class Query:
self.where.add(clause, AND) self.where.add(clause, AND)
self.demote_joins(existing_inner) self.demote_joins(existing_inner)
def build_where(self, q_object):
return self._add_q(q_object, used_aliases=set(), allow_joins=False, simple_col=True)[0]
def _add_q(self, q_object, used_aliases, branch_negated=False, def _add_q(self, q_object, used_aliases, branch_negated=False,
current_negated=False, allow_joins=True, split_subq=True): current_negated=False, allow_joins=True, split_subq=True,
simple_col=False):
"""Add a Q-object to the current filter.""" """Add a Q-object to the current filter."""
connector = q_object.connector connector = q_object.connector
current_negated = current_negated ^ q_object.negated current_negated = current_negated ^ q_object.negated
@ -1290,7 +1310,7 @@ class Query:
child_clause, needed_inner = self.build_filter( child_clause, needed_inner = self.build_filter(
child, can_reuse=used_aliases, branch_negated=branch_negated, child, can_reuse=used_aliases, branch_negated=branch_negated,
current_negated=current_negated, allow_joins=allow_joins, current_negated=current_negated, allow_joins=allow_joins,
split_subq=split_subq, split_subq=split_subq, simple_col=simple_col,
) )
joinpromoter.add_votes(needed_inner) joinpromoter.add_votes(needed_inner)
if child_clause: if child_clause:
@ -1559,7 +1579,7 @@ class Query:
self.unref_alias(joins.pop()) self.unref_alias(joins.pop())
return targets, joins[-1], joins return targets, joins[-1], joins
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False): def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simple_col=False):
if not allow_joins and LOOKUP_SEP in name: if not allow_joins and LOOKUP_SEP in name:
raise FieldError("Joined field references are not permitted in this query") raise FieldError("Joined field references are not permitted in this query")
if name in self.annotations: if name in self.annotations:
@ -1580,7 +1600,7 @@ class Query:
"isn't supported") "isn't supported")
if reuse is not None: if reuse is not None:
reuse.update(join_list) reuse.update(join_list)
col = targets[0].get_col(join_list[-1], join_info.targets[0]) col = _get_col(targets[0], join_info.targets[0], join_list[-1], simple_col)
return col return col
def split_exclude(self, filter_expr, can_reuse, names_with_path): def split_exclude(self, filter_expr, can_reuse, names_with_path):

View File

@ -297,6 +297,7 @@ Models
field accessor. field accessor.
* **models.E026**: The model cannot have more than one field with * **models.E026**: The model cannot have more than one field with
``primary_key=True``. ``primary_key=True``.
* **models.W027**: ``<database>`` does not support check constraints.
Security Security
-------- --------

View File

@ -207,6 +207,25 @@ Creates an index in the database table for the model with ``model_name``.
Removes the index named ``name`` from the model with ``model_name``. Removes the index named ``name`` from the model with ``model_name``.
``AddConstraint``
-----------------
.. class:: AddConstraint(model_name, constraint)
.. versionadded:: 2.2
Creates a constraint in the database table for the model with ``model_name``.
``constraint`` is an instance of :class:`~django.db.models.CheckConstraint`.
``RemoveConstraint``
--------------------
.. class:: RemoveConstraint(model_name, name)
.. versionadded:: 2.2
Removes the constraint named ``name`` from the model with ``model_name``.
Special Operations Special Operations
================== ==================

View File

@ -0,0 +1,46 @@
===========================
Check constraints reference
===========================
.. module:: django.db.models.constraints
.. currentmodule:: django.db.models
.. versionadded:: 2.2
The ``CheckConstraint`` class creates database check constraints. They are
added in the model :attr:`Meta.constraints
<django.db.models.Options.constraints>` option. This document
explains the API references of :class:`CheckConstraint`.
.. admonition:: Referencing built-in constraints
Constraints are defined in ``django.db.models.constraints``, but for
convenience they're imported into :mod:`django.db.models`. The standard
convention is to use ``from django.db import models`` and refer to the
constraints as ``models.CheckConstraint``.
``CheckConstraint`` options
===========================
.. class:: CheckConstraint(constraint, name)
Creates a check constraint in the database.
``constraint``
--------------
.. attribute:: CheckConstraint.constraint
A :class:`Q` object that specifies the condition you want the constraint to
enforce.
For example ``CheckConstraint(Q(age__gte=18), 'age_gte_18')`` ensures the age
field is never less than 18.
``name``
--------
.. attribute:: CheckConstraint.name
The name of the constraint.

View File

@ -9,6 +9,7 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`.
fields fields
indexes indexes
check-constraints
meta meta
relations relations
class class

View File

@ -451,6 +451,26 @@ Django quotes column and table names behind the scenes.
index_together = ["pub_date", "deadline"] index_together = ["pub_date", "deadline"]
``constraints``
---------------
.. attribute:: Options.constraints
.. versionadded:: 2.2
A list of :doc:`constraints </ref/models/check-constraints>` that you want
to define on the model::
from django.db import models
class Customer(models.Model):
age = models.IntegerField()
class Meta:
constraints = [
models.CheckConstraint(models.Q(age__gte=18), 'age_gte_18'),
]
``verbose_name`` ``verbose_name``
---------------- ----------------

View File

@ -30,6 +30,13 @@ officially support the latest release of each series.
What's new in Django 2.2 What's new in Django 2.2
======================== ========================
Check Constraints
-----------------
The new :class:`~django.db.models.CheckConstraint` class enables adding custom
database constraints. Constraints are added to models using the
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
Minor features Minor features
-------------- --------------
@ -213,7 +220,9 @@ Backwards incompatible changes in 2.2
Database backend API Database backend API
-------------------- --------------------
* ... * Third-party database backends must implement support for table check
constraints or set ``DatabaseFeatures.supports_table_check_constraints`` to
``False``.
:mod:`django.contrib.gis` :mod:`django.contrib.gis`
------------------------- -------------------------

View File

View File

@ -0,0 +1,15 @@
from django.db import models
class Product(models.Model):
name = models.CharField(max_length=255)
price = models.IntegerField()
discounted_price = models.IntegerField()
class Meta:
constraints = [
models.CheckConstraint(
models.Q(price__gt=models.F('discounted_price')),
'price_gt_discounted_price'
)
]

View File

@ -0,0 +1,30 @@
from django.db import IntegrityError, models
from django.test import TestCase, skipUnlessDBFeature
from .models import Product
class CheckConstraintTests(TestCase):
def test_repr(self):
constraint = models.Q(price__gt=models.F('discounted_price'))
name = 'price_gt_discounted_price'
check = models.CheckConstraint(constraint, name)
self.assertEqual(
repr(check),
"<CheckConstraint: constraint='{}' name='{}'>".format(constraint, name),
)
def test_deconstruction(self):
constraint = models.Q(price__gt=models.F('discounted_price'))
name = 'price_gt_discounted_price'
check = models.CheckConstraint(constraint, name)
path, args, kwargs = check.deconstruct()
self.assertEqual(path, 'django.db.models.CheckConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {'constraint': constraint, 'name': name})
@skipUnlessDBFeature('supports_table_check_constraints')
def test_model_constraint(self):
Product.objects.create(name='Valid', price=10, discounted_price=5)
with self.assertRaises(IntegrityError):
Product.objects.create(name='Invalid', price=10, discounted_price=20)

View File

@ -1,10 +1,10 @@
import unittest import unittest
from django.conf import settings from django.conf import settings
from django.core.checks import Error from django.core.checks import Error, Warning
from django.core.checks.model_checks import _check_lazy_references from django.core.checks.model_checks import _check_lazy_references
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import connections, models from django.db import connection, connections, models
from django.db.models.signals import post_init from django.db.models.signals import post_init
from django.test import SimpleTestCase from django.test import SimpleTestCase
from django.test.utils import isolate_apps, override_settings from django.test.utils import isolate_apps, override_settings
@ -972,3 +972,26 @@ class OtherModelTests(SimpleTestCase):
id='signals.E001', id='signals.E001',
), ),
]) ])
@isolate_apps('invalid_models_tests')
class ConstraintsTests(SimpleTestCase):
def test_check_constraints(self):
class Model(models.Model):
age = models.IntegerField()
class Meta:
constraints = [models.CheckConstraint(models.Q(age__gte=18), 'is_adult')]
errors = Model.check()
warn = Warning(
'%s does not support check constraints.' % connection.display_name,
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=Model,
id='models.W027',
)
expected = [] if connection.features.supports_table_check_constraints else [warn, warn]
self.assertCountEqual(errors, expected)

View File

@ -61,6 +61,12 @@ class AutodetectorTests(TestCase):
("id", models.AutoField(primary_key=True)), ("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=200, default='Ada Lovelace')), ("name", models.CharField(max_length=200, default='Ada Lovelace')),
]) ])
author_name_check_constraint = ModelState("testapp", "Author", [
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=200)),
],
{'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]},
)
author_dates_of_birth_auto_now = ModelState("testapp", "Author", [ author_dates_of_birth_auto_now = ModelState("testapp", "Author", [
("id", models.AutoField(primary_key=True)), ("id", models.AutoField(primary_key=True)),
("date_of_birth", models.DateField(auto_now=True)), ("date_of_birth", models.DateField(auto_now=True)),
@ -1389,6 +1395,40 @@ class AutodetectorTests(TestCase):
added_index = models.Index(fields=['title', 'author'], name='book_author_title_idx') added_index = models.Index(fields=['title', 'author'], name='book_author_title_idx')
self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='book', index=added_index) self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='book', index=added_index)
def test_create_model_with_check_constraint(self):
"""Test creation of new model with constraints already defined."""
author = ModelState('otherapp', 'Author', [
('id', models.AutoField(primary_key=True)),
('name', models.CharField(max_length=200)),
], {'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]})
changes = self.get_changes([], [author])
added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')
# Right number of migrations?
self.assertEqual(len(changes['otherapp']), 1)
# Right number of actions?
migration = changes['otherapp'][0]
self.assertEqual(len(migration.operations), 2)
# Right actions order?
self.assertOperationTypes(changes, 'otherapp', 0, ['CreateModel', 'AddConstraint'])
self.assertOperationAttributes(changes, 'otherapp', 0, 0, name='Author')
self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='author', constraint=added_constraint)
def test_add_constraints(self):
"""Test change detection of new constraints."""
changes = self.get_changes([self.author_name], [self.author_name_check_constraint])
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ['AddConstraint'])
added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')
self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', constraint=added_constraint)
def test_remove_constraints(self):
"""Test change detection of removed constraints."""
changes = self.get_changes([self.author_name_check_constraint], [self.author_name])
# Right number/type of migrations?
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ['RemoveConstraint'])
self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', name='name_contains_bob')
def test_add_foo_together(self): def test_add_foo_together(self):
"""Tests index/unique_together detection.""" """Tests index/unique_together detection."""
changes = self.get_changes([self.author_empty, self.book], [self.author_empty, self.book_foo_together]) changes = self.get_changes([self.author_empty, self.book], [self.author_empty, self.book_foo_together])
@ -1520,7 +1560,7 @@ class AutodetectorTests(TestCase):
self.assertNumberMigrations(changes, "testapp", 1) self.assertNumberMigrations(changes, "testapp", 1)
self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"]) self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"])
self.assertOperationAttributes( self.assertOperationAttributes(
changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": []} changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": [], "constraints": []}
) )
# Now, we test turning a proxy model into a non-proxy model # Now, we test turning a proxy model into a non-proxy model
# It should delete the proxy then make the real one # It should delete the proxy then make the real one

View File

@ -67,6 +67,17 @@ class MigrationTestBase(TransactionTestCase):
def assertIndexNotExists(self, table, columns): def assertIndexNotExists(self, table, columns):
return self.assertIndexExists(table, columns, False) return self.assertIndexExists(table, columns, False)
def assertConstraintExists(self, table, name, value=True, using='default'):
with connections[using].cursor() as cursor:
constraints = connections[using].introspection.get_constraints(cursor, table).items()
self.assertEqual(
value,
any(c['check'] for n, c in constraints if n == name),
)
def assertConstraintNotExists(self, table, name):
return self.assertConstraintExists(table, name, False)
def assertFKExists(self, table, columns, to, value=True, using='default'): def assertFKExists(self, table, columns, to, value=True, using='default'):
with connections[using].cursor() as cursor: with connections[using].cursor() as cursor:
self.assertEqual( self.assertEqual(

View File

@ -53,7 +53,7 @@ class OperationTestBase(MigrationTestBase):
def set_up_test_model( def set_up_test_model(
self, app_label, second_model=False, third_model=False, index=False, multicol_index=False, self, app_label, second_model=False, third_model=False, index=False, multicol_index=False,
related_model=False, mti_model=False, proxy_model=False, manager_model=False, related_model=False, mti_model=False, proxy_model=False, manager_model=False,
unique_together=False, options=False, db_table=None, index_together=False): unique_together=False, options=False, db_table=None, index_together=False, check_constraint=False):
""" """
Creates a test model state and database table. Creates a test model state and database table.
""" """
@ -106,6 +106,11 @@ class OperationTestBase(MigrationTestBase):
"Pony", "Pony",
models.Index(fields=["pink", "weight"], name="pony_test_idx") models.Index(fields=["pink", "weight"], name="pony_test_idx")
)) ))
if check_constraint:
operations.append(migrations.AddConstraint(
"Pony",
models.CheckConstraint(models.Q(pink__gt=2), name="pony_test_constraint")
))
if second_model: if second_model:
operations.append(migrations.CreateModel( operations.append(migrations.CreateModel(
"Stable", "Stable",
@ -462,6 +467,45 @@ class OperationTests(OperationTestBase):
self.assertTableNotExists("test_crummo_unmanagedpony") self.assertTableNotExists("test_crummo_unmanagedpony")
self.assertTableExists("test_crummo_pony") self.assertTableExists("test_crummo_pony")
@skipUnlessDBFeature('supports_table_check_constraints')
def test_create_model_with_constraint(self):
where = models.Q(pink__gt=2)
check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2')
operation = migrations.CreateModel(
"Pony",
[
("id", models.AutoField(primary_key=True)),
("pink", models.IntegerField(default=3)),
],
options={'constraints': [check_constraint]},
)
# Test the state alteration
project_state = ProjectState()
new_state = project_state.clone()
operation.state_forwards("test_crmo", new_state)
self.assertEqual(len(new_state.models['test_crmo', 'pony'].options['constraints']), 1)
# Test database alteration
self.assertTableNotExists("test_crmo_pony")
with connection.schema_editor() as editor:
operation.database_forwards("test_crmo", editor, project_state, new_state)
self.assertTableExists("test_crmo_pony")
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_crmo_pony (id, pink) VALUES (1, 1)")
# Test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_crmo", editor, new_state, project_state)
self.assertTableNotExists("test_crmo_pony")
# Test deconstruction
definition = operation.deconstruct()
self.assertEqual(definition[0], "CreateModel")
self.assertEqual(definition[1], [])
self.assertEqual(definition[2]['options']['constraints'], [check_constraint])
def test_create_model_managers(self): def test_create_model_managers(self):
""" """
The managers on a model are set. The managers on a model are set.
@ -1708,6 +1752,87 @@ class OperationTests(OperationTestBase):
operation = migrations.AlterIndexTogether("Pony", None) operation = migrations.AlterIndexTogether("Pony", None)
self.assertEqual(operation.describe(), "Alter index_together for Pony (0 constraint(s))") self.assertEqual(operation.describe(), "Alter index_together for Pony (0 constraint(s))")
@skipUnlessDBFeature('supports_table_check_constraints')
def test_add_constraint(self):
"""Test the AddConstraint operation."""
project_state = self.set_up_test_model('test_addconstraint')
where = models.Q(pink__gt=2)
check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2')
operation = migrations.AddConstraint('Pony', check_constraint)
self.assertEqual(operation.describe(), 'Create constraint test_constraint_pony_pink_gt_2 on model Pony')
new_state = project_state.clone()
operation.state_forwards('test_addconstraint', new_state)
self.assertEqual(len(new_state.models['test_addconstraint', 'pony'].options['constraints']), 1)
# Test database alteration
with connection.cursor() as cursor:
with atomic():
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
cursor.execute("DELETE FROM test_addconstraint_pony")
with connection.schema_editor() as editor:
operation.database_forwards("test_addconstraint", editor, project_state, new_state)
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
# Test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_addconstraint", editor, new_state, project_state)
with connection.cursor() as cursor:
with atomic():
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
cursor.execute("DELETE FROM test_addconstraint_pony")
# Test deconstruction
definition = operation.deconstruct()
self.assertEqual(definition[0], "AddConstraint")
self.assertEqual(definition[1], [])
self.assertEqual(definition[2], {'model_name': "Pony", 'constraint': check_constraint})
@skipUnlessDBFeature('supports_table_check_constraints')
def test_remove_constraint(self):
"""Test the RemoveConstraint operation."""
project_state = self.set_up_test_model("test_removeconstraint", check_constraint=True)
self.assertTableExists("test_removeconstraint_pony")
operation = migrations.RemoveConstraint("Pony", "pony_test_constraint")
self.assertEqual(operation.describe(), "Remove constraint pony_test_constraint from model Pony")
new_state = project_state.clone()
operation.state_forwards("test_removeconstraint", new_state)
# Test state alteration
self.assertEqual(len(new_state.models["test_removeconstraint", "pony"].options['constraints']), 0)
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
# Test database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_removeconstraint", editor, project_state, new_state)
with connection.cursor() as cursor:
with atomic():
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
cursor.execute("DELETE FROM test_removeconstraint_pony")
# Test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_removeconstraint", editor, new_state, project_state)
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
# Test deconstruction
definition = operation.deconstruct()
self.assertEqual(definition[0], "RemoveConstraint")
self.assertEqual(definition[1], [])
self.assertEqual(definition[2], {'model_name': "Pony", 'name': "pony_test_constraint"})
def test_alter_model_options(self): def test_alter_model_options(self):
""" """
Tests the AlterModelOptions operation. Tests the AlterModelOptions operation.

View File

@ -127,7 +127,12 @@ class StateTests(SimpleTestCase):
self.assertIs(author_state.fields[3][1].null, True) self.assertIs(author_state.fields[3][1].null, True)
self.assertEqual( self.assertEqual(
author_state.options, author_state.options,
{"unique_together": {("name", "bio")}, "index_together": {("bio", "age")}, "indexes": []} {
"unique_together": {("name", "bio")},
"index_together": {("bio", "age")},
"indexes": [],
"constraints": [],
}
) )
self.assertEqual(author_state.bases, (models.Model,)) self.assertEqual(author_state.bases, (models.Model,))
@ -139,14 +144,17 @@ class StateTests(SimpleTestCase):
self.assertEqual(book_state.fields[3][1].__class__.__name__, "ManyToManyField") self.assertEqual(book_state.fields[3][1].__class__.__name__, "ManyToManyField")
self.assertEqual( self.assertEqual(
book_state.options, book_state.options,
{"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index]}, {"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index], "constraints": []},
) )
self.assertEqual(book_state.bases, (models.Model,)) self.assertEqual(book_state.bases, (models.Model,))
self.assertEqual(author_proxy_state.app_label, "migrations") self.assertEqual(author_proxy_state.app_label, "migrations")
self.assertEqual(author_proxy_state.name, "AuthorProxy") self.assertEqual(author_proxy_state.name, "AuthorProxy")
self.assertEqual(author_proxy_state.fields, []) self.assertEqual(author_proxy_state.fields, [])
self.assertEqual(author_proxy_state.options, {"proxy": True, "ordering": ["name"], "indexes": []}) self.assertEqual(
author_proxy_state.options,
{"proxy": True, "ordering": ["name"], "indexes": [], "constraints": []},
)
self.assertEqual(author_proxy_state.bases, ("migrations.author",)) self.assertEqual(author_proxy_state.bases, ("migrations.author",))
self.assertEqual(sub_author_state.app_label, "migrations") self.assertEqual(sub_author_state.app_label, "migrations")
@ -1002,7 +1010,7 @@ class ModelStateTests(SimpleTestCase):
self.assertEqual(author_state.fields[1][1].max_length, 255) self.assertEqual(author_state.fields[1][1].max_length, 255)
self.assertIs(author_state.fields[2][1].null, False) self.assertIs(author_state.fields[2][1].null, False)
self.assertIs(author_state.fields[3][1].null, True) self.assertIs(author_state.fields[3][1].null, True)
self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []}) self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], "constraints": []})
self.assertEqual(author_state.bases, (models.Model,)) self.assertEqual(author_state.bases, (models.Model,))
self.assertEqual(author_state.managers, []) self.assertEqual(author_state.managers, [])
@ -1047,7 +1055,7 @@ class ModelStateTests(SimpleTestCase):
self.assertEqual(station_state.fields[2][1].null, False) self.assertEqual(station_state.fields[2][1].null, False)
self.assertEqual( self.assertEqual(
station_state.options, station_state.options,
{'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []} {'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], 'constraints': []}
) )
self.assertEqual(station_state.bases, ('migrations.searchablelocation',)) self.assertEqual(station_state.bases, ('migrations.searchablelocation',))
self.assertEqual(station_state.managers, []) self.assertEqual(station_state.managers, [])
@ -1129,6 +1137,21 @@ class ModelStateTests(SimpleTestCase):
index_names = [index.name for index in model_state.options['indexes']] index_names = [index.name for index in model_state.options['indexes']]
self.assertEqual(index_names, ['foo_idx']) self.assertEqual(index_names, ['foo_idx'])
@isolate_apps('migrations')
def test_from_model_constraints(self):
class ModelWithConstraints(models.Model):
size = models.IntegerField()
class Meta:
constraints = [models.CheckConstraint(models.Q(size__gt=1), 'size_gt_1')]
state = ModelState.from_model(ModelWithConstraints)
model_constraints = ModelWithConstraints._meta.constraints
state_constraints = state.options['constraints']
self.assertEqual(model_constraints, state_constraints)
self.assertIsNot(model_constraints, state_constraints)
self.assertIsNot(model_constraints[0], state_constraints[0])
class RelatedModelsTests(SimpleTestCase): class RelatedModelsTests(SimpleTestCase):

View File

@ -0,0 +1,95 @@
from datetime import datetime
from django.core.exceptions import FieldError
from django.db.models import CharField, F, Q
from django.db.models.expressions import SimpleCol
from django.db.models.fields.related_lookups import RelatedIsNull
from django.db.models.functions import Lower
from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan
from django.db.models.sql.query import Query
from django.db.models.sql.where import OR
from django.test import TestCase
from .models import Author, Item, ObjectC, Ranking
class TestQuery(TestCase):
def test_simple_query(self):
query = Query(Author)
where = query.build_where(Q(num__gt=2))
lookup = where.children[0]
self.assertIsInstance(lookup, GreaterThan)
self.assertEqual(lookup.rhs, 2)
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
def test_complex_query(self):
query = Query(Author)
where = query.build_where(Q(num__gt=2) | Q(num__lt=0))
self.assertEqual(where.connector, OR)
lookup = where.children[0]
self.assertIsInstance(lookup, GreaterThan)
self.assertEqual(lookup.rhs, 2)
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
lookup = where.children[1]
self.assertIsInstance(lookup, LessThan)
self.assertEqual(lookup.rhs, 0)
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
def test_multiple_fields(self):
query = Query(Item)
where = query.build_where(Q(modified__gt=F('created')))
lookup = where.children[0]
self.assertIsInstance(lookup, GreaterThan)
self.assertIsInstance(lookup.rhs, SimpleCol)
self.assertIsInstance(lookup.lhs, SimpleCol)
self.assertEqual(lookup.rhs.target, Item._meta.get_field('created'))
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
def test_transform(self):
query = Query(Author)
CharField.register_lookup(Lower, 'lower')
try:
where = query.build_where(~Q(name__lower='foo'))
finally:
CharField._unregister_lookup(Lower, 'lower')
lookup = where.children[0]
self.assertIsInstance(lookup, Exact)
self.assertIsInstance(lookup.lhs, Lower)
self.assertIsInstance(lookup.lhs.lhs, SimpleCol)
self.assertEqual(lookup.lhs.lhs.target, Author._meta.get_field('name'))
def test_negated_nullable(self):
query = Query(Item)
where = query.build_where(~Q(modified__lt=datetime(2017, 1, 1)))
self.assertTrue(where.negated)
lookup = where.children[0]
self.assertIsInstance(lookup, LessThan)
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
lookup = where.children[1]
self.assertIsInstance(lookup, IsNull)
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
def test_foreign_key(self):
query = Query(Item)
msg = 'Joined field references are not permitted in this query'
with self.assertRaisesMessage(FieldError, msg):
query.build_where(Q(creator__num__gt=2))
def test_foreign_key_f(self):
query = Query(Ranking)
with self.assertRaises(FieldError):
query.build_where(Q(rank__gt=F('author__num')))
def test_foreign_key_exclusive(self):
query = Query(ObjectC)
where = query.build_where(Q(objecta=None) | Q(objectb=None))
a_isnull = where.children[0]
self.assertIsInstance(a_isnull, RelatedIsNull)
self.assertIsInstance(a_isnull.lhs, SimpleCol)
self.assertEqual(a_isnull.lhs.target, ObjectC._meta.get_field('objecta'))
b_isnull = where.children[1]
self.assertIsInstance(b_isnull, RelatedIsNull)
self.assertIsInstance(b_isnull.lhs, SimpleCol)
self.assertEqual(b_isnull.lhs.target, ObjectC._meta.get_field('objectb'))