From ca9c3cd39fade827cced1b5198dd37bb80c208b0 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Fri, 7 Sep 2012 15:40:59 -0400 Subject: [PATCH] Add check constraint support - needed a few Field changes --- django/db/backends/__init__.py | 3 + django/db/backends/creation.py | 1 + django/db/backends/mysql/base.py | 1 + .../backends/postgresql_psycopg2/creation.py | 9 ++- .../postgresql_psycopg2/introspection.py | 2 +- django/db/backends/schema.py | 61 ++++++++++++++----- django/db/backends/sqlite3/base.py | 1 + django/db/backends/sqlite3/schema.py | 6 +- django/db/models/fields/__init__.py | 26 +++++++- django/db/models/fields/related.py | 6 ++ tests/modeltests/schema/models.py | 1 + tests/modeltests/schema/tests.py | 50 +++++++++++++++ 12 files changed, 143 insertions(+), 24 deletions(-) diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 39883de35c..021d9bd450 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -435,6 +435,9 @@ class BaseDatabaseFeatures(object): # Does it support foreign keys? supports_foreign_keys = True + # Does it support CHECK constraints? + supports_check_constraints = True + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index 0cc01cc876..52d5edf57d 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -18,6 +18,7 @@ class BaseDatabaseCreation(object): destruction of test databases. """ data_types = {} + data_type_check_constraints = {} def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 10649b64b9..4694dcd46f 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -170,6 +170,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): requires_explicit_null_ordering_when_grouping = True allows_primary_key_0 = False uses_savepoints = True + supports_check_constraints = False def __init__(self, connection): super(DatabaseFeatures, self).__init__(connection) diff --git a/django/db/backends/postgresql_psycopg2/creation.py b/django/db/backends/postgresql_psycopg2/creation.py index ca389b9046..f131d14abe 100644 --- a/django/db/backends/postgresql_psycopg2/creation.py +++ b/django/db/backends/postgresql_psycopg2/creation.py @@ -26,14 +26,19 @@ class DatabaseCreation(BaseDatabaseCreation): 'GenericIPAddressField': 'inet', 'NullBooleanField': 'boolean', 'OneToOneField': 'integer', - 'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)', - 'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)', + 'PositiveIntegerField': 'integer', + 'PositiveSmallIntegerField': 'smallint', 'SlugField': 'varchar(%(max_length)s)', 'SmallIntegerField': 'smallint', 'TextField': 'text', 'TimeField': 'time', } + data_type_check_constraints = { + 'PositiveIntegerField': '"%(column)s" >= 0', + 'PositiveSmallIntegerField': '"%(column)s" >= 0', + } + def sql_table_creation_suffix(self): assert self.connection.settings_dict['TEST_COLLATION'] is None, "PostgreSQL does not support collation setting at database creation time." if self.connection.settings_dict['TEST_CHARSET']: diff --git a/django/db/backends/postgresql_psycopg2/introspection.py b/django/db/backends/postgresql_psycopg2/introspection.py index 580d16d1fb..5a29932859 100644 --- a/django/db/backends/postgresql_psycopg2/introspection.py +++ b/django/db/backends/postgresql_psycopg2/introspection.py @@ -137,7 +137,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): kc.table_schema = %s AND kc.table_name = %s """, ["public", table_name]) - for constraint, column, kind in cursor.fetchall(): + for constraint, column in cursor.fetchall(): # If we're the first column, make the record if constraint not in constraints: constraints[constraint] = { diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index bd86d52e88..974a18cc34 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -19,9 +19,6 @@ class BaseDatabaseSchemaEditor(object): then the relevant actions, and then commit(). This is necessary to allow things like circular foreign key references - FKs will only be created once commit() is called. - - TODO: - - Check constraints (PosIntField) """ # Overrideable SQL templates @@ -41,7 +38,7 @@ class BaseDatabaseSchemaEditor(object): sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE" sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" - sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)" + sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)" @@ -105,7 +102,8 @@ class BaseDatabaseSchemaEditor(object): The field must already have had set_attributes_from_name called. """ # Get the column's type and use that as the basis of the SQL - sql = field.db_type(connection=self.connection) + db_params = field.db_parameters(connection=self.connection) + sql = db_params['type'] params = [] # Check for fields that aren't actually columns (e.g. M2M) if sql is None: @@ -169,6 +167,11 @@ class BaseDatabaseSchemaEditor(object): definition, extra_params = self.column_sql(model, field) if definition is None: continue + # Check constraints can go on the column SQL here + db_params = field.db_parameters(connection=self.connection) + if db_params['check']: + definition += " CHECK (%s)" % db_params['check'] + # Add the SQL to our big list column_sqls.append("%s %s" % ( self.quote_name(field.column), definition, @@ -295,6 +298,10 @@ class BaseDatabaseSchemaEditor(object): # It might not actually have a column behind it if definition is None: return + # Check constraints can go on the column SQL here + db_params = field.db_parameters(connection=self.connection) + if db_params['check']: + definition += " CHECK (%s)" % db_params['check'] # Build the SQL and run it sql = self.sql_create_column % { "table": self.quote_name(model._meta.db_table), @@ -358,8 +365,10 @@ class BaseDatabaseSchemaEditor(object): If strict is true, raises errors if the old column does not match old_field precisely. """ # Ensure this field is even column-based - old_type = old_field.db_type(connection=self.connection) - new_type = self._type_for_alter(new_field) + old_db_params = old_field.db_parameters(connection=self.connection) + old_type = old_db_params['type'] + new_db_params = new_field.db_parameters(connection=self.connection) + new_type = new_db_params['type'] if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created): return self._alter_many_to_many(model, old_field, new_field, strict) elif old_type is None or new_type is None: @@ -417,6 +426,22 @@ class BaseDatabaseSchemaEditor(object): "name": fk_name, } ) + # Change check constraints? + if old_db_params['check'] != new_db_params['check'] and old_db_params['check']: + constraint_names = self._constraint_names(model, [old_field.column], check=True) + if strict and len(constraint_names) != 1: + raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % ( + len(constraint_names), + model._meta.db_table, + old_field.column, + )) + for constraint_name in constraint_names: + self.execute( + self.sql_delete_check % { + "table": self.quote_name(model._meta.db_table), + "name": constraint_name, + } + ) # Have they renamed the column? if old_field.column != new_field.column: self.execute(self.sql_rename_column % { @@ -543,6 +568,16 @@ class BaseDatabaseSchemaEditor(object): "to_column": self.quote_name(new_field.rel.get_related_field().column), } ) + # Does it have check constraints we need to add? + if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: + self.execute( + self.sql_create_check % { + "table": self.quote_name(model._meta.db_table), + "name": self._create_index_name(model, [new_field.column], suffix="_check"), + "column": self.quote_name(new_field.column), + "check": new_db_params['check'], + } + ) def _alter_many_to_many(self, model, old_field, new_field, strict): "Alters M2Ms to repoint their to= endpoints." @@ -555,14 +590,6 @@ class BaseDatabaseSchemaEditor(object): new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0], ) - def _type_for_alter(self, field): - """ - Returns a field's type suitable for ALTER COLUMN. - By default it just returns field.db_type(). - To be overriden by backend specific subclasses - """ - return field.db_type(connection=self.connection) - def _create_index_name(self, model, column_names, suffix=""): "Generates a unique name for an index/unique constraint." # If there is just one column in the index, use a default algorithm from Django @@ -581,7 +608,7 @@ class BaseDatabaseSchemaEditor(object): index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part) return index_name - def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None): + def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None, check=None): "Returns all constraint names matching the columns and conditions" column_names = set(column_names) if column_names else None constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) @@ -594,6 +621,8 @@ class BaseDatabaseSchemaEditor(object): continue if index is not None and infodict['index'] != index: continue + if check is not None and infodict['check'] != check: + continue if foreign_key is not None and not infodict['foreign_key']: continue result.append(name) diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 45e7264e5c..8e30c7f22d 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -97,6 +97,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_bulk_insert = True can_combine_inserts_with_and_without_auto_increment_pk = False supports_foreign_keys = False + supports_check_constraints = False @cached_property def supports_stddev(self): diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index e660f26c87..6149a4e772 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -99,8 +99,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def alter_field(self, model, old_field, new_field, strict=False): # Ensure this field is even column-based - old_type = old_field.db_type(connection=self.connection) - new_type = self._type_for_alter(new_field) + old_db_params = old_field.db_parameters(connection=self.connection) + old_type = old_db_params['type'] + new_db_params = new_field.db_parameters(connection=self.connection) + new_type = new_db_params['type'] if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created): return self._alter_many_to_many(model, old_field, new_field, strict) elif old_type is None or new_type is None: diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 58ae3413f3..a0b09c9fec 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -232,12 +232,32 @@ class Field(object): # mapped to one of the built-in Django field types. In this case, you # can implement db_type() instead of get_internal_type() to specify # exactly which wacky database column type you want to use. + params = self.db_parameters(connection) + if params['type']: + if params['check']: + return "%s CHECK (%s)" % (params['type'], params['check']) + else: + return params['type'] + return None + + def db_parameters(self, connection): + """ + Replacement for db_type, providing a range of different return + values (type, checks) + """ data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") try: - return (connection.creation.data_types[self.get_internal_type()] - % data) + type_string = connection.creation.data_types[self.get_internal_type()] % data except KeyError: - return None + type_string = None + try: + check_string = connection.creation.data_type_check_constraints[self.get_internal_type()] % data + except KeyError: + check_string = None + return { + "type": type_string, + "check": check_string, + } @property def unique(self): diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 08cc0a747f..37bf4e8072 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -1050,6 +1050,9 @@ class ForeignKey(RelatedField, Field): return IntegerField().db_type(connection=connection) return rel_field.db_type(connection=connection) + def db_parameters(self, connection): + return {"type": self.db_type(connection), "check": []} + class OneToOneField(ForeignKey): """ A OneToOneField is essentially the same as a ForeignKey, with the exception @@ -1292,3 +1295,6 @@ class ManyToManyField(RelatedField, Field): # A ManyToManyField is not represented by a single column, # so return None. return None + + def db_parameters(self, connection): + return {"type": None, "check": None} diff --git a/tests/modeltests/schema/models.py b/tests/modeltests/schema/models.py index 76a8cf3687..f3d6d09e2e 100644 --- a/tests/modeltests/schema/models.py +++ b/tests/modeltests/schema/models.py @@ -7,6 +7,7 @@ from django.db import models class Author(models.Model): name = models.CharField(max_length=255) + height = models.PositiveIntegerField(null=True, blank=True) class Meta: managed = False diff --git a/tests/modeltests/schema/tests.py b/tests/modeltests/schema/tests.py index b3fc5d1c80..7d8602eff7 100644 --- a/tests/modeltests/schema/tests.py +++ b/tests/modeltests/schema/tests.py @@ -347,6 +347,56 @@ class SchemaTests(TestCase): else: self.fail("No FK constraint for tag_id found") + @skipUnless(connection.features.supports_check_constraints, "No check constraints") + def test_check_constraints(self): + """ + Tests creating/deleting CHECK constraints + """ + # Create the tables + editor = connection.schema_editor() + editor.start() + editor.create_model(Author) + editor.commit() + # Ensure the constraint exists + constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) + for name, details in constraints.items(): + if details['columns'] == set(["height"]) and details['check']: + break + else: + self.fail("No check constraint for height found") + # Alter the column to remove it + new_field = IntegerField(null=True, blank=True) + new_field.set_attributes_from_name("height") + editor = connection.schema_editor() + editor.start() + editor.alter_field( + Author, + Author._meta.get_field_by_name("height")[0], + new_field, + strict = True, + ) + editor.commit() + constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) + for name, details in constraints.items(): + if details['columns'] == set(["height"]) and details['check']: + self.fail("Check constraint for height found") + # Alter the column to re-add it + editor = connection.schema_editor() + editor.start() + editor.alter_field( + Author, + new_field, + Author._meta.get_field_by_name("height")[0], + strict = True, + ) + editor.commit() + constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) + for name, details in constraints.items(): + if details['columns'] == set(["height"]) and details['check']: + break + else: + self.fail("No check constraint for height found") + def test_unique(self): """ Tests removing and adding unique constraints to a single column.