mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Adding 'sqlmigrate' command and quote_parameter to support it.
This commit is contained in:
52
django/core/management/commands/sqlmigrate.py
Normal file
52
django/core/management/commands/sqlmigrate.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# encoding: utf8
|
||||
from __future__ import unicode_literals
|
||||
from optparse import make_option
|
||||
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.db import connections, DEFAULT_DB_ALIAS
|
||||
from django.db.migrations.executor import MigrationExecutor
|
||||
from django.db.migrations.loader import AmbiguityError
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
|
||||
option_list = BaseCommand.option_list + (
|
||||
make_option('--database', action='store', dest='database',
|
||||
default=DEFAULT_DB_ALIAS, help='Nominates a database to create SQL for. '
|
||||
'Defaults to the "default" database.'),
|
||||
make_option('--backwards', action='store_true', dest='backwards',
|
||||
default=False, help='Creates SQL to unapply the migration, rather than to apply it'),
|
||||
)
|
||||
|
||||
help = "Prints the SQL statements for the named migration."
|
||||
|
||||
def handle(self, *args, **options):
|
||||
|
||||
# Get the database we're operating from
|
||||
db = options.get('database')
|
||||
connection = connections[db]
|
||||
|
||||
# Load up an executor to get all the migration data
|
||||
executor = MigrationExecutor(connection)
|
||||
|
||||
# Resolve command-line arguments into a migration
|
||||
if len(args) != 2:
|
||||
raise CommandError("Wrong number of arguments (expecting 'sqlmigrate appname migrationname')")
|
||||
else:
|
||||
app_label, migration_name = args
|
||||
if app_label not in executor.loader.migrated_apps:
|
||||
raise CommandError("App '%s' does not have migrations" % app_label)
|
||||
try:
|
||||
migration = executor.loader.get_migration_by_prefix(app_label, migration_name)
|
||||
except AmbiguityError:
|
||||
raise CommandError("More than one migration matches '%s' in app '%s'. Please be more specific." % (app_label, migration_name))
|
||||
except KeyError:
|
||||
raise CommandError("Cannot find a migration matching '%s' from app '%s'. Is it in INSTALLED_APPS?" % (app_label, migration_name))
|
||||
targets = [(app_label, migration.name)]
|
||||
|
||||
# Make a plan that represents just the requested migrations and show SQL
|
||||
# for it
|
||||
plan = [(executor.loader.graph.nodes[targets[0]], options.get("backwards", False))]
|
||||
sql_statements = executor.collect_sql(plan)
|
||||
for statement in sql_statements:
|
||||
self.stdout.write(statement)
|
@@ -521,7 +521,7 @@ class BaseDatabaseWrapper(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def schema_editor(self):
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"Returns a new instance of this backend's SchemaEditor"
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -958,6 +958,15 @@ class BaseDatabaseOperations(object):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def quote_parameter(self, value):
|
||||
"""
|
||||
Returns a quoted version of the value so it's safe to use in an SQL
|
||||
string. This should NOT be used to prepare SQL statements to send to
|
||||
the database; it is meant for outputting SQL statements to a file
|
||||
or the console for later execution by a developer/DBA.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def random_function_sql(self):
|
||||
"""
|
||||
Returns an SQL expression that returns a random value.
|
||||
|
@@ -305,6 +305,11 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
return name # Quoting once is enough.
|
||||
return "`%s`" % name
|
||||
|
||||
def quote_parameter(self, value):
|
||||
# Inner import to allow module to fail to load gracefully
|
||||
import MySQLdb.converters
|
||||
return MySQLdb.escape(value, MySQLdb.converters.conversions)
|
||||
|
||||
def random_function_sql(self):
|
||||
return 'RAND()'
|
||||
|
||||
@@ -518,9 +523,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
table_name, column_name, bad_row[1],
|
||||
referenced_table_name, referenced_column_name))
|
||||
|
||||
def schema_editor(self):
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"Returns a new instance of this backend's SchemaEditor"
|
||||
return DatabaseSchemaEditor(self)
|
||||
return DatabaseSchemaEditor(self, *args, **kwargs)
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
|
@@ -320,6 +320,16 @@ WHEN (new.%(col_name)s IS NULL)
|
||||
name = name.replace('%', '%%')
|
||||
return name.upper()
|
||||
|
||||
def quote_parameter(self, value):
|
||||
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
|
||||
return "'%s'" % value
|
||||
elif isinstance(value, six.string_types):
|
||||
return repr(value)
|
||||
elif isinstance(value, bool):
|
||||
return "1" if value else "0"
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def random_function_sql(self):
|
||||
return "DBMS_RANDOM.RANDOM"
|
||||
|
||||
@@ -628,9 +638,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
|
||||
raise
|
||||
|
||||
def schema_editor(self):
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"Returns a new instance of this backend's SchemaEditor"
|
||||
return DatabaseSchemaEditor(self)
|
||||
return DatabaseSchemaEditor(self, *args, **kwargs)
|
||||
|
||||
# Oracle doesn't support savepoint commits. Ignore them.
|
||||
def _savepoint_commit(self, sid):
|
||||
|
@@ -93,11 +93,4 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
return self.normalize_name(for_name + "_" + suffix)
|
||||
|
||||
def prepare_default(self, value):
|
||||
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
|
||||
return "'%s'" % value
|
||||
elif isinstance(value, six.string_types):
|
||||
return repr(value)
|
||||
elif isinstance(value, bool):
|
||||
return "1" if value else "0"
|
||||
else:
|
||||
return str(value)
|
||||
return self.connection.ops.quote_parameter(value)
|
||||
|
@@ -205,9 +205,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
else:
|
||||
return True
|
||||
|
||||
def schema_editor(self):
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"Returns a new instance of this backend's SchemaEditor"
|
||||
return DatabaseSchemaEditor(self)
|
||||
return DatabaseSchemaEditor(self, *args, **kwargs)
|
||||
|
||||
@cached_property
|
||||
def psycopg2_version(self):
|
||||
|
@@ -98,6 +98,11 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def quote_parameter(self, value):
|
||||
# Inner import so backend fails nicely if it's not present
|
||||
import psycopg2
|
||||
return psycopg2.extensions.adapt(value)
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
return "SET TIME ZONE %s"
|
||||
|
||||
|
@@ -54,14 +54,17 @@ class BaseDatabaseSchemaEditor(object):
|
||||
sql_create_fk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
|
||||
sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
|
||||
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s;"
|
||||
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
|
||||
sql_delete_index = "DROP INDEX %(name)s"
|
||||
|
||||
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
|
||||
sql_delete_pk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection, collect_sql=False):
|
||||
self.connection = connection
|
||||
self.collect_sql = collect_sql
|
||||
if self.collect_sql:
|
||||
self.collected_sql = []
|
||||
|
||||
# State-managing methods
|
||||
|
||||
@@ -86,7 +89,10 @@ class BaseDatabaseSchemaEditor(object):
|
||||
cursor = self.connection.cursor()
|
||||
# Log the command we're running, then run it
|
||||
logger.debug("%s; (params %r)" % (sql, params))
|
||||
cursor.execute(sql, params)
|
||||
if self.collect_sql:
|
||||
self.collected_sql.append((sql % map(self.connection.ops.quote_parameter, params)) + ";")
|
||||
else:
|
||||
cursor.execute(sql, params)
|
||||
|
||||
def quote_name(self, name):
|
||||
return self.connection.ops.quote_name(name)
|
||||
|
@@ -214,6 +214,25 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def quote_parameter(self, value):
|
||||
# Inner import to allow nice failure for backend if not present
|
||||
import _sqlite3
|
||||
try:
|
||||
value = _sqlite3.adapt(value)
|
||||
except _sqlite3.ProgrammingError:
|
||||
pass
|
||||
# Manual emulation of SQLite parameter quoting
|
||||
if isinstance(value, six.integer_types):
|
||||
return str(value)
|
||||
elif isinstance(value, six.string_types):
|
||||
return six.text_type(value)
|
||||
elif isinstance(value, type(True)):
|
||||
return str(int(value))
|
||||
elif value is None:
|
||||
return "NULL"
|
||||
else:
|
||||
raise ValueError("Cannot quote parameter value %r" % value)
|
||||
|
||||
def no_limit_value(self):
|
||||
return -1
|
||||
|
||||
@@ -437,9 +456,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
"""
|
||||
self.cursor().execute("BEGIN")
|
||||
|
||||
def schema_editor(self):
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"Returns a new instance of this backend's SchemaEditor"
|
||||
return DatabaseSchemaEditor(self)
|
||||
return DatabaseSchemaEditor(self, *args, **kwargs)
|
||||
|
||||
FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')
|
||||
|
||||
|
@@ -55,7 +55,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
self.create_model(temp_model)
|
||||
# Copy data from the old table
|
||||
field_maps = list(mapping.items())
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % (
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
|
||||
self.quote_name(temp_model._meta.db_table),
|
||||
', '.join(x for x, y in field_maps),
|
||||
', '.join(y for x, y in field_maps),
|
||||
@@ -137,7 +137,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
# Make a new through table
|
||||
self.create_model(new_field.rel.through)
|
||||
# Copy the data across
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % (
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
|
||||
self.quote_name(new_field.rel.through._meta.db_table),
|
||||
', '.join([
|
||||
"id",
|
||||
|
@@ -61,6 +61,22 @@ class MigrationExecutor(object):
|
||||
else:
|
||||
self.unapply_migration(migration, fake=fake)
|
||||
|
||||
def collect_sql(self, plan):
|
||||
"""
|
||||
Takes a migration plan and returns a list of collected SQL
|
||||
statements that represent the best-efforts version of that plan.
|
||||
"""
|
||||
statements = []
|
||||
for migration, backwards in plan:
|
||||
with self.connection.schema_editor(collect_sql=True) as schema_editor:
|
||||
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
|
||||
if not backwards:
|
||||
migration.apply(project_state, schema_editor)
|
||||
else:
|
||||
migration.unapply(project_state, schema_editor)
|
||||
statements.extend(schema_editor.collected_sql)
|
||||
return statements
|
||||
|
||||
def apply_migration(self, migration, fake=False):
|
||||
"""
|
||||
Runs a migration forwards.
|
||||
|
@@ -993,6 +993,24 @@ Prints the CREATE INDEX SQL statements for the given app name(s).
|
||||
The :djadminopt:`--database` option can be used to specify the database for
|
||||
which to print the SQL.
|
||||
|
||||
sqlmigrate <appname> <migrationname>
|
||||
------------------------------------
|
||||
|
||||
.. django-admin:: sqlmigrate
|
||||
|
||||
Prints the SQL for the named migration. This requires an active database
|
||||
connection, which it will use to resolve constraint names; this means you must
|
||||
generate the SQL against a copy of the database you wish to later apply it on.
|
||||
|
||||
The :djadminopt:`--database` option can be used to specify the database for
|
||||
which to generate the SQL.
|
||||
|
||||
.. django-admin-option:: --backwards
|
||||
|
||||
By default, the SQL created is for running the migration in the forwards
|
||||
direction. Pass ``--backwards`` to generate the SQL for
|
||||
un-applying the migration instead.
|
||||
|
||||
sqlsequencereset <appname appname ...>
|
||||
--------------------------------------
|
||||
|
||||
|
@@ -48,6 +48,20 @@ class MigrateTests(MigrationTestBase):
|
||||
self.assertTableNotExists("migrations_tribble")
|
||||
self.assertTableNotExists("migrations_book")
|
||||
|
||||
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations"})
|
||||
def test_sqlmigrate(self):
|
||||
"""
|
||||
Makes sure that sqlmigrate does something.
|
||||
"""
|
||||
# Test forwards. All the databases agree on CREATE TABLE, at least.
|
||||
stdout = six.StringIO()
|
||||
call_command("sqlmigrate", "migrations", "0001", stdout=stdout)
|
||||
self.assertIn("create table", stdout.getvalue().lower())
|
||||
# And backwards is a DROP TABLE
|
||||
stdout = six.StringIO()
|
||||
call_command("sqlmigrate", "migrations", "0001", stdout=stdout, backwards=True)
|
||||
self.assertIn("drop table", stdout.getvalue().lower())
|
||||
|
||||
|
||||
class MakeMigrationsTests(MigrationTestBase):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user