mirror of
				https://github.com/django/django.git
				synced 2025-10-25 22:56:12 +00:00 
			
		
		
		
	Add some field schema alteration methods and tests.
This commit is contained in:
		| @@ -419,6 +419,9 @@ class BaseDatabaseFeatures(object): | ||||
|     # Can we roll back DDL in a transaction? | ||||
|     can_rollback_ddl = False | ||||
|  | ||||
|     # Can we issue more than one ALTER COLUMN clause in an ALTER TABLE? | ||||
|     supports_combined_alters = False | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         self.connection = connection | ||||
|  | ||||
|   | ||||
| @@ -85,6 +85,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): | ||||
|     supports_tablespaces = True | ||||
|     can_distinct_on_fields = True | ||||
|     can_rollback_ddl = True | ||||
|     supports_combined_alters = True | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     vendor = 'postgresql' | ||||
|   | ||||
| @@ -5,6 +5,7 @@ from django.conf import settings | ||||
| from django.db import transaction | ||||
| from django.db.utils import load_backend | ||||
| from django.utils.log import getLogger | ||||
| from django.db.models.fields.related import ManyToManyField | ||||
|  | ||||
| logger = getLogger('django.db.backends.schema') | ||||
|  | ||||
| @@ -29,11 +30,15 @@ class BaseDatabaseSchemaEditor(object): | ||||
|     sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s" | ||||
|     sql_delete_table = "DROP TABLE %(table)s CASCADE" | ||||
|  | ||||
|     sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(definition)s" | ||||
|     sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s" | ||||
|     sql_alter_column = "ALTER TABLE %(table)s %(changes)s" | ||||
|     sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s" | ||||
|     sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL" | ||||
|     sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL" | ||||
|     sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE;" | ||||
|     sql_alter_column_default = "ALTER COLUMN %(column)s SET DEFAULT %(default)s" | ||||
|     sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT" | ||||
|     sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE" | ||||
|     sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" | ||||
|  | ||||
|     sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)" | ||||
|     sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" | ||||
| @@ -91,50 +96,7 @@ class BaseDatabaseSchemaEditor(object): | ||||
|     def quote_name(self, name): | ||||
|         return self.connection.ops.quote_name(name) | ||||
|  | ||||
|     # Actions | ||||
|  | ||||
|     def create_model(self, model): | ||||
|         """ | ||||
|         Takes a model and creates a table for it in the database. | ||||
|         Will also create any accompanying indexes or unique constraints. | ||||
|         """ | ||||
|         # Do nothing if this is an unmanaged or proxy model | ||||
|         if not model._meta.managed or model._meta.proxy: | ||||
|             return [], {} | ||||
|         # Create column SQL, add FK deferreds if needed | ||||
|         column_sqls = [] | ||||
|         for field in model._meta.local_fields: | ||||
|             # SQL | ||||
|             definition = self.column_sql(model, field) | ||||
|             if definition is None: | ||||
|                 continue | ||||
|             column_sqls.append("%s %s" % ( | ||||
|                 self.quote_name(field.column), | ||||
|                 definition, | ||||
|             )) | ||||
|             # FK | ||||
|             if field.rel: | ||||
|                 to_table = field.rel.to._meta.db_table | ||||
|                 to_column = field.rel.to._meta.get_field(field.rel.field_name).column | ||||
|                 self.deferred_sql.append( | ||||
|                     self.sql_create_fk % { | ||||
|                         "name": '%s_refs_%s_%x' % ( | ||||
|                             field.column, | ||||
|                             to_column, | ||||
|                             abs(hash((model._meta.db_table, to_table))) | ||||
|                         ), | ||||
|                         "table": self.quote_name(model._meta.db_table), | ||||
|                         "column": self.quote_name(field.column), | ||||
|                         "to_table": self.quote_name(to_table), | ||||
|                         "to_column": self.quote_name(to_column), | ||||
|                     } | ||||
|                 ) | ||||
|         # Make the table | ||||
|         sql = self.sql_create_table % { | ||||
|             "table": model._meta.db_table, | ||||
|             "definition": ", ".join(column_sqls) | ||||
|         } | ||||
|         self.execute(sql) | ||||
|     # Field <-> database mapping functions | ||||
|  | ||||
|     def column_sql(self, model, field, include_default=False): | ||||
|         """ | ||||
| @@ -143,6 +105,7 @@ class BaseDatabaseSchemaEditor(object): | ||||
|         """ | ||||
|         # Get the column's type and use that as the basis of the SQL | ||||
|         sql = field.db_type(connection=self.connection) | ||||
|         params = [] | ||||
|         # Check for fields that aren't actually columns (e.g. M2M) | ||||
|         if sql is None: | ||||
|             return None | ||||
| @@ -168,11 +131,232 @@ class BaseDatabaseSchemaEditor(object): | ||||
|             sql += " UNIQUE" | ||||
|         # If we were told to include a default value, do so | ||||
|         if include_default: | ||||
|             raise NotImplementedError() | ||||
|             sql += " DEFAULT %s" | ||||
|             params += [self.effective_default(field)] | ||||
|         # Return the sql | ||||
|         return sql | ||||
|         return sql, params | ||||
|  | ||||
|     def effective_default(self, field): | ||||
|         "Returns a field's effective database default value" | ||||
|         if field.has_default(): | ||||
|             default = field.get_default() | ||||
|         elif not field.null and field.blank and field.empty_strings_allowed: | ||||
|             default = "" | ||||
|         else: | ||||
|             default = None | ||||
|         # If it's a callable, call it | ||||
|         if callable(default): | ||||
|             default = default() | ||||
|         return default | ||||
|  | ||||
|     # Actions | ||||
|  | ||||
|     def create_model(self, model): | ||||
|         """ | ||||
|         Takes a model and creates a table for it in the database. | ||||
|         Will also create any accompanying indexes or unique constraints. | ||||
|         """ | ||||
|         # Do nothing if this is an unmanaged or proxy model | ||||
|         if not model._meta.managed or model._meta.proxy: | ||||
|             return | ||||
|         # Create column SQL, add FK deferreds if needed | ||||
|         column_sqls = [] | ||||
|         params = [] | ||||
|         for field in model._meta.local_fields: | ||||
|             # SQL | ||||
|             definition, extra_params = self.column_sql(model, field) | ||||
|             if definition is None: | ||||
|                 continue | ||||
|             column_sqls.append("%s %s" % ( | ||||
|                 self.quote_name(field.column), | ||||
|                 definition, | ||||
|             )) | ||||
|             params.extend(extra_params) | ||||
|             # FK | ||||
|             if field.rel: | ||||
|                 to_table = field.rel.to._meta.db_table | ||||
|                 to_column = field.rel.to._meta.get_field(field.rel.field_name).column | ||||
|                 self.deferred_sql.append( | ||||
|                     self.sql_create_fk % { | ||||
|                         "name": '%s_refs_%s_%x' % ( | ||||
|                             field.column, | ||||
|                             to_column, | ||||
|                             abs(hash((model._meta.db_table, to_table))) | ||||
|                         ), | ||||
|                         "table": self.quote_name(model._meta.db_table), | ||||
|                         "column": self.quote_name(field.column), | ||||
|                         "to_table": self.quote_name(to_table), | ||||
|                         "to_column": self.quote_name(to_column), | ||||
|                     } | ||||
|                 ) | ||||
|         # Make the table | ||||
|         sql = self.sql_create_table % { | ||||
|             "table": model._meta.db_table, | ||||
|             "definition": ", ".join(column_sqls) | ||||
|         } | ||||
|         self.execute(sql, params) | ||||
|  | ||||
|     def delete_model(self, model): | ||||
|         """ | ||||
|         Deletes a model from the database. | ||||
|         """ | ||||
|         # Do nothing if this is an unmanaged or proxy model | ||||
|         if not model._meta.managed or model._meta.proxy: | ||||
|             return | ||||
|         # Delete the table | ||||
|         self.execute(self.sql_delete_table % { | ||||
|             "table": self.quote_name(model._meta.db_table), | ||||
|         }) | ||||
|  | ||||
|     def create_field(self, model, field, keep_default=False): | ||||
|         """ | ||||
|         Creates a field on a model. | ||||
|         Usually involves adding a column, but may involve adding a | ||||
|         table instead (for M2M fields) | ||||
|         """ | ||||
|         # Special-case implicit M2M tables | ||||
|         if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created: | ||||
|             return self.create_model(field.rel.through) | ||||
|         # Get the column's definition | ||||
|         definition, params = self.column_sql(model, field, include_default=True) | ||||
|         # It might not actually have a column behind it | ||||
|         if definition is None: | ||||
|             return | ||||
|         # Build the SQL and run it | ||||
|         sql = self.sql_create_column % { | ||||
|             "table": self.quote_name(model._meta.db_table), | ||||
|             "column": self.quote_name(field.column), | ||||
|             "definition": definition, | ||||
|         } | ||||
|         self.execute(sql, params) | ||||
|         # Drop the default if we need to | ||||
|         # (Django usually does not use in-database defaults) | ||||
|         if not keep_default and field.default is not None: | ||||
|             sql = self.sql_alter_column % { | ||||
|                 "table": self.quote_name(model._meta.db_table), | ||||
|                 "changes": self.sql_alter_column_no_default % { | ||||
|                     "column": self.quote_name(field.column), | ||||
|                 } | ||||
|             } | ||||
|         # Add any FK constraints later | ||||
|         if field.rel: | ||||
|             to_table = field.rel.to._meta.db_table | ||||
|             to_column = field.rel.to._meta.get_field(field.rel.field_name).column | ||||
|             self.deferred_sql.append( | ||||
|                 self.sql_create_fk % { | ||||
|                     "name": '%s_refs_%s_%x' % ( | ||||
|                         field.column, | ||||
|                         to_column, | ||||
|                         abs(hash((model._meta.db_table, to_table))) | ||||
|                     ), | ||||
|                     "table": self.quote_name(model._meta.db_table), | ||||
|                     "column": self.quote_name(field.column), | ||||
|                     "to_table": self.quote_name(to_table), | ||||
|                     "to_column": self.quote_name(to_column), | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|     def delete_field(self, model, field): | ||||
|         """ | ||||
|         Removes a field from a model. Usually involves deleting a column, | ||||
|         but for M2Ms may involve deleting a table. | ||||
|         """ | ||||
|         # Special-case implicit M2M tables | ||||
|         if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created: | ||||
|             return self.delete_model(field.rel.through) | ||||
|         # Get the column's definition | ||||
|         definition, params = self.column_sql(model, field) | ||||
|         # It might not actually have a column behind it | ||||
|         if definition is None: | ||||
|             return | ||||
|         # Delete the column | ||||
|         sql = self.sql_delete_column % { | ||||
|             "table": self.quote_name(model._meta.db_table), | ||||
|             "column": self.quote_name(field.column), | ||||
|         } | ||||
|         self.execute(sql) | ||||
|  | ||||
|     def alter_field(self, model, old_field, new_field): | ||||
|         """ | ||||
|         Allows a field's type, uniqueness, nullability, default, column, | ||||
|         constraints etc. to be modified. | ||||
|         Requires a copy of the old field as well so we can only perform | ||||
|         changes that are required. | ||||
|         """ | ||||
|         # Ensure this field is even column-based | ||||
|         old_type = old_field.db_type(connection=self.connection) | ||||
|         new_type = new_field.db_type(connection=self.connection) | ||||
|         if old_type is None and new_type is None: | ||||
|             # TODO: Handle M2M fields being repointed | ||||
|             return | ||||
|         elif old_type is None or new_type is None: | ||||
|             raise ValueError("Cannot alter field %s into %s - they are not compatible types" % ( | ||||
|                     old_field, | ||||
|                     new_field, | ||||
|                 )) | ||||
|         # First, have they renamed the column? | ||||
|         if old_field.column != new_field.column: | ||||
|             self.execute(self.sql_rename_column % { | ||||
|                 "table": self.quote_name(model._meta.db_table), | ||||
|                 "old_column": self.quote_name(old_field.column), | ||||
|                 "new_column": self.quote_name(new_field.column), | ||||
|             }) | ||||
|         # Next, start accumulating actions to do | ||||
|         actions = [] | ||||
|         # Type change? | ||||
|         if old_type != new_type: | ||||
|             actions.append(( | ||||
|                 self.sql_alter_column_type % { | ||||
|                     "column": self.quote_name(new_field.column), | ||||
|                     "type": new_type, | ||||
|                 }, | ||||
|                 [], | ||||
|             )) | ||||
|         # Default change? | ||||
|         old_default = self.effective_default(old_field) | ||||
|         new_default = self.effective_default(new_field) | ||||
|         if old_default != new_default: | ||||
|             if new_default is None: | ||||
|                 actions.append(( | ||||
|                     self.sql_alter_column_no_default % { | ||||
|                         "column": self.quote_name(new_field.column), | ||||
|                     }, | ||||
|                     [], | ||||
|                 )) | ||||
|             else: | ||||
|                 actions.append(( | ||||
|                     self.sql_alter_column_default % { | ||||
|                         "column": self.quote_name(new_field.column), | ||||
|                         "default": "%s", | ||||
|                     }, | ||||
|                     [new_default], | ||||
|                 )) | ||||
|         # Nullability change? | ||||
|         if old_field.null != new_field.null: | ||||
|             if new_field.null: | ||||
|                 actions.append(( | ||||
|                     self.sql_alter_column_null % { | ||||
|                         "column": self.quote_name(new_field.column), | ||||
|                     }, | ||||
|                     [], | ||||
|                 )) | ||||
|             else: | ||||
|                 actions.append(( | ||||
|                     self.sql_alter_column_null % { | ||||
|                         "column": self.quote_name(new_field.column), | ||||
|                     }, | ||||
|                     [], | ||||
|                 )) | ||||
|         # Combine actions together if we can (e.g. postgres) | ||||
|         if self.connection.features.supports_combined_alters: | ||||
|             sql, params = tuple(zip(*actions)) | ||||
|             actions = [(", ".join(sql), params)] | ||||
|         # Apply those actions | ||||
|         for sql, params in actions: | ||||
|             self.execute( | ||||
|                 self.sql_alter_column % { | ||||
|                     "table": self.quote_name(model._meta.db_table), | ||||
|                     "changes": sql, | ||||
|                 }, | ||||
|                 params, | ||||
|             ) | ||||
|   | ||||
| @@ -2,8 +2,9 @@ from __future__ import absolute_import | ||||
| import copy | ||||
| import datetime | ||||
| from django.test import TestCase | ||||
| from django.db.models.loading import cache | ||||
| from django.db import connection, DatabaseError, IntegrityError | ||||
| from django.db.models.fields import IntegerField, TextField | ||||
| from django.db.models.loading import cache | ||||
| from .models import Author, Book | ||||
|  | ||||
|  | ||||
| @@ -18,6 +19,8 @@ class SchemaTests(TestCase): | ||||
|  | ||||
|     models = [Author, Book] | ||||
|  | ||||
|     # Utility functions | ||||
|  | ||||
|     def setUp(self): | ||||
|         # Make sure we're in manual transaction mode | ||||
|         connection.commit_unless_managed() | ||||
| @@ -51,6 +54,18 @@ class SchemaTests(TestCase): | ||||
|         cache.app_store = self.old_app_store | ||||
|         cache._get_models_cache = {} | ||||
|  | ||||
|     def column_classes(self, model): | ||||
|         cursor = connection.cursor() | ||||
|         return dict( | ||||
|             (d[0], (connection.introspection.get_field_type(d[1], d), d)) | ||||
|             for d in connection.introspection.get_table_description( | ||||
|                 cursor, | ||||
|                 model._meta.db_table, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     # Tests | ||||
|  | ||||
|     def test_creation_deletion(self): | ||||
|         """ | ||||
|         Tries creating a model's table, and then deleting it. | ||||
| @@ -100,3 +115,60 @@ class SchemaTests(TestCase): | ||||
|                 pub_date = datetime.datetime.now(), | ||||
|             ) | ||||
|             connection.commit() | ||||
|  | ||||
|     def test_create_field(self): | ||||
|         """ | ||||
|         Tests adding fields to models | ||||
|         """ | ||||
|         # Create the table | ||||
|         editor = connection.schema_editor() | ||||
|         editor.start() | ||||
|         editor.create_model(Author) | ||||
|         editor.commit() | ||||
|         # Ensure there's no age field | ||||
|         columns = self.column_classes(Author) | ||||
|         self.assertNotIn("age", columns) | ||||
|         # Alter the name field to a TextField | ||||
|         new_field = IntegerField(null=True) | ||||
|         new_field.set_attributes_from_name("age") | ||||
|         editor = connection.schema_editor() | ||||
|         editor.start() | ||||
|         editor.create_field( | ||||
|             Author, | ||||
|             new_field, | ||||
|         ) | ||||
|         editor.commit() | ||||
|         # Ensure the field is right afterwards | ||||
|         columns = self.column_classes(Author) | ||||
|         self.assertEqual(columns['age'][0], "IntegerField") | ||||
|         self.assertEqual(columns['age'][1][6], True) | ||||
|  | ||||
|     def test_alter(self): | ||||
|         """ | ||||
|         Tests simple altering of fields | ||||
|         """ | ||||
|         # Create the table | ||||
|         editor = connection.schema_editor() | ||||
|         editor.start() | ||||
|         editor.create_model(Author) | ||||
|         editor.commit() | ||||
|         # Ensure the field is right to begin with | ||||
|         columns = self.column_classes(Author) | ||||
|         self.assertEqual(columns['name'][0], "CharField") | ||||
|         self.assertEqual(columns['name'][1][3], 255) | ||||
|         self.assertEqual(columns['name'][1][6], False) | ||||
|         # Alter the name field to a TextField | ||||
|         new_field = TextField(null=True) | ||||
|         new_field.set_attributes_from_name("name") | ||||
|         editor = connection.schema_editor() | ||||
|         editor.start() | ||||
|         editor.alter_field( | ||||
|             Author, | ||||
|             Author._meta.get_field_by_name("name")[0], | ||||
|             new_field, | ||||
|         ) | ||||
|         editor.commit() | ||||
|         # Ensure the field is right afterwards | ||||
|         columns = self.column_classes(Author) | ||||
|         self.assertEqual(columns['name'][0], "TextField") | ||||
|         self.assertEqual(columns['name'][1][6], True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user