1
0
mirror of https://github.com/django/django.git synced 2025-08-16 06:49:16 +00:00

Fixed #5461 -- Refactored the database backend code to use classes for the creation and introspection modules. Introduces a new validation module for DB-specific validation. This is a backwards incompatible change; see the wiki for details.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@8296 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2008-08-11 12:11:25 +00:00
parent cec69eb70d
commit 9dc4ba875f
43 changed files with 1528 additions and 1423 deletions

View File

@ -1,5 +1,5 @@
from django.test.utils import create_test_db
def create_spatial_db(test=True, verbosity=1, autoclobber=False): def create_spatial_db(test=True, verbosity=1, autoclobber=False):
if not test: raise NotImplementedError('This uses `create_test_db` from test/utils.py') if not test: raise NotImplementedError('This uses `create_test_db` from test/utils.py')
create_test_db(verbosity, autoclobber) from django.db import connection
connection.creation.create_test_db(verbosity, autoclobber)

View File

@ -1,8 +1,6 @@
from django.db.backends.oracle.creation import create_test_db
def create_spatial_db(test=True, verbosity=1, autoclobber=False): def create_spatial_db(test=True, verbosity=1, autoclobber=False):
"A wrapper over the Oracle `create_test_db` routine." "A wrapper over the Oracle `create_test_db` routine."
if not test: raise NotImplementedError('This uses `create_test_db` from db/backends/oracle/creation.py') if not test: raise NotImplementedError('This uses `create_test_db` from db/backends/oracle/creation.py')
from django.conf import settings
from django.db import connection from django.db import connection
create_test_db(settings, connection, verbosity, autoclobber) connection.creation.create_test_db(verbosity, autoclobber)

View File

@ -1,7 +1,7 @@
from django.conf import settings from django.conf import settings
from django.core.management import call_command from django.core.management import call_command
from django.db import connection from django.db import connection
from django.test.utils import _set_autocommit, TEST_DATABASE_PREFIX from django.db.backends.creation import TEST_DATABASE_PREFIX
import os, re, sys import os, re, sys
def getstatusoutput(cmd): def getstatusoutput(cmd):
@ -40,7 +40,7 @@ def _create_with_cursor(db_name, verbosity=1, autoclobber=False):
create_sql += ' OWNER %s' % settings.DATABASE_USER create_sql += ' OWNER %s' % settings.DATABASE_USER
cursor = connection.cursor() cursor = connection.cursor()
_set_autocommit(connection) connection.creation.set_autocommit(connection)
try: try:
# Trying to create the database first. # Trying to create the database first.

View File

@ -67,11 +67,9 @@ class Command(InspectCommand):
def handle_inspection(self): def handle_inspection(self):
"Overloaded from Django's version to handle geographic database tables." "Overloaded from Django's version to handle geographic database tables."
from django.db import connection, get_introspection_module from django.db import connection
import keyword import keyword
introspection_module = get_introspection_module()
geo_cols = self.geometry_columns() geo_cols = self.geometry_columns()
table2model = lambda table_name: table_name.title().replace('_', '') table2model = lambda table_name: table_name.title().replace('_', '')
@ -88,20 +86,20 @@ class Command(InspectCommand):
yield '' yield ''
yield 'from django.contrib.gis.db import models' yield 'from django.contrib.gis.db import models'
yield '' yield ''
for table_name in introspection_module.get_table_list(cursor): for table_name in connection.introspection.get_table_list(cursor):
# Getting the geographic table dictionary. # Getting the geographic table dictionary.
geo_table = geo_cols.get(table_name, {}) geo_table = geo_cols.get(table_name, {})
yield 'class %s(models.Model):' % table2model(table_name) yield 'class %s(models.Model):' % table2model(table_name)
try: try:
relations = introspection_module.get_relations(cursor, table_name) relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError: except NotImplementedError:
relations = {} relations = {}
try: try:
indexes = introspection_module.get_indexes(cursor, table_name) indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError: except NotImplementedError:
indexes = {} indexes = {}
for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)): for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
att_name, iatt_name = row[0].lower(), row[0] att_name, iatt_name = row[0].lower(), row[0]
comment_notes = [] # Holds Field notes, to be displayed in a Python comment. comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'. extra_params = {} # Holds Field parameters such as 'db_column'.
@ -133,12 +131,12 @@ class Command(InspectCommand):
if srid != 4326: extra_params['srid'] = srid if srid != 4326: extra_params['srid'] = srid
else: else:
try: try:
field_type = introspection_module.DATA_TYPES_REVERSE[row[1]] field_type = connection.introspection.data_types_reverse[row[1]]
except KeyError: except KeyError:
field_type = 'TextField' field_type = 'TextField'
comment_notes.append('This field type is a guess.') comment_notes.append('This field type is a guess.')
# This is a hook for DATA_TYPES_REVERSE to return a tuple of # This is a hook for data_types_reverse to return a tuple of
# (field_type, extra_params_dict). # (field_type, extra_params_dict).
if type(field_type) is tuple: if type(field_type) is tuple:
field_type, new_params = field_type field_type, new_params = field_type

View File

@ -6,5 +6,5 @@ class Command(NoArgsCommand):
requires_model_validation = False requires_model_validation = False
def handle_noargs(self, **options): def handle_noargs(self, **options):
from django.db import runshell from django.db import connection
runshell() connection.client.runshell()

View File

@ -13,11 +13,9 @@ class Command(NoArgsCommand):
raise CommandError("Database inspection isn't supported for the currently selected database backend.") raise CommandError("Database inspection isn't supported for the currently selected database backend.")
def handle_inspection(self): def handle_inspection(self):
from django.db import connection, get_introspection_module from django.db import connection
import keyword import keyword
introspection_module = get_introspection_module()
table2model = lambda table_name: table_name.title().replace('_', '') table2model = lambda table_name: table_name.title().replace('_', '')
cursor = connection.cursor() cursor = connection.cursor()
@ -32,17 +30,17 @@ class Command(NoArgsCommand):
yield '' yield ''
yield 'from django.db import models' yield 'from django.db import models'
yield '' yield ''
for table_name in introspection_module.get_table_list(cursor): for table_name in connection.introspection.get_table_list(cursor):
yield 'class %s(models.Model):' % table2model(table_name) yield 'class %s(models.Model):' % table2model(table_name)
try: try:
relations = introspection_module.get_relations(cursor, table_name) relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError: except NotImplementedError:
relations = {} relations = {}
try: try:
indexes = introspection_module.get_indexes(cursor, table_name) indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError: except NotImplementedError:
indexes = {} indexes = {}
for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)): for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
att_name = row[0].lower() att_name = row[0].lower()
comment_notes = [] # Holds Field notes, to be displayed in a Python comment. comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'. extra_params = {} # Holds Field parameters such as 'db_column'.
@ -65,7 +63,7 @@ class Command(NoArgsCommand):
extra_params['db_column'] = att_name extra_params['db_column'] = att_name
else: else:
try: try:
field_type = introspection_module.DATA_TYPES_REVERSE[row[1]] field_type = connection.introspection.data_types_reverse[row[1]]
except KeyError: except KeyError:
field_type = 'TextField' field_type = 'TextField'
comment_notes.append('This field type is a guess.') comment_notes.append('This field type is a guess.')

View File

@ -21,7 +21,7 @@ class Command(NoArgsCommand):
def handle_noargs(self, **options): def handle_noargs(self, **options):
from django.db import connection, transaction, models from django.db import connection, transaction, models
from django.conf import settings from django.conf import settings
from django.core.management.sql import table_names, installed_models, sql_model_create, sql_for_pending_references, many_to_many_sql_for_model, custom_sql_for_model, sql_indexes_for_model, emit_post_sync_signal from django.core.management.sql import custom_sql_for_model, emit_post_sync_signal
verbosity = int(options.get('verbosity', 1)) verbosity = int(options.get('verbosity', 1))
interactive = options.get('interactive') interactive = options.get('interactive')
@ -50,16 +50,9 @@ class Command(NoArgsCommand):
cursor = connection.cursor() cursor = connection.cursor()
if connection.features.uses_case_insensitive_names:
table_name_converter = lambda x: x.upper()
else:
table_name_converter = lambda x: x
# Get a list of all existing database tables, so we know what needs to
# be added.
tables = [table_name_converter(name) for name in table_names()]
# Get a list of already installed *models* so that references work right. # Get a list of already installed *models* so that references work right.
seen_models = installed_models(tables) tables = connection.introspection.table_names()
seen_models = connection.introspection.installed_models(tables)
created_models = set() created_models = set()
pending_references = {} pending_references = {}
@ -71,21 +64,21 @@ class Command(NoArgsCommand):
# Create the model's database table, if it doesn't already exist. # Create the model's database table, if it doesn't already exist.
if verbosity >= 2: if verbosity >= 2:
print "Processing %s.%s model" % (app_name, model._meta.object_name) print "Processing %s.%s model" % (app_name, model._meta.object_name)
if table_name_converter(model._meta.db_table) in tables: if connection.introspection.table_name_converter(model._meta.db_table) in tables:
continue continue
sql, references = sql_model_create(model, self.style, seen_models) sql, references = connection.creation.sql_create_model(model, self.style, seen_models)
seen_models.add(model) seen_models.add(model)
created_models.add(model) created_models.add(model)
for refto, refs in references.items(): for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs) pending_references.setdefault(refto, []).extend(refs)
if refto in seen_models: if refto in seen_models:
sql.extend(sql_for_pending_references(refto, self.style, pending_references)) sql.extend(connection.creation.sql_for_pending_references(refto, self.style, pending_references))
sql.extend(sql_for_pending_references(model, self.style, pending_references)) sql.extend(connection.creation.sql_for_pending_references(model, self.style, pending_references))
if verbosity >= 1: if verbosity >= 1:
print "Creating table %s" % model._meta.db_table print "Creating table %s" % model._meta.db_table
for statement in sql: for statement in sql:
cursor.execute(statement) cursor.execute(statement)
tables.append(table_name_converter(model._meta.db_table)) tables.append(connection.introspection.table_name_converter(model._meta.db_table))
# Create the m2m tables. This must be done after all tables have been created # Create the m2m tables. This must be done after all tables have been created
# to ensure that all referred tables will exist. # to ensure that all referred tables will exist.
@ -94,7 +87,7 @@ class Command(NoArgsCommand):
model_list = models.get_models(app) model_list = models.get_models(app)
for model in model_list: for model in model_list:
if model in created_models: if model in created_models:
sql = many_to_many_sql_for_model(model, self.style) sql = connection.creation.sql_for_many_to_many(model, self.style)
if sql: if sql:
if verbosity >= 2: if verbosity >= 2:
print "Creating many-to-many tables for %s.%s model" % (app_name, model._meta.object_name) print "Creating many-to-many tables for %s.%s model" % (app_name, model._meta.object_name)
@ -140,7 +133,7 @@ class Command(NoArgsCommand):
app_name = app.__name__.split('.')[-2] app_name = app.__name__.split('.')[-2]
for model in models.get_models(app): for model in models.get_models(app):
if model in created_models: if model in created_models:
index_sql = sql_indexes_for_model(model, self.style) index_sql = connection.creation.sql_indexes_for_model(model, self.style)
if index_sql: if index_sql:
if verbosity >= 1: if verbosity >= 1:
print "Installing index for %s.%s model" % (app_name, model._meta.object_name) print "Installing index for %s.%s model" % (app_name, model._meta.object_name)

View File

@ -18,13 +18,13 @@ class Command(BaseCommand):
def handle(self, *fixture_labels, **options): def handle(self, *fixture_labels, **options):
from django.core.management import call_command from django.core.management import call_command
from django.test.utils import create_test_db from django.db import connection
verbosity = int(options.get('verbosity', 1)) verbosity = int(options.get('verbosity', 1))
addrport = options.get('addrport') addrport = options.get('addrport')
# Create a test database. # Create a test database.
db_name = create_test_db(verbosity=verbosity) db_name = connection.creation.create_test_db(verbosity=verbosity)
# Import the fixture data into the test database. # Import the fixture data into the test database.
call_command('loaddata', *fixture_labels, **{'verbosity': verbosity}) call_command('loaddata', *fixture_labels, **{'verbosity': verbosity})

View File

@ -7,65 +7,9 @@ try:
except NameError: except NameError:
from sets import Set as set # Python 2.3 fallback from sets import Set as set # Python 2.3 fallback
def table_names():
"Returns a list of all table names that exist in the database."
from django.db import connection, get_introspection_module
cursor = connection.cursor()
return set(get_introspection_module().get_table_list(cursor))
def django_table_names(only_existing=False):
"""
Returns a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, the resulting list will only include the tables
that actually exist in the database.
"""
from django.db import models
tables = set()
for app in models.get_apps():
for model in models.get_models(app):
tables.add(model._meta.db_table)
tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
tables = [t for t in tables if t in table_names()]
return tables
def installed_models(table_list):
"Returns a set of all models that are installed, given a list of existing table names."
from django.db import connection, models
all_models = []
for app in models.get_apps():
for model in models.get_models(app):
all_models.append(model)
if connection.features.uses_case_insensitive_names:
converter = lambda x: x.upper()
else:
converter = lambda x: x
return set([m for m in all_models if converter(m._meta.db_table) in map(converter, table_list)])
def sequence_list():
"Returns a list of information about all DB sequences for all models in all apps."
from django.db import models
apps = models.get_apps()
sequence_list = []
for app in apps:
for model in models.get_models(app):
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
def sql_create(app, style): def sql_create(app, style):
"Returns a list of the CREATE TABLE SQL statements for the given app." "Returns a list of the CREATE TABLE SQL statements for the given app."
from django.db import models from django.db import connection, models
from django.conf import settings from django.conf import settings
if settings.DATABASE_ENGINE == 'dummy': if settings.DATABASE_ENGINE == 'dummy':
@ -81,23 +25,24 @@ def sql_create(app, style):
# we can be conservative). # we can be conservative).
app_models = models.get_models(app) app_models = models.get_models(app)
final_output = [] final_output = []
known_models = set([model for model in installed_models(table_names()) if model not in app_models]) tables = connection.introspection.table_names()
known_models = set([model for model in connection.introspection.installed_models(tables) if model not in app_models])
pending_references = {} pending_references = {}
for model in app_models: for model in app_models:
output, references = sql_model_create(model, style, known_models) output, references = connection.creation.sql_create_model(model, style, known_models)
final_output.extend(output) final_output.extend(output)
for refto, refs in references.items(): for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs) pending_references.setdefault(refto, []).extend(refs)
if refto in known_models: if refto in known_models:
final_output.extend(sql_for_pending_references(refto, style, pending_references)) final_output.extend(connection.creation.sql_for_pending_references(refto, style, pending_references))
final_output.extend(sql_for_pending_references(model, style, pending_references)) final_output.extend(connection.creation.sql_for_pending_references(model, style, pending_references))
# Keep track of the fact that we've created the table for this model. # Keep track of the fact that we've created the table for this model.
known_models.add(model) known_models.add(model)
# Create the many-to-many join tables. # Create the many-to-many join tables.
for model in app_models: for model in app_models:
final_output.extend(many_to_many_sql_for_model(model, style)) final_output.extend(connection.creation.sql_for_many_to_many(model, style))
# Handle references to tables that are from other apps # Handle references to tables that are from other apps
# but don't exist physically. # but don't exist physically.
@ -106,7 +51,7 @@ def sql_create(app, style):
alter_sql = [] alter_sql = []
for model in not_installed_models: for model in not_installed_models:
alter_sql.extend(['-- ' + sql for sql in alter_sql.extend(['-- ' + sql for sql in
sql_for_pending_references(model, style, pending_references)]) connection.creation.sql_for_pending_references(model, style, pending_references)])
if alter_sql: if alter_sql:
final_output.append('-- The following references should be added but depend on non-existent tables:') final_output.append('-- The following references should be added but depend on non-existent tables:')
final_output.extend(alter_sql) final_output.extend(alter_sql)
@ -115,10 +60,9 @@ def sql_create(app, style):
def sql_delete(app, style): def sql_delete(app, style):
"Returns a list of the DROP TABLE SQL statements for the given app." "Returns a list of the DROP TABLE SQL statements for the given app."
from django.db import connection, models, get_introspection_module from django.db import connection, models
from django.db.backends.util import truncate_name from django.db.backends.util import truncate_name
from django.contrib.contenttypes import generic from django.contrib.contenttypes import generic
introspection = get_introspection_module()
# This should work even if a connection isn't available # This should work even if a connection isn't available
try: try:
@ -128,16 +72,11 @@ def sql_delete(app, style):
# Figure out which tables already exist # Figure out which tables already exist
if cursor: if cursor:
table_names = introspection.get_table_list(cursor) table_names = connection.introspection.get_table_list(cursor)
else: else:
table_names = [] table_names = []
if connection.features.uses_case_insensitive_names:
table_name_converter = lambda x: x.upper()
else:
table_name_converter = lambda x: x
output = [] output = []
qn = connection.ops.quote_name
# Output DROP TABLE statements for standard application tables. # Output DROP TABLE statements for standard application tables.
to_delete = set() to_delete = set()
@ -145,7 +84,7 @@ def sql_delete(app, style):
references_to_delete = {} references_to_delete = {}
app_models = models.get_models(app) app_models = models.get_models(app)
for model in app_models: for model in app_models:
if cursor and table_name_converter(model._meta.db_table) in table_names: if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped # The table exists, so it needs to be dropped
opts = model._meta opts = model._meta
for f in opts.local_fields: for f in opts.local_fields:
@ -155,40 +94,15 @@ def sql_delete(app, style):
to_delete.add(model) to_delete.add(model)
for model in app_models: for model in app_models:
if cursor and table_name_converter(model._meta.db_table) in table_names: if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# Drop the table now output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
output.append('%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(model._meta.db_table))))
if connection.features.supports_constraints and model in references_to_delete:
for rel_class, f in references_to_delete[model]:
table = rel_class._meta.db_table
col = f.column
r_table = model._meta.db_table
r_col = model._meta.get_field(f.rel.field_name).column
r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table))))
output.append('%s %s %s %s;' % \
(style.SQL_KEYWORD('ALTER TABLE'),
style.SQL_TABLE(qn(table)),
style.SQL_KEYWORD(connection.ops.drop_foreignkey_sql()),
style.SQL_FIELD(truncate_name(r_name, connection.ops.max_name_length()))))
del references_to_delete[model]
if model._meta.has_auto_field:
ds = connection.ops.drop_sequence_sql(model._meta.db_table)
if ds:
output.append(ds)
# Output DROP TABLE statements for many-to-many tables. # Output DROP TABLE statements for many-to-many tables.
for model in app_models: for model in app_models:
opts = model._meta opts = model._meta
for f in opts.local_many_to_many: for f in opts.local_many_to_many:
if not f.creates_table: if cursor and connection.introspection.table_name_converter(f.m2m_db_table()) in table_names:
continue output.extend(connection.creation.sql_destroy_many_to_many(model, f, style))
if cursor and table_name_converter(f.m2m_db_table()) in table_names:
output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(f.m2m_db_table()))))
ds = connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column))
if ds:
output.append(ds)
app_label = app_models[0]._meta.app_label app_label = app_models[0]._meta.app_label
@ -213,10 +127,10 @@ def sql_flush(style, only_django=False):
""" """
from django.db import connection from django.db import connection
if only_django: if only_django:
tables = django_table_names() tables = connection.introspection.django_table_names()
else: else:
tables = table_names() tables = connection.introspection.table_names()
statements = connection.ops.sql_flush(style, tables, sequence_list()) statements = connection.ops.sql_flush(style, tables, connection.introspection.sequence_list())
return statements return statements
def sql_custom(app, style): def sql_custom(app, style):
@ -234,198 +148,16 @@ def sql_custom(app, style):
def sql_indexes(app, style): def sql_indexes(app, style):
"Returns a list of the CREATE INDEX SQL statements for all models in the given app." "Returns a list of the CREATE INDEX SQL statements for all models in the given app."
from django.db import models from django.db import connection, models
output = [] output = []
for model in models.get_models(app): for model in models.get_models(app):
output.extend(sql_indexes_for_model(model, style)) output.extend(connection.creation.sql_indexes_for_model(model, style))
return output return output
def sql_all(app, style): def sql_all(app, style):
"Returns a list of CREATE TABLE SQL, initial-data inserts, and CREATE INDEX SQL for the given module." "Returns a list of CREATE TABLE SQL, initial-data inserts, and CREATE INDEX SQL for the given module."
return sql_create(app, style) + sql_custom(app, style) + sql_indexes(app, style) return sql_create(app, style) + sql_custom(app, style) + sql_indexes(app, style)
def sql_model_create(model, style, known_models=set()):
"""
Returns the SQL required to create a single model, as a tuple of:
(list_of_sql, pending_references_dict)
"""
from django.db import connection, models
opts = model._meta
final_output = []
table_output = []
pending_references = {}
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
# Skip ManyToManyFields, because they're not represented as
# database columns in this table.
continue
# Make the definition (e.g. 'foo VARCHAR(30)') for this field.
field_output = [style.SQL_FIELD(qn(f.column)),
style.SQL_COLTYPE(col_type)]
field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or '')))
if f.primary_key:
field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))
elif f.unique:
field_output.append(style.SQL_KEYWORD('UNIQUE'))
if tablespace and connection.features.supports_tablespaces and f.unique:
# We must specify the index tablespace inline, because we
# won't be generating a CREATE INDEX statement for this field.
field_output.append(connection.ops.tablespace_sql(tablespace, inline=True))
if f.rel:
if inline_references and f.rel.to in known_models:
field_output.append(style.SQL_KEYWORD('REFERENCES') + ' ' + \
style.SQL_TABLE(qn(f.rel.to._meta.db_table)) + ' (' + \
style.SQL_FIELD(qn(f.rel.to._meta.get_field(f.rel.field_name).column)) + ')' +
connection.ops.deferrable_sql()
)
else:
# We haven't yet created the table to which this field
# is related, so save it for later.
pr = pending_references.setdefault(f.rel.to, []).append((model, f))
table_output.append(' '.join(field_output))
if opts.order_with_respect_to:
table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \
style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \
style.SQL_KEYWORD('NULL'))
for field_constraints in opts.unique_together:
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
for i, line in enumerate(table_output): # Combine and add commas.
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
full_statement.append(')')
if opts.db_tablespace and connection.features.supports_tablespaces:
full_statement.append(connection.ops.tablespace_sql(opts.db_tablespace))
full_statement.append(';')
final_output.append('\n'.join(full_statement))
if opts.has_auto_field:
# Add any extra SQL needed to support auto-incrementing primary keys.
auto_column = opts.auto_field.db_column or opts.auto_field.name
autoinc_sql = connection.ops.autoinc_sql(opts.db_table, auto_column)
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output, pending_references
def sql_for_pending_references(model, style, pending_references):
"""
Returns any ALTER TABLE statements to add constraints after the fact.
"""
from django.db import connection
from django.db.backends.util import truncate_name
qn = connection.ops.quote_name
final_output = []
if connection.features.supports_constraints:
opts = model._meta
if model in pending_references:
for rel_class, f in pending_references[model]:
rel_opts = rel_class._meta
r_table = rel_opts.db_table
r_col = f.column
table = opts.db_table
col = opts.get_field(f.rel.field_name).column
# For MySQL, r_name must be unique in the first 64 characters.
# So we are careful with character usage here.
r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
(qn(r_table), truncate_name(r_name, connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
connection.ops.deferrable_sql()))
del pending_references[model]
return final_output
def many_to_many_sql_for_model(model, style):
from django.db import connection, models
from django.contrib.contenttypes import generic
from django.db.backends.util import truncate_name
opts = model._meta
final_output = []
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
for f in opts.local_many_to_many:
if f.creates_table:
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace and connection.features.supports_tablespaces:
tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
else:
tablespace_sql = ''
table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \
style.SQL_TABLE(qn(f.m2m_db_table())) + ' (']
table_output.append(' %s %s %s%s,' %
(style.SQL_FIELD(qn('id')),
style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()),
style.SQL_KEYWORD('NOT NULL PRIMARY KEY'),
tablespace_sql))
if inline_references:
deferred = []
table_output.append(' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(opts.db_table)),
style.SQL_FIELD(qn(opts.pk.column)),
connection.ops.deferrable_sql()))
table_output.append(' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(f.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(f.rel.to._meta.db_table)),
style.SQL_FIELD(qn(f.rel.to._meta.pk.column)),
connection.ops.deferrable_sql()))
else:
table_output.append(' %s %s %s,' %
(style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL')))
table_output.append(' %s %s %s,' %
(style.SQL_FIELD(qn(f.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL')))
deferred = [
(f.m2m_db_table(), f.m2m_column_name(), opts.db_table,
opts.pk.column),
( f.m2m_db_table(), f.m2m_reverse_name(),
f.rel.to._meta.db_table, f.rel.to._meta.pk.column)
]
table_output.append(' %s (%s, %s)%s' %
(style.SQL_KEYWORD('UNIQUE'),
style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_FIELD(qn(f.m2m_reverse_name())),
tablespace_sql))
table_output.append(')')
if opts.db_tablespace and connection.features.supports_tablespaces:
# f.db_tablespace is only for indices, so ignore its value here.
table_output.append(connection.ops.tablespace_sql(opts.db_tablespace))
table_output.append(';')
final_output.append('\n'.join(table_output))
for r_table, r_col, table, col in deferred:
r_name = '%s_refs_%s_%x' % (r_col, col,
abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
(qn(r_table),
truncate_name(r_name, connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
connection.ops.deferrable_sql()))
# Add any extra SQL needed to support auto-incrementing PKs
autoinc_sql = connection.ops.autoinc_sql(f.m2m_db_table(), 'id')
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output
def custom_sql_for_model(model, style): def custom_sql_for_model(model, style):
from django.db import models from django.db import models
from django.conf import settings from django.conf import settings
@ -461,28 +193,6 @@ def custom_sql_for_model(model, style):
return output return output
def sql_indexes_for_model(model, style):
"Returns the CREATE INDEX SQL statements for a single model"
from django.db import connection
output = []
qn = connection.ops.quote_name
for f in model._meta.local_fields:
if f.db_index and not f.unique:
tablespace = f.db_tablespace or model._meta.db_tablespace
if tablespace and connection.features.supports_tablespaces:
tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace)
else:
tablespace_sql = ''
output.append(
style.SQL_KEYWORD('CREATE INDEX') + ' ' + \
style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' + \
style.SQL_KEYWORD('ON') + ' ' + \
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' + \
"(%s)" % style.SQL_FIELD(qn(f.column)) + \
"%s;" % tablespace_sql
)
return output
def emit_post_sync_signal(created_models, verbosity, interactive): def emit_post_sync_signal(created_models, verbosity, interactive):
from django.db import models from django.db import models

View File

@ -61,11 +61,8 @@ def get_validation_errors(outfile, app=None):
if f.db_index not in (None, True, False): if f.db_index not in (None, True, False):
e.add(opts, '"%s": "db_index" should be either None, True or False.' % f.name) e.add(opts, '"%s": "db_index" should be either None, True or False.' % f.name)
# Check that max_length <= 255 if using older MySQL versions. # Perform any backend-specific field validation.
if settings.DATABASE_ENGINE == 'mysql': connection.validation.validate_field(e, opts, f)
db_version = connection.get_server_version()
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
e.add(opts, '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))
# Check to see if the related field will clash with any existing # Check to see if the related field will clash with any existing
# fields, m2m fields, m2m related objects or related objects # fields, m2m fields, m2m related objects or related objects

View File

@ -14,14 +14,12 @@ try:
# backends that ships with Django, so look there first. # backends that ships with Django, so look there first.
_import_path = 'django.db.backends.' _import_path = 'django.db.backends.'
backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, ['']) backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
creation = __import__('%s%s.creation' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
except ImportError, e: except ImportError, e:
# If the import failed, we might be looking for a database backend # If the import failed, we might be looking for a database backend
# distributed external to Django. So we'll try that next. # distributed external to Django. So we'll try that next.
try: try:
_import_path = '' _import_path = ''
backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, ['']) backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, [''])
creation = __import__('%s.creation' % settings.DATABASE_ENGINE, {}, {}, [''])
except ImportError, e_user: except ImportError, e_user:
# The database backend wasn't found. Display a helpful error message # The database backend wasn't found. Display a helpful error message
# listing all possible (built-in) database backends. # listing all possible (built-in) database backends.
@ -29,27 +27,11 @@ except ImportError, e:
available_backends = [f for f in os.listdir(backend_dir) if not f.startswith('_') and not f.startswith('.') and not f.endswith('.py') and not f.endswith('.pyc')] available_backends = [f for f in os.listdir(backend_dir) if not f.startswith('_') and not f.startswith('.') and not f.endswith('.py') and not f.endswith('.pyc')]
available_backends.sort() available_backends.sort()
if settings.DATABASE_ENGINE not in available_backends: if settings.DATABASE_ENGINE not in available_backends:
raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s" % \ raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s\nError was: %s" % \
(settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends))) (settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends, e_user)))
else: else:
raise # If there's some other error, this must be an error in Django itself. raise # If there's some other error, this must be an error in Django itself.
def _import_database_module(import_path='', module_name=''):
"""Lazily import a database module when requested."""
return __import__('%s%s.%s' % (import_path, settings.DATABASE_ENGINE, module_name), {}, {}, [''])
# We don't want to import the introspect module unless someone asks for it, so
# lazily load it on demmand.
get_introspection_module = curry(_import_database_module, _import_path, 'introspection')
def get_creation_module():
return creation
# We want runshell() to work the same way, but we have to treat it a
# little differently (since it just runs instead of returning a module like
# the above) and wrap the lazily-loaded runshell() method.
runshell = lambda: _import_database_module(_import_path, "client").runshell()
# Convenient aliases for backend bits. # Convenient aliases for backend bits.
connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS) connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS)
DatabaseError = backend.DatabaseError DatabaseError = backend.DatabaseError

View File

@ -42,14 +42,9 @@ class BaseDatabaseWrapper(local):
return util.CursorDebugWrapper(cursor, self) return util.CursorDebugWrapper(cursor, self)
class BaseDatabaseFeatures(object): class BaseDatabaseFeatures(object):
allows_group_by_ordinal = True
inline_fk_references = True
# True if django.db.backend.utils.typecast_timestamp is used on values # True if django.db.backend.utils.typecast_timestamp is used on values
# returned from dates() calls. # returned from dates() calls.
needs_datetime_string_cast = True needs_datetime_string_cast = True
supports_constraints = True
supports_tablespaces = False
uses_case_insensitive_names = False
uses_custom_query_class = False uses_custom_query_class = False
empty_fetchmany_value = [] empty_fetchmany_value = []
update_can_self_select = True update_can_self_select = True
@ -253,12 +248,12 @@ class BaseDatabaseOperations(object):
""" """
return "BEGIN;" return "BEGIN;"
def tablespace_sql(self, tablespace, inline=False): def sql_for_tablespace(self, tablespace, inline=False):
""" """
Returns the tablespace SQL, or None if the backend doesn't use Returns the SQL that will be appended to tables or rows to define
tablespaces. a tablespace. Returns '' if the backend doesn't use tablespaces.
""" """
return None return ''
def prep_for_like_query(self, x): def prep_for_like_query(self, x):
"""Prepares a value for use in a LIKE query.""" """Prepares a value for use in a LIKE query."""
@ -325,3 +320,89 @@ class BaseDatabaseOperations(object):
""" """
return self.year_lookup_bounds(value) return self.year_lookup_bounds(value)
class BaseDatabaseIntrospection(object):
"""
This class encapsulates all backend-specific introspection utilities
"""
data_types_reverse = {}
def __init__(self, connection):
self.connection = connection
def table_name_converter(self, name):
"""Apply a conversion to the name for the purposes of comparison.
The default table name converter is for case sensitive comparison.
"""
return name
def table_names(self):
"Returns a list of names of all tables that exist in the database."
cursor = self.connection.cursor()
return self.get_table_list(cursor)
def django_table_names(self, only_existing=False):
"""
Returns a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, the resulting list will only include the tables
that actually exist in the database.
"""
from django.db import models
tables = set()
for app in models.get_apps():
for model in models.get_models(app):
tables.add(model._meta.db_table)
tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
tables = [t for t in tables if t in self.table_names()]
return tables
def installed_models(self, tables):
"Returns a set of all models represented by the provided list of table names."
from django.db import models
all_models = []
for app in models.get_apps():
for model in models.get_models(app):
all_models.append(model)
return set([m for m in all_models
if self.table_name_converter(m._meta.db_table) in map(self.table_name_converter, tables)
])
def sequence_list(self):
"Returns a list of information about all DB sequences for all models in all apps."
from django.db import models
apps = models.get_apps()
sequence_list = []
for app in apps:
for model in models.get_models(app):
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
class BaseDatabaseClient(object):
"""
This class encapsualtes all backend-specific methods for opening a
client shell
"""
def runshell(self):
raise NotImplementedError()
class BaseDatabaseValidation(object):
"""
This class encapsualtes all backend-specific model validation.
"""
def validate_field(self, errors, opts, f):
"By default, there is no backend-specific validation"
pass

View File

@ -1,7 +1,396 @@
class BaseCreation(object): import sys
import time
from django.conf import settings
from django.core.management import call_command
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
class BaseDatabaseCreation(object):
""" """
This class encapsulates all backend-specific differences that pertain to This class encapsulates all backend-specific differences that pertain to
database *creation*, such as the column types to use for particular Django database *creation*, such as the column types to use for particular Django
Fields. Fields, the SQL used to create and destroy tables, and the creation and
destruction of test databases.
""" """
pass data_types = {}
def __init__(self, connection):
self.connection = connection
def sql_create_model(self, model, style, known_models=set()):
"""
Returns the SQL required to create a single model, as a tuple of:
(list_of_sql, pending_references_dict)
"""
from django.db import models
opts = model._meta
final_output = []
table_output = []
pending_references = {}
qn = self.connection.ops.quote_name
for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
# Skip ManyToManyFields, because they're not represented as
# database columns in this table.
continue
# Make the definition (e.g. 'foo VARCHAR(30)') for this field.
field_output = [style.SQL_FIELD(qn(f.column)),
style.SQL_COLTYPE(col_type)]
field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or '')))
if f.primary_key:
field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))
elif f.unique:
field_output.append(style.SQL_KEYWORD('UNIQUE'))
if tablespace and f.unique:
# We must specify the index tablespace inline, because we
# won't be generating a CREATE INDEX statement for this field.
field_output.append(self.connection.ops.tablespace_sql(tablespace, inline=True))
if f.rel:
ref_output, pending = self.sql_for_inline_foreign_key_references(f, known_models, style)
if pending:
pr = pending_references.setdefault(f.rel.to, []).append((model, f))
else:
field_output.extend(ref_output)
table_output.append(' '.join(field_output))
if opts.order_with_respect_to:
table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \
style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \
style.SQL_KEYWORD('NULL'))
for field_constraints in opts.unique_together:
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
for i, line in enumerate(table_output): # Combine and add commas.
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
full_statement.append(')')
if opts.db_tablespace:
full_statement.append(self.connection.ops.tablespace_sql(opts.db_tablespace))
full_statement.append(';')
final_output.append('\n'.join(full_statement))
if opts.has_auto_field:
# Add any extra SQL needed to support auto-incrementing primary keys.
auto_column = opts.auto_field.db_column or opts.auto_field.name
autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, auto_column)
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output, pending_references
def sql_for_inline_foreign_key_references(self, field, known_models, style):
"Return the SQL snippet defining the foreign key reference for a field"
qn = self.connection.ops.quote_name
if field.rel.to in known_models:
output = [style.SQL_KEYWORD('REFERENCES') + ' ' + \
style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + \
style.SQL_FIELD(qn(field.rel.to._meta.get_field(field.rel.field_name).column)) + ')' +
self.connection.ops.deferrable_sql()
]
pending = False
else:
# We haven't yet created the table to which this field
# is related, so save it for later.
output = []
pending = True
return output, pending
def sql_for_pending_references(self, model, style, pending_references):
"Returns any ALTER TABLE statements to add constraints after the fact."
from django.db.backends.util import truncate_name
qn = self.connection.ops.quote_name
final_output = []
opts = model._meta
if model in pending_references:
for rel_class, f in pending_references[model]:
rel_opts = rel_class._meta
r_table = rel_opts.db_table
r_col = f.column
table = opts.db_table
col = opts.get_field(f.rel.field_name).column
# For MySQL, r_name must be unique in the first 64 characters.
# So we are careful with character usage here.
r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
(qn(r_table), truncate_name(r_name, self.connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
self.connection.ops.deferrable_sql()))
del pending_references[model]
return final_output
def sql_for_many_to_many(self, model, style):
"Return the CREATE TABLE statments for all the many-to-many tables defined on a model"
output = []
for f in model._meta.local_many_to_many:
output.extend(self.sql_for_many_to_many_field(model, f, style))
return output
def sql_for_many_to_many_field(self, model, f, style):
"Return the CREATE TABLE statements for a single m2m field"
from django.db import models
from django.db.backends.util import truncate_name
output = []
if f.creates_table:
opts = model._meta
qn = self.connection.ops.quote_name
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace:
sql = self.connection.ops.tablespace_sql(tablespace, inline=True)
if sql:
tablespace_sql = ' ' + sql
else:
tablespace_sql = ''
else:
tablespace_sql = ''
table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \
style.SQL_TABLE(qn(f.m2m_db_table())) + ' (']
table_output.append(' %s %s %s%s,' %
(style.SQL_FIELD(qn('id')),
style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()),
style.SQL_KEYWORD('NOT NULL PRIMARY KEY'),
tablespace_sql))
deferred = []
inline_output, deferred = self.sql_for_inline_many_to_many_references(model, f, style)
table_output.extend(inline_output)
table_output.append(' %s (%s, %s)%s' %
(style.SQL_KEYWORD('UNIQUE'),
style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_FIELD(qn(f.m2m_reverse_name())),
tablespace_sql))
table_output.append(')')
if opts.db_tablespace:
# f.db_tablespace is only for indices, so ignore its value here.
table_output.append(self.connection.ops.tablespace_sql(opts.db_tablespace))
table_output.append(';')
output.append('\n'.join(table_output))
for r_table, r_col, table, col in deferred:
r_name = '%s_refs_%s_%x' % (r_col, col,
abs(hash((r_table, table))))
output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
(qn(r_table),
truncate_name(r_name, self.connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
self.connection.ops.deferrable_sql()))
# Add any extra SQL needed to support auto-incrementing PKs
autoinc_sql = self.connection.ops.autoinc_sql(f.m2m_db_table(), 'id')
if autoinc_sql:
for stmt in autoinc_sql:
output.append(stmt)
return output
def sql_for_inline_many_to_many_references(self, model, field, style):
"Create the references to other tables required by a many-to-many table"
from django.db import models
opts = model._meta
qn = self.connection.ops.quote_name
table_output = [
' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(field.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(opts.db_table)),
style.SQL_FIELD(qn(opts.pk.column)),
self.connection.ops.deferrable_sql()),
' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(field.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(field.rel.to._meta.db_table)),
style.SQL_FIELD(qn(field.rel.to._meta.pk.column)),
self.connection.ops.deferrable_sql())
]
deferred = []
return table_output, deferred
def sql_indexes_for_model(self, model, style):
"Returns the CREATE INDEX SQL statements for a single model"
output = []
for f in model._meta.local_fields:
output.extend(self.sql_indexes_for_field(model, f, style))
return output
def sql_indexes_for_field(self, model, f, style):
"Return the CREATE INDEX SQL statements for a single model field"
if f.db_index and not f.unique:
qn = self.connection.ops.quote_name
tablespace = f.db_tablespace or model._meta.db_tablespace
if tablespace:
sql = self.connection.ops.tablespace_sql(tablespace)
if sql:
tablespace_sql = ' ' + sql
else:
tablespace_sql = ''
else:
tablespace_sql = ''
output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' +
style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' +
style.SQL_KEYWORD('ON') + ' ' +
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' +
"(%s)" % style.SQL_FIELD(qn(f.column)) +
"%s;" % tablespace_sql]
else:
output = []
return output
def sql_destroy_model(self, model, references_to_delete, style):
"Return the DROP TABLE and restraint dropping statements for a single model"
# Drop the table now
qn = self.connection.ops.quote_name
output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(model._meta.db_table)))]
if model in references_to_delete:
output.extend(self.sql_remove_table_constraints(model, references_to_delete))
if model._meta.has_auto_field:
ds = self.connection.ops.drop_sequence_sql(model._meta.db_table)
if ds:
output.append(ds)
return output
def sql_remove_table_constraints(self, model, references_to_delete):
output = []
for rel_class, f in references_to_delete[model]:
table = rel_class._meta.db_table
col = f.column
r_table = model._meta.db_table
r_col = model._meta.get_field(f.rel.field_name).column
r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table))))
output.append('%s %s %s %s;' % \
(style.SQL_KEYWORD('ALTER TABLE'),
style.SQL_TABLE(qn(table)),
style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
style.SQL_FIELD(truncate_name(r_name, self.connection.ops.max_name_length()))))
del references_to_delete[model]
return output
def sql_destroy_many_to_many(self, model, f, style):
"Returns the DROP TABLE statements for a single m2m field"
qn = self.connection.ops.quote_name
output = []
if f.creates_table:
output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(f.m2m_db_table()))))
ds = self.connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column))
if ds:
output.append(ds)
return output
def create_test_db(self, verbosity=1, autoclobber=False):
"""
Creates a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
if verbosity >= 1:
print "Creating test database..."
test_database_name = self._create_test_db(verbosity, autoclobber)
self.connection.close()
settings.DATABASE_NAME = test_database_name
call_command('syncdb', verbosity=verbosity, interactive=False)
if settings.CACHE_BACKEND.startswith('db://'):
cache_name = settings.CACHE_BACKEND[len('db://'):]
call_command('createcachetable', cache_name)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
cursor = self.connection.cursor()
return test_database_name
def _create_test_db(self, verbosity, autoclobber):
"Internal implementation - creates the test db tables."
suffix = self.sql_table_creation_suffix()
if settings.TEST_DATABASE_NAME:
test_database_name = settings.TEST_DATABASE_NAME
else:
test_database_name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
qn = self.connection.ops.quote_name
# Create the test database and connect to it. We need to autocommit
# if the database supports it because PostgreSQL doesn't allow
# CREATE/DROP DATABASE statements within transactions.
cursor = self.connection.cursor()
self.set_autocommit()
try:
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
cursor.execute("DROP DATABASE %s" % qn(test_database_name))
if verbosity >= 1:
print "Creating test database..."
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
return test_database_name
def destroy_test_db(self, old_database_name, verbosity=1):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
if verbosity >= 1:
print "Destroying test database..."
self.connection.close()
test_database_name = settings.DATABASE_NAME
settings.DATABASE_NAME = old_database_name
self._destroy_test_db(test_database_name, verbosity)
def _destroy_test_db(self, test_database_name, verbosity):
"Internal implementation - remove the test db tables."
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = self.connection.cursor()
self.set_autocommit()
time.sleep(1) # To avoid "database is being accessed by other users" errors.
cursor.execute("DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name))
self.connection.close()
def set_autocommit(self):
"Make sure a connection is in autocommit mode."
if hasattr(self.connection.connection, "autocommit"):
if callable(self.connection.connection.autocommit):
self.connection.connection.autocommit(True)
else:
self.connection.connection.autocommit = True
elif hasattr(self.connection.connection, "set_isolation_level"):
self.connection.connection.set_isolation_level(0)
def sql_table_creation_suffix(self):
"SQL to append to the end of the test table creation statements"
return ''

View File

@ -8,7 +8,8 @@ ImproperlyConfigured.
""" """
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.backends import BaseDatabaseFeatures, BaseDatabaseOperations from django.db.backends import *
from django.db.backends.creation import BaseDatabaseCreation
def complain(*args, **kwargs): def complain(*args, **kwargs):
raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet." raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet."
@ -25,16 +26,30 @@ class IntegrityError(DatabaseError):
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
quote_name = complain quote_name = complain
class DatabaseClient(BaseDatabaseClient):
runshell = complain
class DatabaseIntrospection(BaseDatabaseIntrospection):
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
class DatabaseWrapper(object): class DatabaseWrapper(object):
features = BaseDatabaseFeatures()
ops = DatabaseOperations()
operators = {} operators = {}
cursor = complain cursor = complain
_commit = complain _commit = complain
_rollback = ignore _rollback = ignore
def __init__(self, **kwargs): def __init__(self, *args, **kwargs):
pass super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = BaseDatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = BaseDatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def close(self): def close(self):
pass pass

View File

@ -1,3 +0,0 @@
from django.db.backends.dummy.base import complain
runshell = complain

View File

@ -1 +0,0 @@
DATA_TYPES = {}

View File

@ -1,8 +0,0 @@
from django.db.backends.dummy.base import complain
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
DATA_TYPES_REVERSE = {}

View File

@ -4,7 +4,12 @@ MySQL database backend for Django.
Requires MySQLdb: http://sourceforge.net/projects/mysql-python Requires MySQLdb: http://sourceforge.net/projects/mysql-python
""" """
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util from django.db.backends import *
from django.db.backends.mysql.client import DatabaseClient
from django.db.backends.mysql.creation import DatabaseCreation
from django.db.backends.mysql.introspection import DatabaseIntrospection
from django.db.backends.mysql.validation import DatabaseValidation
try: try:
import MySQLdb as Database import MySQLdb as Database
except ImportError, e: except ImportError, e:
@ -60,7 +65,6 @@ server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
# TRADITIONAL will automatically cause most warnings to be treated as errors. # TRADITIONAL will automatically cause most warnings to be treated as errors.
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
inline_fk_references = False
empty_fetchmany_value = () empty_fetchmany_value = ()
update_can_self_select = False update_can_self_select = False
@ -142,8 +146,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return [first % value, second % value] return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= BINARY %s', 'exact': '= BINARY %s',
'iexact': 'LIKE %s', 'iexact': 'LIKE %s',
@ -165,6 +168,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
super(DatabaseWrapper, self).__init__(**kwargs) super(DatabaseWrapper, self).__init__(**kwargs)
self.server_version = None self.server_version = None
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = DatabaseValidation()
def _valid_connection(self): def _valid_connection(self):
if self.connection is not None: if self.connection is not None:
try: try:

View File

@ -1,7 +1,9 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
def runshell(self):
args = [''] args = ['']
db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME) db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME)
user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER) user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER)

View File

@ -1,8 +1,12 @@
# This dictionary maps Field objects to their associated MySQL column from django.conf import settings
# types, as strings. Column-type strings can contain format strings; they'll from django.db.backends.creation import BaseDatabaseCreation
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output. class DatabaseCreation(BaseDatabaseCreation):
DATA_TYPES = { # This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'integer AUTO_INCREMENT', 'AutoField': 'integer AUTO_INCREMENT',
'BooleanField': 'bool', 'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)', 'CharField': 'varchar(%(max_length)s)',
@ -25,4 +29,40 @@ DATA_TYPES = {
'TextField': 'longtext', 'TextField': 'longtext',
'TimeField': 'time', 'TimeField': 'time',
'USStateField': 'varchar(2)', 'USStateField': 'varchar(2)',
} }
def sql_table_creation_suffix(self):
suffix = []
if settings.TEST_DATABASE_CHARSET:
suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET)
if settings.TEST_DATABASE_COLLATION:
suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION)
return ' '.join(suffix)
def sql_for_inline_foreign_key_references(self, field, known_models, style):
"All inline references are pending under MySQL"
return [], True
def sql_for_inline_many_to_many_references(self, model, field, style):
from django.db import models
opts = model._meta
qn = self.connection.ops.quote_name
table_output = [
' %s %s %s,' %
(style.SQL_FIELD(qn(field.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL')),
' %s %s %s,' %
(style.SQL_FIELD(qn(field.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL'))
]
deferred = [
(field.m2m_db_table(), field.m2m_column_name(), opts.db_table,
opts.pk.column),
(field.m2m_db_table(), field.m2m_reverse_name(),
field.rel.to._meta.db_table, field.rel.to._meta.pk.column)
]
return table_output, deferred

View File

@ -1,80 +1,12 @@
from django.db.backends.mysql.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
from MySQLdb import ProgrammingError, OperationalError from MySQLdb import ProgrammingError, OperationalError
from MySQLdb.constants import FIELD_TYPE from MySQLdb.constants import FIELD_TYPE
import re import re
quote_name = DatabaseOperations().quote_name
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
def get_table_list(cursor): class DatabaseIntrospection(BaseDatabaseIntrospection):
"Returns a list of table names in the current database." data_types_reverse = {
cursor.execute("SHOW TABLES")
return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name))
return cursor.description
def _name_to_index(cursor, table_name):
"""
Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based.
"""
return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))])
def get_relations(cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
my_field_dict = _name_to_index(cursor, table_name)
constraints = []
relations = {}
try:
# This should work for MySQL 5.0.
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
constraints.extend(cursor.fetchall())
except (ProgrammingError, OperationalError):
# Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
# Go through all constraints and save the equal matches.
cursor.execute("SHOW CREATE TABLE %s" % quote_name(table_name))
for row in cursor.fetchall():
pos = 0
while True:
match = foreign_key_re.search(row[1], pos)
if match == None:
break
pos = match.end()
constraints.append(match.groups())
for my_fieldname, other_table, other_field in constraints:
other_field_index = _name_to_index(cursor, other_table)[other_field]
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
cursor.execute("SHOW INDEX FROM %s" % quote_name(table_name))
indexes = {}
for row in cursor.fetchall():
indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
return indexes
DATA_TYPES_REVERSE = {
FIELD_TYPE.BLOB: 'TextField', FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField', FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField', FIELD_TYPE.DECIMAL: 'DecimalField',
@ -93,4 +25,73 @@ DATA_TYPES_REVERSE = {
FIELD_TYPE.MEDIUM_BLOB: 'TextField', FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField', FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField', FIELD_TYPE.VAR_STRING: 'CharField',
} }
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
cursor.execute("SHOW TABLES")
return [row[0] for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return cursor.description
def _name_to_index(self, cursor, table_name):
"""
Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based.
"""
return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))])
def get_relations(self, cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
my_field_dict = self._name_to_index(cursor, table_name)
constraints = []
relations = {}
try:
# This should work for MySQL 5.0.
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
constraints.extend(cursor.fetchall())
except (ProgrammingError, OperationalError):
# Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
# Go through all constraints and save the equal matches.
cursor.execute("SHOW CREATE TABLE %s" % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
pos = 0
while True:
match = foreign_key_re.search(row[1], pos)
if match == None:
break
pos = match.end()
constraints.append(match.groups())
for my_fieldname, other_table, other_field in constraints:
other_field_index = self._name_to_index(cursor, other_table)[other_field]
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
indexes = {}
for row in cursor.fetchall():
indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
return indexes

View File

@ -0,0 +1,13 @@
from django.db.backends import BaseDatabaseValidation
class DatabaseValidation(BaseDatabaseValidation):
def validate_field(self, errors, opts, f):
"Prior to MySQL 5.0.3, character fields could not exceed 255 characters"
from django.db import models
from django.db import connection
db_version = connection.get_server_version()
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
errors.add(opts,
'"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' %
(f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))

View File

@ -8,8 +8,11 @@ import os
import datetime import datetime
import time import time
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util from django.db.backends import *
from django.db.backends.oracle import query from django.db.backends.oracle import query
from django.db.backends.oracle.client import DatabaseClient
from django.db.backends.oracle.creation import DatabaseCreation
from django.db.backends.oracle.introspection import DatabaseIntrospection
from django.utils.encoding import smart_str, force_unicode from django.utils.encoding import smart_str, force_unicode
# Oracle takes client-side character set encoding from the environment. # Oracle takes client-side character set encoding from the environment.
@ -24,11 +27,8 @@ DatabaseError = Database.Error
IntegrityError = Database.IntegrityError IntegrityError = Database.IntegrityError
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_ordinal = False
empty_fetchmany_value = () empty_fetchmany_value = ()
needs_datetime_string_cast = False needs_datetime_string_cast = False
supports_tablespaces = True
uses_case_insensitive_names = True
uses_custom_query_class = True uses_custom_query_class = True
interprets_empty_strings_as_nulls = True interprets_empty_strings_as_nulls = True
@ -194,10 +194,8 @@ class DatabaseOperations(BaseDatabaseOperations):
return [first % value, second % value] return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= %s', 'exact': '= %s',
'iexact': '= UPPER(%s)', 'iexact': '= UPPER(%s)',
@ -214,6 +212,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
} }
oracle_version = None oracle_version = None
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _valid_connection(self): def _valid_connection(self):
return self.connection is not None return self.connection is not None

View File

@ -1,7 +1,9 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
def runshell(self):
dsn = settings.DATABASE_USER dsn = settings.DATABASE_USER
if settings.DATABASE_PASSWORD: if settings.DATABASE_PASSWORD:
dsn += "/%s" % settings.DATABASE_PASSWORD dsn += "/%s" % settings.DATABASE_PASSWORD

View File

@ -1,15 +1,21 @@
import sys, time import sys, time
from django.conf import settings
from django.core import management from django.core import management
from django.db.backends.creation import BaseDatabaseCreation
# This dictionary maps Field objects to their associated Oracle column TEST_DATABASE_PREFIX = 'test_'
# types, as strings. Column-type strings can contain format strings; they'll PASSWORD = 'Im_a_lumberjack'
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
#
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
DATA_TYPES = { class DatabaseCreation(BaseDatabaseCreation):
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
#
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
data_types = {
'AutoField': 'NUMBER(11)', 'AutoField': 'NUMBER(11)',
'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))', 'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))',
'CharField': 'NVARCHAR2(%(max_length)s)', 'CharField': 'NVARCHAR2(%(max_length)s)',
@ -33,18 +39,14 @@ DATA_TYPES = {
'TimeField': 'TIMESTAMP', 'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)', 'URLField': 'VARCHAR2(%(max_length)s)',
'USStateField': 'CHAR(2)', 'USStateField': 'CHAR(2)',
} }
TEST_DATABASE_PREFIX = 'test_' def _create_test_db(self, verbosity, autoclobber):
PASSWORD = 'Im_a_lumberjack' TEST_DATABASE_NAME = self._test_database_name(settings)
REMEMBER = {} TEST_DATABASE_USER = self._test_database_user(settings)
TEST_DATABASE_PASSWD = self._test_database_passwd(settings)
def create_test_db(settings, connection, verbosity=1, autoclobber=False): TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings)
TEST_DATABASE_NAME = _test_database_name(settings) TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings)
TEST_DATABASE_USER = _test_database_user(settings)
TEST_DATABASE_PASSWD = _test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings)
parameters = { parameters = {
'dbname': TEST_DATABASE_NAME, 'dbname': TEST_DATABASE_NAME,
@ -54,15 +56,15 @@ def create_test_db(settings, connection, verbosity=1, autoclobber=False):
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, 'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
} }
REMEMBER['user'] = settings.DATABASE_USER self.remember['user'] = settings.DATABASE_USER
REMEMBER['passwd'] = settings.DATABASE_PASSWORD self.remember['passwd'] = settings.DATABASE_PASSWORD
cursor = connection.cursor() cursor = self.connection.cursor()
if _test_database_create(settings): if self._test_database_create(settings):
if verbosity >= 1: if verbosity >= 1:
print 'Creating test database...' print 'Creating test database...'
try: try:
_create_test_db(cursor, parameters, verbosity) self._execute_test_db_creation(cursor, parameters, verbosity)
except Exception, e: except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e) sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber: if not autoclobber:
@ -71,10 +73,10 @@ def create_test_db(settings, connection, verbosity=1, autoclobber=False):
try: try:
if verbosity >= 1: if verbosity >= 1:
print "Destroying old test database..." print "Destroying old test database..."
_destroy_test_db(cursor, parameters, verbosity) self._execute_test_db_destruction(cursor, parameters, verbosity)
if verbosity >= 1: if verbosity >= 1:
print "Creating test database..." print "Creating test database..."
_create_test_db(cursor, parameters, verbosity) self._execute_test_db_creation(cursor, parameters, verbosity)
except Exception, e: except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e) sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2) sys.exit(2)
@ -82,11 +84,11 @@ def create_test_db(settings, connection, verbosity=1, autoclobber=False):
print "Tests cancelled." print "Tests cancelled."
sys.exit(1) sys.exit(1)
if _test_user_create(settings): if self._test_user_create(settings):
if verbosity >= 1: if verbosity >= 1:
print "Creating test user..." print "Creating test user..."
try: try:
_create_test_user(cursor, parameters, verbosity) self._create_test_user(cursor, parameters, verbosity)
except Exception, e: except Exception, e:
sys.stderr.write("Got an error creating the test user: %s\n" % e) sys.stderr.write("Got an error creating the test user: %s\n" % e)
if not autoclobber: if not autoclobber:
@ -95,10 +97,10 @@ def create_test_db(settings, connection, verbosity=1, autoclobber=False):
try: try:
if verbosity >= 1: if verbosity >= 1:
print "Destroying old test user..." print "Destroying old test user..."
_destroy_test_user(cursor, parameters, verbosity) self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1: if verbosity >= 1:
print "Creating test user..." print "Creating test user..."
_create_test_user(cursor, parameters, verbosity) self._create_test_user(cursor, parameters, verbosity)
except Exception, e: except Exception, e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e) sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2) sys.exit(2)
@ -106,28 +108,24 @@ def create_test_db(settings, connection, verbosity=1, autoclobber=False):
print "Tests cancelled." print "Tests cancelled."
sys.exit(1) sys.exit(1)
connection.close()
settings.DATABASE_USER = TEST_DATABASE_USER settings.DATABASE_USER = TEST_DATABASE_USER
settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD
management.call_command('syncdb', verbosity=verbosity, interactive=False) return TEST_DATABASE_NAME
# Get a cursor (even though we don't need one yet). This has def _destroy_test_db(self, test_database_name, verbosity=1):
# the side effect of initializing the test database. """
cursor = connection.cursor() Destroy a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
TEST_DATABASE_NAME = self._test_database_name(settings)
TEST_DATABASE_USER = self._test_database_user(settings)
TEST_DATABASE_PASSWD = self._test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings)
def destroy_test_db(settings, connection, old_database_name, verbosity=1): settings.DATABASE_USER = self.remember['user']
connection.close() settings.DATABASE_PASSWORD = self.remember['passwd']
TEST_DATABASE_NAME = _test_database_name(settings)
TEST_DATABASE_USER = _test_database_user(settings)
TEST_DATABASE_PASSWD = _test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings)
settings.DATABASE_NAME = old_database_name
settings.DATABASE_USER = REMEMBER['user']
settings.DATABASE_PASSWORD = REMEMBER['passwd']
parameters = { parameters = {
'dbname': TEST_DATABASE_NAME, 'dbname': TEST_DATABASE_NAME,
@ -137,22 +135,22 @@ def destroy_test_db(settings, connection, old_database_name, verbosity=1):
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, 'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
} }
REMEMBER['user'] = settings.DATABASE_USER self.remember['user'] = settings.DATABASE_USER
REMEMBER['passwd'] = settings.DATABASE_PASSWORD self.remember['passwd'] = settings.DATABASE_PASSWORD
cursor = connection.cursor() cursor = self.connection.cursor()
time.sleep(1) # To avoid "database is being accessed by other users" errors. time.sleep(1) # To avoid "database is being accessed by other users" errors.
if _test_user_create(settings): if self._test_user_create(settings):
if verbosity >= 1: if verbosity >= 1:
print 'Destroying test user...' print 'Destroying test user...'
_destroy_test_user(cursor, parameters, verbosity) self._destroy_test_user(cursor, parameters, verbosity)
if _test_database_create(settings): if self._test_database_create(settings):
if verbosity >= 1: if verbosity >= 1:
print 'Destroying test database...' print 'Destroying test database tables...'
_destroy_test_db(cursor, parameters, verbosity) self._execute_test_db_destruction(cursor, parameters, verbosity)
connection.close() self.connection.close()
def _create_test_db(cursor, parameters, verbosity): def _execute_test_db_creation(cursor, parameters, verbosity):
if verbosity >= 2: if verbosity >= 2:
print "_create_test_db(): dbname = %s" % parameters['dbname'] print "_create_test_db(): dbname = %s" % parameters['dbname']
statements = [ statements = [
@ -167,7 +165,7 @@ def _create_test_db(cursor, parameters, verbosity):
] ]
_execute_statements(cursor, statements, parameters, verbosity) _execute_statements(cursor, statements, parameters, verbosity)
def _create_test_user(cursor, parameters, verbosity): def _create_test_user(cursor, parameters, verbosity):
if verbosity >= 2: if verbosity >= 2:
print "_create_test_user(): username = %s" % parameters['user'] print "_create_test_user(): username = %s" % parameters['user']
statements = [ statements = [
@ -180,16 +178,16 @@ def _create_test_user(cursor, parameters, verbosity):
] ]
_execute_statements(cursor, statements, parameters, verbosity) _execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_db(cursor, parameters, verbosity): def _execute_test_db_destruction(cursor, parameters, verbosity):
if verbosity >= 2: if verbosity >= 2:
print "_destroy_test_db(): dbname=%s" % parameters['dbname'] print "_execute_test_db_destruction(): dbname=%s" % parameters['dbname']
statements = [ statements = [
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
] ]
_execute_statements(cursor, statements, parameters, verbosity) _execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(cursor, parameters, verbosity): def _destroy_test_user(cursor, parameters, verbosity):
if verbosity >= 2: if verbosity >= 2:
print "_destroy_test_user(): user=%s" % parameters['user'] print "_destroy_test_user(): user=%s" % parameters['user']
print "Be patient. This can take some time..." print "Be patient. This can take some time..."
@ -198,7 +196,7 @@ def _destroy_test_user(cursor, parameters, verbosity):
] ]
_execute_statements(cursor, statements, parameters, verbosity) _execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(cursor, statements, parameters, verbosity): def _execute_statements(cursor, statements, parameters, verbosity):
for template in statements: for template in statements:
stmt = template % parameters stmt = template % parameters
if verbosity >= 2: if verbosity >= 2:
@ -209,7 +207,7 @@ def _execute_statements(cursor, statements, parameters, verbosity):
sys.stderr.write("Failed (%s)\n" % (err)) sys.stderr.write("Failed (%s)\n" % (err))
raise raise
def _test_database_name(settings): def _test_database_name(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try: try:
if settings.TEST_DATABASE_NAME: if settings.TEST_DATABASE_NAME:
@ -220,7 +218,7 @@ def _test_database_name(settings):
raise raise
return name return name
def _test_database_create(settings): def _test_database_create(settings):
name = True name = True
try: try:
if settings.TEST_DATABASE_CREATE: if settings.TEST_DATABASE_CREATE:
@ -233,7 +231,7 @@ def _test_database_create(settings):
raise raise
return name return name
def _test_user_create(settings): def _test_user_create(settings):
name = True name = True
try: try:
if settings.TEST_USER_CREATE: if settings.TEST_USER_CREATE:
@ -246,7 +244,7 @@ def _test_user_create(settings):
raise raise
return name return name
def _test_database_user(settings): def _test_database_user(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try: try:
if settings.TEST_DATABASE_USER: if settings.TEST_DATABASE_USER:
@ -257,7 +255,7 @@ def _test_database_user(settings):
raise raise
return name return name
def _test_database_passwd(settings): def _test_database_passwd(settings):
name = PASSWORD name = PASSWORD
try: try:
if settings.TEST_DATABASE_PASSWD: if settings.TEST_DATABASE_PASSWD:
@ -268,7 +266,7 @@ def _test_database_passwd(settings):
raise raise
return name return name
def _test_database_tblspace(settings): def _test_database_tblspace(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try: try:
if settings.TEST_DATABASE_TBLSPACE: if settings.TEST_DATABASE_TBLSPACE:
@ -279,7 +277,7 @@ def _test_database_tblspace(settings):
raise raise
return name return name
def _test_database_tblspace_tmp(settings): def _test_database_tblspace_tmp(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp' name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp'
try: try:
if settings.TEST_DATABASE_TBLSPACE_TMP: if settings.TEST_DATABASE_TBLSPACE_TMP:

View File

@ -1,37 +1,52 @@
from django.db.backends.oracle.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
import re
import cx_Oracle import cx_Oracle
import re
quote_name = DatabaseOperations().quote_name
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
def get_table_list(cursor): class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type objects to Django Field types.
data_types_reverse = {
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateTimeField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}
def get_table_list(self, cursor):
"Returns a list of table names in the current database." "Returns a list of table names in the current database."
cursor.execute("SELECT TABLE_NAME FROM USER_TABLES") cursor.execute("SELECT TABLE_NAME FROM USER_TABLES")
return [row[0].upper() for row in cursor.fetchall()] return [row[0].upper() for row in cursor.fetchall()]
def get_table_description(cursor, table_name): def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface." "Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % quote_name(table_name)) cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % self.connection.ops.quote_name(table_name))
return cursor.description return cursor.description
def _name_to_index(cursor, table_name): def table_name_converter(self, name):
"Table name comparison is case insensitive under Oracle"
return name.upper()
def _name_to_index(self, cursor, table_name):
""" """
Returns a dictionary of {field_name: field_index} for the given table. Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based. Indexes are 0-based.
""" """
return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))]) return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))])
def get_relations(cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based. representing all relationships to the given table. Indexes are 0-based.
""" """
cursor.execute(""" cursor.execute("""
SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1 SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb, FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb,
user_tab_cols ta, user_tab_cols tb user_tab_cols ta, user_tab_cols tb
WHERE user_constraints.table_name = %s AND WHERE user_constraints.table_name = %s AND
ta.table_name = %s AND ta.table_name = %s AND
ta.column_name = ca.column_name AND ta.column_name = ca.column_name AND
ca.table_name = %s AND ca.table_name = %s AND
@ -46,7 +61,7 @@ WHERE user_constraints.table_name = %s AND
relations[row[0]] = (row[2], row[1]) relations[row[0]] = (row[2], row[1])
return relations return relations
def get_indexes(cursor, table_name): def get_indexes(self, cursor, table_name):
""" """
Returns a dictionary of fieldname -> infodict for the given table, Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format: where each infodict is in the format:
@ -57,7 +72,7 @@ def get_indexes(cursor, table_name):
# first associated field name # first associated field name
# "We were in the nick of time; you were in great peril!" # "We were in the nick of time; you were in great peril!"
sql = """ sql = """
WITH primarycols AS ( WITH primarycols AS (
SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL
FROM user_cons_columns, user_constraints FROM user_cons_columns, user_constraints
WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND
@ -69,11 +84,11 @@ WITH primarycols AS (
WHERE uniqueness = 'UNIQUE' AND WHERE uniqueness = 'UNIQUE' AND
user_indexes.index_name = user_ind_columns.index_name AND user_indexes.index_name = user_ind_columns.index_name AND
user_ind_columns.table_name = %s) user_ind_columns.table_name = %s)
SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL
FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM
uniquecols) allcols, uniquecols) allcols,
primarycols, uniquecols primarycols, uniquecols
WHERE allcols.column_name = primarycols.column_name (+) AND WHERE allcols.column_name = primarycols.column_name (+) AND
allcols.column_name = uniquecols.column_name (+) allcols.column_name = uniquecols.column_name (+)
""" """
cursor.execute(sql, [table_name, table_name]) cursor.execute(sql, [table_name, table_name])
@ -86,13 +101,3 @@ WHERE allcols.column_name = primarycols.column_name (+) AND
indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]} indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]}
return indexes return indexes
# Maps type objects to Django Field types.
DATA_TYPES_REVERSE = {
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateTimeField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}

View File

@ -4,9 +4,13 @@ PostgreSQL database backend for Django.
Requires psycopg 1: http://initd.org/projects/psycopg1 Requires psycopg 1: http://initd.org/projects/psycopg1
""" """
from django.utils.encoding import smart_str, smart_unicode from django.db.backends import *
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, util from django.db.backends.postgresql.client import DatabaseClient
from django.db.backends.postgresql.creation import DatabaseCreation
from django.db.backends.postgresql.introspection import DatabaseIntrospection
from django.db.backends.postgresql.operations import DatabaseOperations from django.db.backends.postgresql.operations import DatabaseOperations
from django.utils.encoding import smart_str, smart_unicode
try: try:
import psycopg as Database import psycopg as Database
except ImportError, e: except ImportError, e:
@ -59,12 +63,7 @@ class UnicodeCursorWrapper(object):
def __iter__(self): def __iter__(self):
return iter(self.cursor) return iter(self.cursor)
class DatabaseFeatures(BaseDatabaseFeatures):
pass # This backend uses all the defaults.
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= %s', 'exact': '= %s',
'iexact': 'ILIKE %s', 'iexact': 'ILIKE %s',
@ -82,6 +81,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': 'ILIKE %s', 'iendswith': 'ILIKE %s',
} }
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = BaseDatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self, settings):
set_tz = False set_tz = False
if self.connection is None: if self.connection is None:

View File

@ -1,7 +1,9 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
def runshell(self):
args = ['psql'] args = ['psql']
if settings.DATABASE_USER: if settings.DATABASE_USER:
args += ["-U", settings.DATABASE_USER] args += ["-U", settings.DATABASE_USER]

View File

@ -1,8 +1,12 @@
# This dictionary maps Field objects to their associated PostgreSQL column from django.conf import settings
# types, as strings. Column-type strings can contain format strings; they'll from django.db.backends.creation import BaseDatabaseCreation
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output. class DatabaseCreation(BaseDatabaseCreation):
DATA_TYPES = { # This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'serial', 'AutoField': 'serial',
'BooleanField': 'boolean', 'BooleanField': 'boolean',
'CharField': 'varchar(%(max_length)s)', 'CharField': 'varchar(%(max_length)s)',
@ -25,4 +29,10 @@ DATA_TYPES = {
'TextField': 'text', 'TextField': 'text',
'TimeField': 'time', 'TimeField': 'time',
'USStateField': 'varchar(2)', 'USStateField': 'varchar(2)',
} }
def sql_table_creation_suffix(self):
assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time."
if settings.TEST_DATABASE_CHARSET:
return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET
return ''

View File

@ -1,8 +1,24 @@
from django.db.backends.postgresql.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
quote_name = DatabaseOperations().quote_name class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}
def get_table_list(cursor): def get_table_list(self, cursor):
"Returns a list of table names in the current database." "Returns a list of table names in the current database."
cursor.execute(""" cursor.execute("""
SELECT c.relname SELECT c.relname
@ -13,12 +29,12 @@ def get_table_list(cursor):
AND pg_catalog.pg_table_is_visible(c.oid)""") AND pg_catalog.pg_table_is_visible(c.oid)""")
return [row[0] for row in cursor.fetchall()] return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name): def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface." "Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return cursor.description return cursor.description
def get_relations(cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based. representing all relationships to the given table. Indexes are 0-based.
@ -39,7 +55,7 @@ def get_relations(cursor, table_name):
continue continue
return relations return relations
def get_indexes(cursor, table_name): def get_indexes(self, cursor, table_name):
""" """
Returns a dictionary of fieldname -> infodict for the given table, Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format: where each infodict is in the format:
@ -68,19 +84,3 @@ def get_indexes(cursor, table_name):
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]} indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes return indexes
# Maps type codes to Django Field types.
DATA_TYPES_REVERSE = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}

View File

@ -4,8 +4,12 @@ PostgreSQL database backend for Django.
Requires psycopg 2: http://initd.org/projects/psycopg2 Requires psycopg 2: http://initd.org/projects/psycopg2
""" """
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures from django.db.backends import *
from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations
from django.db.backends.postgresql.client import DatabaseClient
from django.db.backends.postgresql.creation import DatabaseCreation
from django.db.backends.postgresql_psycopg2.introspection import DatabaseIntrospection
from django.utils.safestring import SafeUnicode from django.utils.safestring import SafeUnicode
try: try:
import psycopg2 as Database import psycopg2 as Database
@ -31,8 +35,6 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
return cursor.query return cursor.query
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= %s', 'exact': '= %s',
'iexact': 'ILIKE %s', 'iexact': 'ILIKE %s',
@ -50,6 +52,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': 'ILIKE %s', 'iendswith': 'ILIKE %s',
} }
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self, settings):
set_tz = False set_tz = False
if self.connection is None: if self.connection is None:

View File

@ -1 +0,0 @@
from django.db.backends.postgresql.client import *

View File

@ -1 +0,0 @@
from django.db.backends.postgresql.creation import *

View File

@ -1,24 +1,8 @@
from django.db.backends.postgresql_psycopg2.base import DatabaseOperations from django.db.backends.postgresql.introspection import DatabaseIntrospection as PostgresDatabaseIntrospection
quote_name = DatabaseOperations().quote_name class DatabaseIntrospection(PostgresDatabaseIntrospection):
def get_table_list(cursor): def get_relations(self, cursor, table_name):
"Returns a list of table names in the current database."
cursor.execute("""
SELECT c.relname
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r', 'v', '')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name))
return cursor.description
def get_relations(cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based. representing all relationships to the given table. Indexes are 0-based.
@ -35,49 +19,3 @@ def get_relations(cursor, table_name):
# row[0] and row[1] are single-item lists, so grab the single item. # row[0] and row[1] are single-item lists, so grab the single item.
relations[row[0][0] - 1] = (row[1][0] - 1, row[2]) relations[row[0][0] - 1] = (row[1][0] - 1, row[2])
return relations return relations
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
cursor.execute("""
SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND attr.attrelid = c.oid
AND attr.attnum = idx.indkey[0]
AND c.relname = %s""", [table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
if ' ' in row[1]:
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
# Maps type codes to Django Field types.
DATA_TYPES_REVERSE = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}

View File

@ -6,7 +6,11 @@ Python 2.3 and 2.4 require pysqlite2 (http://pysqlite.org/).
Python 2.5 and later use the sqlite3 module in the standard library. Python 2.5 and later use the sqlite3 module in the standard library.
""" """
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util from django.db.backends import *
from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection
try: try:
try: try:
from sqlite3 import dbapi2 as Database from sqlite3 import dbapi2 as Database
@ -46,7 +50,6 @@ if Database.version_info >= (2,4,1):
Database.register_adapter(str, lambda s:s.decode('utf-8')) Database.register_adapter(str, lambda s:s.decode('utf-8'))
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
supports_constraints = False
# SQLite cannot handle us only partially reading from a cursor's result set # SQLite cannot handle us only partially reading from a cursor's result set
# and then writing the same rows to the database in another cursor. This # and then writing the same rows to the database in another cursor. This
# setting ensures we always read result sets fully into memory all in one # setting ensures we always read result sets fully into memory all in one
@ -96,10 +99,7 @@ class DatabaseOperations(BaseDatabaseOperations):
second = '%s-12-31 23:59:59.999999' second = '%s-12-31 23:59:59.999999'
return [first % value, second % value] return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
# SQLite requires LIKE statements to include an ESCAPE clause if the value # SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it. # being escaped has a percent or underscore in it.
@ -121,6 +121,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': "LIKE %s ESCAPE '\\'", 'iendswith': "LIKE %s ESCAPE '\\'",
} }
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self, settings):
if self.connection is None: if self.connection is None:
if not settings.DATABASE_NAME: if not settings.DATABASE_NAME:

View File

@ -1,6 +1,8 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
def runshell(self):
args = ['', settings.DATABASE_NAME] args = ['', settings.DATABASE_NAME]
os.execvp('sqlite3', args) os.execvp('sqlite3', args)

View File

@ -1,7 +1,13 @@
# SQLite doesn't actually support most of these types, but it "does the right import os
# thing" given more verbose field definitions, so leave them as is so that import sys
# schema inspection is more useful. from django.conf import settings
DATA_TYPES = { from django.db.backends.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
data_types = {
'AutoField': 'integer', 'AutoField': 'integer',
'BooleanField': 'bool', 'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)', 'CharField': 'varchar(%(max_length)s)',
@ -24,4 +30,44 @@ DATA_TYPES = {
'TextField': 'text', 'TextField': 'text',
'TimeField': 'time', 'TimeField': 'time',
'USStateField': 'varchar(2)', 'USStateField': 'varchar(2)',
} }
def sql_for_pending_references(self, model, style, pending_references):
"SQLite3 doesn't support constraints"
return []
def sql_remove_table_constraints(self, model, references_to_delete):
"SQLite3 doesn't support constraints"
return []
def _create_test_db(self, verbosity, autoclobber):
if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:":
test_database_name = settings.TEST_DATABASE_NAME
# Erase the old test database
if verbosity >= 1:
print "Destroying old test database..."
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
os.remove(test_database_name)
except Exception, e:
sys.stderr.write("Got an error deleting the old test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
if verbosity >= 1:
print "Creating test database..."
else:
test_database_name = ":memory:"
return test_database_name
def _destroy_test_db(self, test_database_name, verbosity):
if test_database_name and test_database_name != ":memory:":
# Remove the SQLite database file
os.remove(test_database_name)

View File

@ -1,63 +1,13 @@
from django.db.backends.sqlite3.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
quote_name = DatabaseOperations().quote_name # This light wrapper "fakes" a dictionary interface, because some SQLite data
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
def get_table_list(cursor): # as a simple dictionary lookup.
"Returns a list of table names in the current database." class FlexibleFieldLookupDict:
# Skip the sqlite_sequence system table used for autoincrement key # Maps SQL types to Django Field types. Some of the SQL types have multiple
# generation. # entries here because SQLite allows for anything and doesn't normalize the
cursor.execute(""" # field type; it uses whatever was given.
SELECT name FROM sqlite_master base_data_types_reverse = {
WHERE type='table' AND NOT name='sqlite_sequence'
ORDER BY name""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
return [(info['name'], info['type'], None, None, None, None,
info['null_ok']) for info in _table_info(cursor, table_name)]
def get_relations(cursor, table_name):
raise NotImplementedError
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
indexes = {}
for info in _table_info(cursor, table_name):
indexes[info['name']] = {'primary_key': info['pk'] != 0,
'unique': False}
cursor.execute('PRAGMA index_list(%s)' % quote_name(table_name))
# seq, name, unique
for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]:
if not unique:
continue
cursor.execute('PRAGMA index_info(%s)' % quote_name(index))
info = cursor.fetchall()
# Skip indexes across multiple fields
if len(info) != 1:
continue
name = info[0][2] # seqno, cid, name
indexes[name]['unique'] = True
return indexes
def _table_info(cursor, name):
cursor.execute('PRAGMA table_info(%s)' % quote_name(name))
# cid, name, type, notnull, dflt_value, pk
return [{'name': field[1],
'type': field[2],
'null_ok': not field[3],
'pk': field[5] # undocumented
} for field in cursor.fetchall()]
# Maps SQL types to Django Field types. Some of the SQL types have multiple
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
BASE_DATA_TYPES_REVERSE = {
'bool': 'BooleanField', 'bool': 'BooleanField',
'boolean': 'BooleanField', 'boolean': 'BooleanField',
'smallint': 'SmallIntegerField', 'smallint': 'SmallIntegerField',
@ -69,16 +19,12 @@ BASE_DATA_TYPES_REVERSE = {
'date': 'DateField', 'date': 'DateField',
'datetime': 'DateTimeField', 'datetime': 'DateTimeField',
'time': 'TimeField', 'time': 'TimeField',
} }
# This light wrapper "fakes" a dictionary interface, because some SQLite data
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
# as a simple dictionary lookup.
class FlexibleFieldLookupDict:
def __getitem__(self, key): def __getitem__(self, key):
key = key.lower() key = key.lower()
try: try:
return BASE_DATA_TYPES_REVERSE[key] return self.base_data_types_reverse[key]
except KeyError: except KeyError:
import re import re
m = re.search(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$', key) m = re.search(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$', key)
@ -86,4 +32,58 @@ class FlexibleFieldLookupDict:
return ('CharField', {'max_length': int(m.group(1))}) return ('CharField', {'max_length': int(m.group(1))})
raise KeyError raise KeyError
DATA_TYPES_REVERSE = FlexibleFieldLookupDict() class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = FlexibleFieldLookupDict()
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND NOT name='sqlite_sequence'
ORDER BY name""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
return [(info['name'], info['type'], None, None, None, None,
info['null_ok']) for info in self._table_info(cursor, table_name)]
def get_relations(self, cursor, table_name):
raise NotImplementedError
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
indexes = {}
for info in self._table_info(cursor, table_name):
indexes[info['name']] = {'primary_key': info['pk'] != 0,
'unique': False}
cursor.execute('PRAGMA index_list(%s)' % self.connection.ops.quote_name(table_name))
# seq, name, unique
for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]:
if not unique:
continue
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
info = cursor.fetchall()
# Skip indexes across multiple fields
if len(info) != 1:
continue
name = info[0][2] # seqno, cid, name
indexes[name]['unique'] = True
return indexes
def _table_info(self, cursor, name):
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(name))
# cid, name, type, notnull, dflt_value, pk
return [{'name': field[1],
'type': field[2],
'null_ok': not field[3],
'pk': field[5] # undocumented
} for field in cursor.fetchall()]

View File

@ -7,7 +7,7 @@ try:
except ImportError: except ImportError:
from django.utils import _decimal as decimal # for Python 2.3 from django.utils import _decimal as decimal # for Python 2.3
from django.db import connection, get_creation_module from django.db import connection
from django.db.models import signals from django.db.models import signals
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
from django.dispatch import dispatcher from django.dispatch import dispatcher
@ -145,14 +145,14 @@ class Field(object):
# as the TextField Django field type, which means XMLField's # as the TextField Django field type, which means XMLField's
# get_internal_type() returns 'TextField'. # get_internal_type() returns 'TextField'.
# #
# But the limitation of the get_internal_type() / DATA_TYPES approach # But the limitation of the get_internal_type() / data_types approach
# is that it cannot handle database column types that aren't already # is that it cannot handle database column types that aren't already
# mapped to one of the built-in Django field types. In this case, you # mapped to one of the built-in Django field types. In this case, you
# can implement db_type() instead of get_internal_type() to specify # can implement db_type() instead of get_internal_type() to specify
# exactly which wacky database column type you want to use. # exactly which wacky database column type you want to use.
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
try: try:
return get_creation_module().DATA_TYPES[self.get_internal_type()] % data return connection.creation.data_types[self.get_internal_type()] % data
except KeyError: except KeyError:
return None return None

View File

@ -3,7 +3,6 @@ from django.conf import settings
from django.db.models import get_app, get_apps from django.db.models import get_app, get_apps
from django.test import _doctest as doctest from django.test import _doctest as doctest
from django.test.utils import setup_test_environment, teardown_test_environment from django.test.utils import setup_test_environment, teardown_test_environment
from django.test.utils import create_test_db, destroy_test_db
from django.test.testcases import OutputChecker, DocTestRunner from django.test.testcases import OutputChecker, DocTestRunner
# The module name for tests outside models.py # The module name for tests outside models.py
@ -139,9 +138,10 @@ def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]):
suite.addTest(test) suite.addTest(test)
old_name = settings.DATABASE_NAME old_name = settings.DATABASE_NAME
create_test_db(verbosity, autoclobber=not interactive) from django.db import connection
connection.creation.create_test_db(verbosity, autoclobber=not interactive)
result = unittest.TextTestRunner(verbosity=verbosity).run(suite) result = unittest.TextTestRunner(verbosity=verbosity).run(suite)
destroy_test_db(old_name, verbosity) connection.creation.destroy_test_db(old_name, verbosity)
teardown_test_environment() teardown_test_environment()

View File

@ -1,16 +1,11 @@
import sys, time, os import sys, time, os
from django.conf import settings from django.conf import settings
from django.db import connection, get_creation_module from django.db import connection
from django.core import mail from django.core import mail
from django.core.management import call_command
from django.test import signals from django.test import signals
from django.template import Template from django.template import Template
from django.utils.translation import deactivate from django.utils.translation import deactivate
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
def instrumented_test_render(self, context): def instrumented_test_render(self, context):
""" """
An instrumented Template render method, providing a signal An instrumented Template render method, providing a signal
@ -70,147 +65,3 @@ def teardown_test_environment():
del mail.outbox del mail.outbox
def _set_autocommit(connection):
"Make sure a connection is in autocommit mode."
if hasattr(connection.connection, "autocommit"):
if callable(connection.connection.autocommit):
connection.connection.autocommit(True)
else:
connection.connection.autocommit = True
elif hasattr(connection.connection, "set_isolation_level"):
connection.connection.set_isolation_level(0)
def get_mysql_create_suffix():
suffix = []
if settings.TEST_DATABASE_CHARSET:
suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET)
if settings.TEST_DATABASE_COLLATION:
suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION)
return ' '.join(suffix)
def get_postgresql_create_suffix():
assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time."
if settings.TEST_DATABASE_CHARSET:
return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET
return ''
def create_test_db(verbosity=1, autoclobber=False):
"""
Creates a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
# If the database backend wants to create the test DB itself, let it
creation_module = get_creation_module()
if hasattr(creation_module, "create_test_db"):
creation_module.create_test_db(settings, connection, verbosity, autoclobber)
return
if verbosity >= 1:
print "Creating test database..."
# If we're using SQLite, it's more convenient to test against an
# in-memory database. Using the TEST_DATABASE_NAME setting you can still choose
# to run on a physical database.
if settings.DATABASE_ENGINE == "sqlite3":
if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:":
TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME
# Erase the old test database
if verbosity >= 1:
print "Destroying old test database..."
if os.access(TEST_DATABASE_NAME, os.F_OK):
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
os.remove(TEST_DATABASE_NAME)
except Exception, e:
sys.stderr.write("Got an error deleting the old test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
if verbosity >= 1:
print "Creating test database..."
else:
TEST_DATABASE_NAME = ":memory:"
else:
suffix = {
'postgresql': get_postgresql_create_suffix,
'postgresql_psycopg2': get_postgresql_create_suffix,
'mysql': get_mysql_create_suffix,
}.get(settings.DATABASE_ENGINE, lambda: '')()
if settings.TEST_DATABASE_NAME:
TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME
else:
TEST_DATABASE_NAME = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
qn = connection.ops.quote_name
# Create the test database and connect to it. We need to autocommit
# if the database supports it because PostgreSQL doesn't allow
# CREATE/DROP DATABASE statements within transactions.
cursor = connection.cursor()
_set_autocommit(connection)
try:
cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix))
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
cursor.execute("DROP DATABASE %s" % qn(TEST_DATABASE_NAME))
if verbosity >= 1:
print "Creating test database..."
cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix))
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
connection.close()
settings.DATABASE_NAME = TEST_DATABASE_NAME
call_command('syncdb', verbosity=verbosity, interactive=False)
if settings.CACHE_BACKEND.startswith('db://'):
cache_name = settings.CACHE_BACKEND[len('db://'):]
call_command('createcachetable', cache_name)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
cursor = connection.cursor()
return TEST_DATABASE_NAME
def destroy_test_db(old_database_name, verbosity=1):
# If the database wants to drop the test DB itself, let it
creation_module = get_creation_module()
if hasattr(creation_module, "destroy_test_db"):
creation_module.destroy_test_db(settings, connection, old_database_name, verbosity)
return
if verbosity >= 1:
print "Destroying test database..."
connection.close()
TEST_DATABASE_NAME = settings.DATABASE_NAME
settings.DATABASE_NAME = old_database_name
if settings.DATABASE_ENGINE == "sqlite3":
if TEST_DATABASE_NAME and TEST_DATABASE_NAME != ":memory:":
# Remove the SQLite database file
os.remove(TEST_DATABASE_NAME)
else:
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = connection.cursor()
_set_autocommit(connection)
time.sleep(1) # To avoid "database is being accessed by other users" errors.
cursor.execute("DROP DATABASE %s" % connection.ops.quote_name(TEST_DATABASE_NAME))
connection.close()

View File

@ -1026,6 +1026,9 @@ a number of utility methods in the ``django.test.utils`` module.
black magic hooks into the template system and restoring normal e-mail black magic hooks into the template system and restoring normal e-mail
services. services.
The creation module of the database backend (``connection.creation``) also
provides some utilities that can be useful during testing.
``create_test_db(verbosity=1, autoclobber=False)`` ``create_test_db(verbosity=1, autoclobber=False)``
Creates a new test database and runs ``syncdb`` against it. Creates a new test database and runs ``syncdb`` against it.
@ -1044,7 +1047,7 @@ a number of utility methods in the ``django.test.utils`` module.
``create_test_db()`` has the side effect of modifying ``create_test_db()`` has the side effect of modifying
``settings.DATABASE_NAME`` to match the name of the test database. ``settings.DATABASE_NAME`` to match the name of the test database.
New in the Django development version, this function returns the name of **New in Django development version:** This function returns the name of
the test database that it created. the test database that it created.
``destroy_test_db(old_database_name, verbosity=1)`` ``destroy_test_db(old_database_name, verbosity=1)``

View File

@ -15,10 +15,6 @@ class Person(models.Model):
def __unicode__(self): def __unicode__(self):
return u'%s %s' % (self.first_name, self.last_name) return u'%s %s' % (self.first_name, self.last_name)
if connection.features.uses_case_insensitive_names:
t_convert = lambda x: x.upper()
else:
t_convert = lambda x: x
qn = connection.ops.quote_name qn = connection.ops.quote_name
__test__ = {'API_TESTS': """ __test__ = {'API_TESTS': """
@ -29,7 +25,7 @@ __test__ = {'API_TESTS': """
>>> opts = Square._meta >>> opts = Square._meta
>>> f1, f2 = opts.get_field('root'), opts.get_field('square') >>> f1, f2 = opts.get_field('root'), opts.get_field('square')
>>> query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' >>> query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)'
... % (t_convert(opts.db_table), qn(f1.column), qn(f2.column))) ... % (connection.introspection.table_name_converter(opts.db_table), qn(f1.column), qn(f2.column)))
>>> cursor.executemany(query, [(i, i**2) for i in range(-5, 6)]) and None or None >>> cursor.executemany(query, [(i, i**2) for i in range(-5, 6)]) and None or None
>>> Square.objects.order_by('root') >>> Square.objects.order_by('root')
[<Square: -5 ** 2 == 25>, <Square: -4 ** 2 == 16>, <Square: -3 ** 2 == 9>, <Square: -2 ** 2 == 4>, <Square: -1 ** 2 == 1>, <Square: 0 ** 2 == 0>, <Square: 1 ** 2 == 1>, <Square: 2 ** 2 == 4>, <Square: 3 ** 2 == 9>, <Square: 4 ** 2 == 16>, <Square: 5 ** 2 == 25>] [<Square: -5 ** 2 == 25>, <Square: -4 ** 2 == 16>, <Square: -3 ** 2 == 9>, <Square: -2 ** 2 == 4>, <Square: -1 ** 2 == 1>, <Square: 0 ** 2 == 0>, <Square: 1 ** 2 == 1>, <Square: 2 ** 2 == 4>, <Square: 3 ** 2 == 9>, <Square: 4 ** 2 == 16>, <Square: 5 ** 2 == 25>]
@ -48,7 +44,7 @@ __test__ = {'API_TESTS': """
>>> opts2 = Person._meta >>> opts2 = Person._meta
>>> f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') >>> f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
>>> query2 = ('SELECT %s, %s FROM %s ORDER BY %s' >>> query2 = ('SELECT %s, %s FROM %s ORDER BY %s'
... % (qn(f3.column), qn(f4.column), t_convert(opts2.db_table), ... % (qn(f3.column), qn(f4.column), connection.introspection.table_name_converter(opts2.db_table),
... qn(f3.column))) ... qn(f3.column)))
>>> cursor.execute(query2) and None or None >>> cursor.execute(query2) and None or None
>>> cursor.fetchone() >>> cursor.fetchone()