1
0
mirror of https://github.com/django/django.git synced 2025-06-05 03:29:12 +00:00

Fixed #5461 -- Refactored the database backend code to use classes for the creation and introspection modules. Introduces a new validation module for DB-specific validation. This is a backwards incompatible change; see the wiki for details.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@8296 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2008-08-11 12:11:25 +00:00
parent cec69eb70d
commit 9dc4ba875f
43 changed files with 1528 additions and 1423 deletions

View File

@ -1,5 +1,5 @@
from django.test.utils import create_test_db
def create_spatial_db(test=True, verbosity=1, autoclobber=False): def create_spatial_db(test=True, verbosity=1, autoclobber=False):
if not test: raise NotImplementedError('This uses `create_test_db` from test/utils.py') if not test: raise NotImplementedError('This uses `create_test_db` from test/utils.py')
create_test_db(verbosity, autoclobber) from django.db import connection
connection.creation.create_test_db(verbosity, autoclobber)

View File

@ -1,8 +1,6 @@
from django.db.backends.oracle.creation import create_test_db
def create_spatial_db(test=True, verbosity=1, autoclobber=False): def create_spatial_db(test=True, verbosity=1, autoclobber=False):
"A wrapper over the Oracle `create_test_db` routine." "A wrapper over the Oracle `create_test_db` routine."
if not test: raise NotImplementedError('This uses `create_test_db` from db/backends/oracle/creation.py') if not test: raise NotImplementedError('This uses `create_test_db` from db/backends/oracle/creation.py')
from django.conf import settings
from django.db import connection from django.db import connection
create_test_db(settings, connection, verbosity, autoclobber) connection.creation.create_test_db(verbosity, autoclobber)

View File

@ -1,7 +1,7 @@
from django.conf import settings from django.conf import settings
from django.core.management import call_command from django.core.management import call_command
from django.db import connection from django.db import connection
from django.test.utils import _set_autocommit, TEST_DATABASE_PREFIX from django.db.backends.creation import TEST_DATABASE_PREFIX
import os, re, sys import os, re, sys
def getstatusoutput(cmd): def getstatusoutput(cmd):
@ -38,9 +38,9 @@ def _create_with_cursor(db_name, verbosity=1, autoclobber=False):
create_sql = 'CREATE DATABASE %s' % connection.ops.quote_name(db_name) create_sql = 'CREATE DATABASE %s' % connection.ops.quote_name(db_name)
if settings.DATABASE_USER: if settings.DATABASE_USER:
create_sql += ' OWNER %s' % settings.DATABASE_USER create_sql += ' OWNER %s' % settings.DATABASE_USER
cursor = connection.cursor() cursor = connection.cursor()
_set_autocommit(connection) connection.creation.set_autocommit(connection)
try: try:
# Trying to create the database first. # Trying to create the database first.
@ -58,12 +58,12 @@ def _create_with_cursor(db_name, verbosity=1, autoclobber=False):
else: else:
raise Exception('Spatial Database Creation canceled.') raise Exception('Spatial Database Creation canceled.')
foo = _create_with_cursor foo = _create_with_cursor
created_regex = re.compile(r'^createdb: database creation failed: ERROR: database ".+" already exists') created_regex = re.compile(r'^createdb: database creation failed: ERROR: database ".+" already exists')
def _create_with_shell(db_name, verbosity=1, autoclobber=False): def _create_with_shell(db_name, verbosity=1, autoclobber=False):
""" """
If no spatial database already exists, then using a cursor will not work. If no spatial database already exists, then using a cursor will not work.
Thus, a `createdb` command will be issued through the shell to bootstrap Thus, a `createdb` command will be issued through the shell to bootstrap
creation of the spatial database. creation of the spatial database.
""" """
@ -83,7 +83,7 @@ def _create_with_shell(db_name, verbosity=1, autoclobber=False):
if verbosity >= 1: print 'Destroying old spatial database...' if verbosity >= 1: print 'Destroying old spatial database...'
drop_cmd = 'dropdb %s%s' % (options, db_name) drop_cmd = 'dropdb %s%s' % (options, db_name)
status, output = getstatusoutput(drop_cmd) status, output = getstatusoutput(drop_cmd)
if status != 0: if status != 0:
raise Exception('Could not drop database %s: %s' % (db_name, output)) raise Exception('Could not drop database %s: %s' % (db_name, output))
if verbosity >= 1: print 'Creating new spatial database...' if verbosity >= 1: print 'Creating new spatial database...'
status, output = getstatusoutput(create_cmd) status, output = getstatusoutput(create_cmd)
@ -102,10 +102,10 @@ def create_spatial_db(test=False, verbosity=1, autoclobber=False, interactive=Fa
raise Exception('Spatial database creation only supported postgresql_psycopg2 platform.') raise Exception('Spatial database creation only supported postgresql_psycopg2 platform.')
# Getting the spatial database name # Getting the spatial database name
if test: if test:
db_name = get_spatial_db(test=True) db_name = get_spatial_db(test=True)
_create_with_cursor(db_name, verbosity=verbosity, autoclobber=autoclobber) _create_with_cursor(db_name, verbosity=verbosity, autoclobber=autoclobber)
else: else:
db_name = get_spatial_db() db_name = get_spatial_db()
_create_with_shell(db_name, verbosity=verbosity, autoclobber=autoclobber) _create_with_shell(db_name, verbosity=verbosity, autoclobber=autoclobber)
@ -125,7 +125,7 @@ def create_spatial_db(test=False, verbosity=1, autoclobber=False, interactive=Fa
# Syncing the database # Syncing the database
call_command('syncdb', verbosity=verbosity, interactive=interactive) call_command('syncdb', verbosity=verbosity, interactive=interactive)
def drop_db(db_name=False, test=False): def drop_db(db_name=False, test=False):
""" """
Drops the given database (defaults to what is returned from Drops the given database (defaults to what is returned from
@ -151,7 +151,7 @@ def get_cmd_options(db_name):
def get_spatial_db(test=False): def get_spatial_db(test=False):
""" """
Returns the name of the spatial database. The 'test' keyword may be set Returns the name of the spatial database. The 'test' keyword may be set
to return the test spatial database name. to return the test spatial database name.
""" """
if test: if test:
@ -167,13 +167,13 @@ def get_spatial_db(test=False):
def load_postgis_sql(db_name, verbosity=1): def load_postgis_sql(db_name, verbosity=1):
""" """
This routine loads up the PostGIS SQL files lwpostgis.sql and This routine loads up the PostGIS SQL files lwpostgis.sql and
spatial_ref_sys.sql. spatial_ref_sys.sql.
""" """
# Getting the path to the PostGIS SQL # Getting the path to the PostGIS SQL
try: try:
# POSTGIS_SQL_PATH may be placed in settings to tell GeoDjango where the # POSTGIS_SQL_PATH may be placed in settings to tell GeoDjango where the
# PostGIS SQL files are located. This is especially useful on Win32 # PostGIS SQL files are located. This is especially useful on Win32
# platforms since the output of pg_config looks like "C:/PROGRA~1/..". # platforms since the output of pg_config looks like "C:/PROGRA~1/..".
sql_path = settings.POSTGIS_SQL_PATH sql_path = settings.POSTGIS_SQL_PATH
@ -193,7 +193,7 @@ def load_postgis_sql(db_name, verbosity=1):
# Getting the psql command-line options, and command format. # Getting the psql command-line options, and command format.
options = get_cmd_options(db_name) options = get_cmd_options(db_name)
cmd_fmt = 'psql %s-f "%%s"' % options cmd_fmt = 'psql %s-f "%%s"' % options
# Now trying to load up the PostGIS functions # Now trying to load up the PostGIS functions
cmd = cmd_fmt % lwpostgis_file cmd = cmd_fmt % lwpostgis_file
if verbosity >= 1: print cmd if verbosity >= 1: print cmd
@ -211,8 +211,8 @@ def load_postgis_sql(db_name, verbosity=1):
# Setting the permissions because on Windows platforms the owner # Setting the permissions because on Windows platforms the owner
# of the spatial_ref_sys and geometry_columns tables is always # of the spatial_ref_sys and geometry_columns tables is always
# the postgres user, regardless of how the db is created. # the postgres user, regardless of how the db is created.
if os.name == 'nt': set_permissions(db_name) if os.name == 'nt': set_permissions(db_name)
def set_permissions(db_name): def set_permissions(db_name):
""" """
Sets the permissions on the given database to that of the user specified Sets the permissions on the given database to that of the user specified

View File

@ -7,7 +7,7 @@ from django.core.management.commands.inspectdb import Command as InspectCommand
from django.contrib.gis.db.backend import SpatialBackend from django.contrib.gis.db.backend import SpatialBackend
class Command(InspectCommand): class Command(InspectCommand):
# Mapping from lower-case OGC type to the corresponding GeoDjango field. # Mapping from lower-case OGC type to the corresponding GeoDjango field.
geofield_mapping = {'point' : 'PointField', geofield_mapping = {'point' : 'PointField',
'linestring' : 'LineStringField', 'linestring' : 'LineStringField',
@ -21,11 +21,11 @@ class Command(InspectCommand):
def geometry_columns(self): def geometry_columns(self):
""" """
Returns a datastructure of metadata information associated with the Returns a datastructure of metadata information associated with the
`geometry_columns` (or equivalent) table. `geometry_columns` (or equivalent) table.
""" """
# The `geo_cols` is a dictionary data structure that holds information # The `geo_cols` is a dictionary data structure that holds information
# about any geographic columns in the database. # about any geographic columns in the database.
geo_cols = {} geo_cols = {}
def add_col(table, column, coldata): def add_col(table, column, coldata):
if table in geo_cols: if table in geo_cols:
@ -47,7 +47,7 @@ class Command(InspectCommand):
elif SpatialBackend.name == 'mysql': elif SpatialBackend.name == 'mysql':
# On MySQL have to get all table metadata before hand; this means walking through # On MySQL have to get all table metadata before hand; this means walking through
# each table and seeing if any column types are spatial. Can't detect this with # each table and seeing if any column types are spatial. Can't detect this with
# `cursor.description` (what the introspection module does) because all spatial types # `cursor.description` (what the introspection module does) because all spatial types
# have the same integer type (255 for GEOMETRY). # have the same integer type (255 for GEOMETRY).
from django.db import connection from django.db import connection
cursor = connection.cursor() cursor = connection.cursor()
@ -67,13 +67,11 @@ class Command(InspectCommand):
def handle_inspection(self): def handle_inspection(self):
"Overloaded from Django's version to handle geographic database tables." "Overloaded from Django's version to handle geographic database tables."
from django.db import connection, get_introspection_module from django.db import connection
import keyword import keyword
introspection_module = get_introspection_module()
geo_cols = self.geometry_columns() geo_cols = self.geometry_columns()
table2model = lambda table_name: table_name.title().replace('_', '') table2model = lambda table_name: table_name.title().replace('_', '')
cursor = connection.cursor() cursor = connection.cursor()
@ -88,20 +86,20 @@ class Command(InspectCommand):
yield '' yield ''
yield 'from django.contrib.gis.db import models' yield 'from django.contrib.gis.db import models'
yield '' yield ''
for table_name in introspection_module.get_table_list(cursor): for table_name in connection.introspection.get_table_list(cursor):
# Getting the geographic table dictionary. # Getting the geographic table dictionary.
geo_table = geo_cols.get(table_name, {}) geo_table = geo_cols.get(table_name, {})
yield 'class %s(models.Model):' % table2model(table_name) yield 'class %s(models.Model):' % table2model(table_name)
try: try:
relations = introspection_module.get_relations(cursor, table_name) relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError: except NotImplementedError:
relations = {} relations = {}
try: try:
indexes = introspection_module.get_indexes(cursor, table_name) indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError: except NotImplementedError:
indexes = {} indexes = {}
for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)): for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
att_name, iatt_name = row[0].lower(), row[0] att_name, iatt_name = row[0].lower(), row[0]
comment_notes = [] # Holds Field notes, to be displayed in a Python comment. comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'. extra_params = {} # Holds Field parameters such as 'db_column'.
@ -133,12 +131,12 @@ class Command(InspectCommand):
if srid != 4326: extra_params['srid'] = srid if srid != 4326: extra_params['srid'] = srid
else: else:
try: try:
field_type = introspection_module.DATA_TYPES_REVERSE[row[1]] field_type = connection.introspection.data_types_reverse[row[1]]
except KeyError: except KeyError:
field_type = 'TextField' field_type = 'TextField'
comment_notes.append('This field type is a guess.') comment_notes.append('This field type is a guess.')
# This is a hook for DATA_TYPES_REVERSE to return a tuple of # This is a hook for data_types_reverse to return a tuple of
# (field_type, extra_params_dict). # (field_type, extra_params_dict).
if type(field_type) is tuple: if type(field_type) is tuple:
field_type, new_params = field_type field_type, new_params = field_type

View File

@ -6,5 +6,5 @@ class Command(NoArgsCommand):
requires_model_validation = False requires_model_validation = False
def handle_noargs(self, **options): def handle_noargs(self, **options):
from django.db import runshell from django.db import connection
runshell() connection.client.runshell()

View File

@ -13,11 +13,9 @@ class Command(NoArgsCommand):
raise CommandError("Database inspection isn't supported for the currently selected database backend.") raise CommandError("Database inspection isn't supported for the currently selected database backend.")
def handle_inspection(self): def handle_inspection(self):
from django.db import connection, get_introspection_module from django.db import connection
import keyword import keyword
introspection_module = get_introspection_module()
table2model = lambda table_name: table_name.title().replace('_', '') table2model = lambda table_name: table_name.title().replace('_', '')
cursor = connection.cursor() cursor = connection.cursor()
@ -32,17 +30,17 @@ class Command(NoArgsCommand):
yield '' yield ''
yield 'from django.db import models' yield 'from django.db import models'
yield '' yield ''
for table_name in introspection_module.get_table_list(cursor): for table_name in connection.introspection.get_table_list(cursor):
yield 'class %s(models.Model):' % table2model(table_name) yield 'class %s(models.Model):' % table2model(table_name)
try: try:
relations = introspection_module.get_relations(cursor, table_name) relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError: except NotImplementedError:
relations = {} relations = {}
try: try:
indexes = introspection_module.get_indexes(cursor, table_name) indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError: except NotImplementedError:
indexes = {} indexes = {}
for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)): for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
att_name = row[0].lower() att_name = row[0].lower()
comment_notes = [] # Holds Field notes, to be displayed in a Python comment. comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'. extra_params = {} # Holds Field parameters such as 'db_column'.
@ -65,7 +63,7 @@ class Command(NoArgsCommand):
extra_params['db_column'] = att_name extra_params['db_column'] = att_name
else: else:
try: try:
field_type = introspection_module.DATA_TYPES_REVERSE[row[1]] field_type = connection.introspection.data_types_reverse[row[1]]
except KeyError: except KeyError:
field_type = 'TextField' field_type = 'TextField'
comment_notes.append('This field type is a guess.') comment_notes.append('This field type is a guess.')

View File

@ -21,7 +21,7 @@ class Command(NoArgsCommand):
def handle_noargs(self, **options): def handle_noargs(self, **options):
from django.db import connection, transaction, models from django.db import connection, transaction, models
from django.conf import settings from django.conf import settings
from django.core.management.sql import table_names, installed_models, sql_model_create, sql_for_pending_references, many_to_many_sql_for_model, custom_sql_for_model, sql_indexes_for_model, emit_post_sync_signal from django.core.management.sql import custom_sql_for_model, emit_post_sync_signal
verbosity = int(options.get('verbosity', 1)) verbosity = int(options.get('verbosity', 1))
interactive = options.get('interactive') interactive = options.get('interactive')
@ -50,16 +50,9 @@ class Command(NoArgsCommand):
cursor = connection.cursor() cursor = connection.cursor()
if connection.features.uses_case_insensitive_names:
table_name_converter = lambda x: x.upper()
else:
table_name_converter = lambda x: x
# Get a list of all existing database tables, so we know what needs to
# be added.
tables = [table_name_converter(name) for name in table_names()]
# Get a list of already installed *models* so that references work right. # Get a list of already installed *models* so that references work right.
seen_models = installed_models(tables) tables = connection.introspection.table_names()
seen_models = connection.introspection.installed_models(tables)
created_models = set() created_models = set()
pending_references = {} pending_references = {}
@ -71,21 +64,21 @@ class Command(NoArgsCommand):
# Create the model's database table, if it doesn't already exist. # Create the model's database table, if it doesn't already exist.
if verbosity >= 2: if verbosity >= 2:
print "Processing %s.%s model" % (app_name, model._meta.object_name) print "Processing %s.%s model" % (app_name, model._meta.object_name)
if table_name_converter(model._meta.db_table) in tables: if connection.introspection.table_name_converter(model._meta.db_table) in tables:
continue continue
sql, references = sql_model_create(model, self.style, seen_models) sql, references = connection.creation.sql_create_model(model, self.style, seen_models)
seen_models.add(model) seen_models.add(model)
created_models.add(model) created_models.add(model)
for refto, refs in references.items(): for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs) pending_references.setdefault(refto, []).extend(refs)
if refto in seen_models: if refto in seen_models:
sql.extend(sql_for_pending_references(refto, self.style, pending_references)) sql.extend(connection.creation.sql_for_pending_references(refto, self.style, pending_references))
sql.extend(sql_for_pending_references(model, self.style, pending_references)) sql.extend(connection.creation.sql_for_pending_references(model, self.style, pending_references))
if verbosity >= 1: if verbosity >= 1:
print "Creating table %s" % model._meta.db_table print "Creating table %s" % model._meta.db_table
for statement in sql: for statement in sql:
cursor.execute(statement) cursor.execute(statement)
tables.append(table_name_converter(model._meta.db_table)) tables.append(connection.introspection.table_name_converter(model._meta.db_table))
# Create the m2m tables. This must be done after all tables have been created # Create the m2m tables. This must be done after all tables have been created
# to ensure that all referred tables will exist. # to ensure that all referred tables will exist.
@ -94,7 +87,7 @@ class Command(NoArgsCommand):
model_list = models.get_models(app) model_list = models.get_models(app)
for model in model_list: for model in model_list:
if model in created_models: if model in created_models:
sql = many_to_many_sql_for_model(model, self.style) sql = connection.creation.sql_for_many_to_many(model, self.style)
if sql: if sql:
if verbosity >= 2: if verbosity >= 2:
print "Creating many-to-many tables for %s.%s model" % (app_name, model._meta.object_name) print "Creating many-to-many tables for %s.%s model" % (app_name, model._meta.object_name)
@ -140,7 +133,7 @@ class Command(NoArgsCommand):
app_name = app.__name__.split('.')[-2] app_name = app.__name__.split('.')[-2]
for model in models.get_models(app): for model in models.get_models(app):
if model in created_models: if model in created_models:
index_sql = sql_indexes_for_model(model, self.style) index_sql = connection.creation.sql_indexes_for_model(model, self.style)
if index_sql: if index_sql:
if verbosity >= 1: if verbosity >= 1:
print "Installing index for %s.%s model" % (app_name, model._meta.object_name) print "Installing index for %s.%s model" % (app_name, model._meta.object_name)

View File

@ -18,13 +18,13 @@ class Command(BaseCommand):
def handle(self, *fixture_labels, **options): def handle(self, *fixture_labels, **options):
from django.core.management import call_command from django.core.management import call_command
from django.test.utils import create_test_db from django.db import connection
verbosity = int(options.get('verbosity', 1)) verbosity = int(options.get('verbosity', 1))
addrport = options.get('addrport') addrport = options.get('addrport')
# Create a test database. # Create a test database.
db_name = create_test_db(verbosity=verbosity) db_name = connection.creation.create_test_db(verbosity=verbosity)
# Import the fixture data into the test database. # Import the fixture data into the test database.
call_command('loaddata', *fixture_labels, **{'verbosity': verbosity}) call_command('loaddata', *fixture_labels, **{'verbosity': verbosity})

View File

@ -7,65 +7,9 @@ try:
except NameError: except NameError:
from sets import Set as set # Python 2.3 fallback from sets import Set as set # Python 2.3 fallback
def table_names():
"Returns a list of all table names that exist in the database."
from django.db import connection, get_introspection_module
cursor = connection.cursor()
return set(get_introspection_module().get_table_list(cursor))
def django_table_names(only_existing=False):
"""
Returns a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, the resulting list will only include the tables
that actually exist in the database.
"""
from django.db import models
tables = set()
for app in models.get_apps():
for model in models.get_models(app):
tables.add(model._meta.db_table)
tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
tables = [t for t in tables if t in table_names()]
return tables
def installed_models(table_list):
"Returns a set of all models that are installed, given a list of existing table names."
from django.db import connection, models
all_models = []
for app in models.get_apps():
for model in models.get_models(app):
all_models.append(model)
if connection.features.uses_case_insensitive_names:
converter = lambda x: x.upper()
else:
converter = lambda x: x
return set([m for m in all_models if converter(m._meta.db_table) in map(converter, table_list)])
def sequence_list():
"Returns a list of information about all DB sequences for all models in all apps."
from django.db import models
apps = models.get_apps()
sequence_list = []
for app in apps:
for model in models.get_models(app):
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
def sql_create(app, style): def sql_create(app, style):
"Returns a list of the CREATE TABLE SQL statements for the given app." "Returns a list of the CREATE TABLE SQL statements for the given app."
from django.db import models from django.db import connection, models
from django.conf import settings from django.conf import settings
if settings.DATABASE_ENGINE == 'dummy': if settings.DATABASE_ENGINE == 'dummy':
@ -81,23 +25,24 @@ def sql_create(app, style):
# we can be conservative). # we can be conservative).
app_models = models.get_models(app) app_models = models.get_models(app)
final_output = [] final_output = []
known_models = set([model for model in installed_models(table_names()) if model not in app_models]) tables = connection.introspection.table_names()
known_models = set([model for model in connection.introspection.installed_models(tables) if model not in app_models])
pending_references = {} pending_references = {}
for model in app_models: for model in app_models:
output, references = sql_model_create(model, style, known_models) output, references = connection.creation.sql_create_model(model, style, known_models)
final_output.extend(output) final_output.extend(output)
for refto, refs in references.items(): for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs) pending_references.setdefault(refto, []).extend(refs)
if refto in known_models: if refto in known_models:
final_output.extend(sql_for_pending_references(refto, style, pending_references)) final_output.extend(connection.creation.sql_for_pending_references(refto, style, pending_references))
final_output.extend(sql_for_pending_references(model, style, pending_references)) final_output.extend(connection.creation.sql_for_pending_references(model, style, pending_references))
# Keep track of the fact that we've created the table for this model. # Keep track of the fact that we've created the table for this model.
known_models.add(model) known_models.add(model)
# Create the many-to-many join tables. # Create the many-to-many join tables.
for model in app_models: for model in app_models:
final_output.extend(many_to_many_sql_for_model(model, style)) final_output.extend(connection.creation.sql_for_many_to_many(model, style))
# Handle references to tables that are from other apps # Handle references to tables that are from other apps
# but don't exist physically. # but don't exist physically.
@ -106,7 +51,7 @@ def sql_create(app, style):
alter_sql = [] alter_sql = []
for model in not_installed_models: for model in not_installed_models:
alter_sql.extend(['-- ' + sql for sql in alter_sql.extend(['-- ' + sql for sql in
sql_for_pending_references(model, style, pending_references)]) connection.creation.sql_for_pending_references(model, style, pending_references)])
if alter_sql: if alter_sql:
final_output.append('-- The following references should be added but depend on non-existent tables:') final_output.append('-- The following references should be added but depend on non-existent tables:')
final_output.extend(alter_sql) final_output.extend(alter_sql)
@ -115,10 +60,9 @@ def sql_create(app, style):
def sql_delete(app, style): def sql_delete(app, style):
"Returns a list of the DROP TABLE SQL statements for the given app." "Returns a list of the DROP TABLE SQL statements for the given app."
from django.db import connection, models, get_introspection_module from django.db import connection, models
from django.db.backends.util import truncate_name from django.db.backends.util import truncate_name
from django.contrib.contenttypes import generic from django.contrib.contenttypes import generic
introspection = get_introspection_module()
# This should work even if a connection isn't available # This should work even if a connection isn't available
try: try:
@ -128,16 +72,11 @@ def sql_delete(app, style):
# Figure out which tables already exist # Figure out which tables already exist
if cursor: if cursor:
table_names = introspection.get_table_list(cursor) table_names = connection.introspection.get_table_list(cursor)
else: else:
table_names = [] table_names = []
if connection.features.uses_case_insensitive_names:
table_name_converter = lambda x: x.upper()
else:
table_name_converter = lambda x: x
output = [] output = []
qn = connection.ops.quote_name
# Output DROP TABLE statements for standard application tables. # Output DROP TABLE statements for standard application tables.
to_delete = set() to_delete = set()
@ -145,7 +84,7 @@ def sql_delete(app, style):
references_to_delete = {} references_to_delete = {}
app_models = models.get_models(app) app_models = models.get_models(app)
for model in app_models: for model in app_models:
if cursor and table_name_converter(model._meta.db_table) in table_names: if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped # The table exists, so it needs to be dropped
opts = model._meta opts = model._meta
for f in opts.local_fields: for f in opts.local_fields:
@ -155,40 +94,15 @@ def sql_delete(app, style):
to_delete.add(model) to_delete.add(model)
for model in app_models: for model in app_models:
if cursor and table_name_converter(model._meta.db_table) in table_names: if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# Drop the table now output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
output.append('%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(model._meta.db_table))))
if connection.features.supports_constraints and model in references_to_delete:
for rel_class, f in references_to_delete[model]:
table = rel_class._meta.db_table
col = f.column
r_table = model._meta.db_table
r_col = model._meta.get_field(f.rel.field_name).column
r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table))))
output.append('%s %s %s %s;' % \
(style.SQL_KEYWORD('ALTER TABLE'),
style.SQL_TABLE(qn(table)),
style.SQL_KEYWORD(connection.ops.drop_foreignkey_sql()),
style.SQL_FIELD(truncate_name(r_name, connection.ops.max_name_length()))))
del references_to_delete[model]
if model._meta.has_auto_field:
ds = connection.ops.drop_sequence_sql(model._meta.db_table)
if ds:
output.append(ds)
# Output DROP TABLE statements for many-to-many tables. # Output DROP TABLE statements for many-to-many tables.
for model in app_models: for model in app_models:
opts = model._meta opts = model._meta
for f in opts.local_many_to_many: for f in opts.local_many_to_many:
if not f.creates_table: if cursor and connection.introspection.table_name_converter(f.m2m_db_table()) in table_names:
continue output.extend(connection.creation.sql_destroy_many_to_many(model, f, style))
if cursor and table_name_converter(f.m2m_db_table()) in table_names:
output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(f.m2m_db_table()))))
ds = connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column))
if ds:
output.append(ds)
app_label = app_models[0]._meta.app_label app_label = app_models[0]._meta.app_label
@ -213,10 +127,10 @@ def sql_flush(style, only_django=False):
""" """
from django.db import connection from django.db import connection
if only_django: if only_django:
tables = django_table_names() tables = connection.introspection.django_table_names()
else: else:
tables = table_names() tables = connection.introspection.table_names()
statements = connection.ops.sql_flush(style, tables, sequence_list()) statements = connection.ops.sql_flush(style, tables, connection.introspection.sequence_list())
return statements return statements
def sql_custom(app, style): def sql_custom(app, style):
@ -234,198 +148,16 @@ def sql_custom(app, style):
def sql_indexes(app, style): def sql_indexes(app, style):
"Returns a list of the CREATE INDEX SQL statements for all models in the given app." "Returns a list of the CREATE INDEX SQL statements for all models in the given app."
from django.db import models from django.db import connection, models
output = [] output = []
for model in models.get_models(app): for model in models.get_models(app):
output.extend(sql_indexes_for_model(model, style)) output.extend(connection.creation.sql_indexes_for_model(model, style))
return output return output
def sql_all(app, style): def sql_all(app, style):
"Returns a list of CREATE TABLE SQL, initial-data inserts, and CREATE INDEX SQL for the given module." "Returns a list of CREATE TABLE SQL, initial-data inserts, and CREATE INDEX SQL for the given module."
return sql_create(app, style) + sql_custom(app, style) + sql_indexes(app, style) return sql_create(app, style) + sql_custom(app, style) + sql_indexes(app, style)
def sql_model_create(model, style, known_models=set()):
"""
Returns the SQL required to create a single model, as a tuple of:
(list_of_sql, pending_references_dict)
"""
from django.db import connection, models
opts = model._meta
final_output = []
table_output = []
pending_references = {}
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
# Skip ManyToManyFields, because they're not represented as
# database columns in this table.
continue
# Make the definition (e.g. 'foo VARCHAR(30)') for this field.
field_output = [style.SQL_FIELD(qn(f.column)),
style.SQL_COLTYPE(col_type)]
field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or '')))
if f.primary_key:
field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))
elif f.unique:
field_output.append(style.SQL_KEYWORD('UNIQUE'))
if tablespace and connection.features.supports_tablespaces and f.unique:
# We must specify the index tablespace inline, because we
# won't be generating a CREATE INDEX statement for this field.
field_output.append(connection.ops.tablespace_sql(tablespace, inline=True))
if f.rel:
if inline_references and f.rel.to in known_models:
field_output.append(style.SQL_KEYWORD('REFERENCES') + ' ' + \
style.SQL_TABLE(qn(f.rel.to._meta.db_table)) + ' (' + \
style.SQL_FIELD(qn(f.rel.to._meta.get_field(f.rel.field_name).column)) + ')' +
connection.ops.deferrable_sql()
)
else:
# We haven't yet created the table to which this field
# is related, so save it for later.
pr = pending_references.setdefault(f.rel.to, []).append((model, f))
table_output.append(' '.join(field_output))
if opts.order_with_respect_to:
table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \
style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \
style.SQL_KEYWORD('NULL'))
for field_constraints in opts.unique_together:
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
for i, line in enumerate(table_output): # Combine and add commas.
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
full_statement.append(')')
if opts.db_tablespace and connection.features.supports_tablespaces:
full_statement.append(connection.ops.tablespace_sql(opts.db_tablespace))
full_statement.append(';')
final_output.append('\n'.join(full_statement))
if opts.has_auto_field:
# Add any extra SQL needed to support auto-incrementing primary keys.
auto_column = opts.auto_field.db_column or opts.auto_field.name
autoinc_sql = connection.ops.autoinc_sql(opts.db_table, auto_column)
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output, pending_references
def sql_for_pending_references(model, style, pending_references):
"""
Returns any ALTER TABLE statements to add constraints after the fact.
"""
from django.db import connection
from django.db.backends.util import truncate_name
qn = connection.ops.quote_name
final_output = []
if connection.features.supports_constraints:
opts = model._meta
if model in pending_references:
for rel_class, f in pending_references[model]:
rel_opts = rel_class._meta
r_table = rel_opts.db_table
r_col = f.column
table = opts.db_table
col = opts.get_field(f.rel.field_name).column
# For MySQL, r_name must be unique in the first 64 characters.
# So we are careful with character usage here.
r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
(qn(r_table), truncate_name(r_name, connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
connection.ops.deferrable_sql()))
del pending_references[model]
return final_output
def many_to_many_sql_for_model(model, style):
from django.db import connection, models
from django.contrib.contenttypes import generic
from django.db.backends.util import truncate_name
opts = model._meta
final_output = []
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
for f in opts.local_many_to_many:
if f.creates_table:
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace and connection.features.supports_tablespaces:
tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
else:
tablespace_sql = ''
table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \
style.SQL_TABLE(qn(f.m2m_db_table())) + ' (']
table_output.append(' %s %s %s%s,' %
(style.SQL_FIELD(qn('id')),
style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()),
style.SQL_KEYWORD('NOT NULL PRIMARY KEY'),
tablespace_sql))
if inline_references:
deferred = []
table_output.append(' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(opts.db_table)),
style.SQL_FIELD(qn(opts.pk.column)),
connection.ops.deferrable_sql()))
table_output.append(' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(f.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(f.rel.to._meta.db_table)),
style.SQL_FIELD(qn(f.rel.to._meta.pk.column)),
connection.ops.deferrable_sql()))
else:
table_output.append(' %s %s %s,' %
(style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL')))
table_output.append(' %s %s %s,' %
(style.SQL_FIELD(qn(f.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL')))
deferred = [
(f.m2m_db_table(), f.m2m_column_name(), opts.db_table,
opts.pk.column),
( f.m2m_db_table(), f.m2m_reverse_name(),
f.rel.to._meta.db_table, f.rel.to._meta.pk.column)
]
table_output.append(' %s (%s, %s)%s' %
(style.SQL_KEYWORD('UNIQUE'),
style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_FIELD(qn(f.m2m_reverse_name())),
tablespace_sql))
table_output.append(')')
if opts.db_tablespace and connection.features.supports_tablespaces:
# f.db_tablespace is only for indices, so ignore its value here.
table_output.append(connection.ops.tablespace_sql(opts.db_tablespace))
table_output.append(';')
final_output.append('\n'.join(table_output))
for r_table, r_col, table, col in deferred:
r_name = '%s_refs_%s_%x' % (r_col, col,
abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
(qn(r_table),
truncate_name(r_name, connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
connection.ops.deferrable_sql()))
# Add any extra SQL needed to support auto-incrementing PKs
autoinc_sql = connection.ops.autoinc_sql(f.m2m_db_table(), 'id')
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output
def custom_sql_for_model(model, style): def custom_sql_for_model(model, style):
from django.db import models from django.db import models
from django.conf import settings from django.conf import settings
@ -461,28 +193,6 @@ def custom_sql_for_model(model, style):
return output return output
def sql_indexes_for_model(model, style):
"Returns the CREATE INDEX SQL statements for a single model"
from django.db import connection
output = []
qn = connection.ops.quote_name
for f in model._meta.local_fields:
if f.db_index and not f.unique:
tablespace = f.db_tablespace or model._meta.db_tablespace
if tablespace and connection.features.supports_tablespaces:
tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace)
else:
tablespace_sql = ''
output.append(
style.SQL_KEYWORD('CREATE INDEX') + ' ' + \
style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' + \
style.SQL_KEYWORD('ON') + ' ' + \
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' + \
"(%s)" % style.SQL_FIELD(qn(f.column)) + \
"%s;" % tablespace_sql
)
return output
def emit_post_sync_signal(created_models, verbosity, interactive): def emit_post_sync_signal(created_models, verbosity, interactive):
from django.db import models from django.db import models

View File

@ -61,11 +61,8 @@ def get_validation_errors(outfile, app=None):
if f.db_index not in (None, True, False): if f.db_index not in (None, True, False):
e.add(opts, '"%s": "db_index" should be either None, True or False.' % f.name) e.add(opts, '"%s": "db_index" should be either None, True or False.' % f.name)
# Check that max_length <= 255 if using older MySQL versions. # Perform any backend-specific field validation.
if settings.DATABASE_ENGINE == 'mysql': connection.validation.validate_field(e, opts, f)
db_version = connection.get_server_version()
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
e.add(opts, '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))
# Check to see if the related field will clash with any existing # Check to see if the related field will clash with any existing
# fields, m2m fields, m2m related objects or related objects # fields, m2m fields, m2m related objects or related objects

View File

@ -14,14 +14,12 @@ try:
# backends that ships with Django, so look there first. # backends that ships with Django, so look there first.
_import_path = 'django.db.backends.' _import_path = 'django.db.backends.'
backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, ['']) backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
creation = __import__('%s%s.creation' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
except ImportError, e: except ImportError, e:
# If the import failed, we might be looking for a database backend # If the import failed, we might be looking for a database backend
# distributed external to Django. So we'll try that next. # distributed external to Django. So we'll try that next.
try: try:
_import_path = '' _import_path = ''
backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, ['']) backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, [''])
creation = __import__('%s.creation' % settings.DATABASE_ENGINE, {}, {}, [''])
except ImportError, e_user: except ImportError, e_user:
# The database backend wasn't found. Display a helpful error message # The database backend wasn't found. Display a helpful error message
# listing all possible (built-in) database backends. # listing all possible (built-in) database backends.
@ -29,27 +27,11 @@ except ImportError, e:
available_backends = [f for f in os.listdir(backend_dir) if not f.startswith('_') and not f.startswith('.') and not f.endswith('.py') and not f.endswith('.pyc')] available_backends = [f for f in os.listdir(backend_dir) if not f.startswith('_') and not f.startswith('.') and not f.endswith('.py') and not f.endswith('.pyc')]
available_backends.sort() available_backends.sort()
if settings.DATABASE_ENGINE not in available_backends: if settings.DATABASE_ENGINE not in available_backends:
raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s" % \ raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s\nError was: %s" % \
(settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends))) (settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends, e_user)))
else: else:
raise # If there's some other error, this must be an error in Django itself. raise # If there's some other error, this must be an error in Django itself.
def _import_database_module(import_path='', module_name=''):
"""Lazily import a database module when requested."""
return __import__('%s%s.%s' % (import_path, settings.DATABASE_ENGINE, module_name), {}, {}, [''])
# We don't want to import the introspect module unless someone asks for it, so
# lazily load it on demmand.
get_introspection_module = curry(_import_database_module, _import_path, 'introspection')
def get_creation_module():
return creation
# We want runshell() to work the same way, but we have to treat it a
# little differently (since it just runs instead of returning a module like
# the above) and wrap the lazily-loaded runshell() method.
runshell = lambda: _import_database_module(_import_path, "client").runshell()
# Convenient aliases for backend bits. # Convenient aliases for backend bits.
connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS) connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS)
DatabaseError = backend.DatabaseError DatabaseError = backend.DatabaseError

View File

@ -42,14 +42,9 @@ class BaseDatabaseWrapper(local):
return util.CursorDebugWrapper(cursor, self) return util.CursorDebugWrapper(cursor, self)
class BaseDatabaseFeatures(object): class BaseDatabaseFeatures(object):
allows_group_by_ordinal = True
inline_fk_references = True
# True if django.db.backend.utils.typecast_timestamp is used on values # True if django.db.backend.utils.typecast_timestamp is used on values
# returned from dates() calls. # returned from dates() calls.
needs_datetime_string_cast = True needs_datetime_string_cast = True
supports_constraints = True
supports_tablespaces = False
uses_case_insensitive_names = False
uses_custom_query_class = False uses_custom_query_class = False
empty_fetchmany_value = [] empty_fetchmany_value = []
update_can_self_select = True update_can_self_select = True
@ -253,13 +248,13 @@ class BaseDatabaseOperations(object):
""" """
return "BEGIN;" return "BEGIN;"
def tablespace_sql(self, tablespace, inline=False): def sql_for_tablespace(self, tablespace, inline=False):
""" """
Returns the tablespace SQL, or None if the backend doesn't use Returns the SQL that will be appended to tables or rows to define
tablespaces. a tablespace. Returns '' if the backend doesn't use tablespaces.
""" """
return None return ''
def prep_for_like_query(self, x): def prep_for_like_query(self, x):
"""Prepares a value for use in a LIKE query.""" """Prepares a value for use in a LIKE query."""
from django.utils.encoding import smart_unicode from django.utils.encoding import smart_unicode
@ -325,3 +320,89 @@ class BaseDatabaseOperations(object):
""" """
return self.year_lookup_bounds(value) return self.year_lookup_bounds(value)
class BaseDatabaseIntrospection(object):
"""
This class encapsulates all backend-specific introspection utilities
"""
data_types_reverse = {}
def __init__(self, connection):
self.connection = connection
def table_name_converter(self, name):
"""Apply a conversion to the name for the purposes of comparison.
The default table name converter is for case sensitive comparison.
"""
return name
def table_names(self):
"Returns a list of names of all tables that exist in the database."
cursor = self.connection.cursor()
return self.get_table_list(cursor)
def django_table_names(self, only_existing=False):
"""
Returns a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, the resulting list will only include the tables
that actually exist in the database.
"""
from django.db import models
tables = set()
for app in models.get_apps():
for model in models.get_models(app):
tables.add(model._meta.db_table)
tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
tables = [t for t in tables if t in self.table_names()]
return tables
def installed_models(self, tables):
"Returns a set of all models represented by the provided list of table names."
from django.db import models
all_models = []
for app in models.get_apps():
for model in models.get_models(app):
all_models.append(model)
return set([m for m in all_models
if self.table_name_converter(m._meta.db_table) in map(self.table_name_converter, tables)
])
def sequence_list(self):
"Returns a list of information about all DB sequences for all models in all apps."
from django.db import models
apps = models.get_apps()
sequence_list = []
for app in apps:
for model in models.get_models(app):
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
class BaseDatabaseClient(object):
"""
This class encapsualtes all backend-specific methods for opening a
client shell
"""
def runshell(self):
raise NotImplementedError()
class BaseDatabaseValidation(object):
"""
This class encapsualtes all backend-specific model validation.
"""
def validate_field(self, errors, opts, f):
"By default, there is no backend-specific validation"
pass

View File

@ -1,7 +1,396 @@
class BaseCreation(object): import sys
import time
from django.conf import settings
from django.core.management import call_command
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
class BaseDatabaseCreation(object):
""" """
This class encapsulates all backend-specific differences that pertain to This class encapsulates all backend-specific differences that pertain to
database *creation*, such as the column types to use for particular Django database *creation*, such as the column types to use for particular Django
Fields. Fields, the SQL used to create and destroy tables, and the creation and
destruction of test databases.
""" """
pass data_types = {}
def __init__(self, connection):
self.connection = connection
def sql_create_model(self, model, style, known_models=set()):
"""
Returns the SQL required to create a single model, as a tuple of:
(list_of_sql, pending_references_dict)
"""
from django.db import models
opts = model._meta
final_output = []
table_output = []
pending_references = {}
qn = self.connection.ops.quote_name
for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
# Skip ManyToManyFields, because they're not represented as
# database columns in this table.
continue
# Make the definition (e.g. 'foo VARCHAR(30)') for this field.
field_output = [style.SQL_FIELD(qn(f.column)),
style.SQL_COLTYPE(col_type)]
field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or '')))
if f.primary_key:
field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))
elif f.unique:
field_output.append(style.SQL_KEYWORD('UNIQUE'))
if tablespace and f.unique:
# We must specify the index tablespace inline, because we
# won't be generating a CREATE INDEX statement for this field.
field_output.append(self.connection.ops.tablespace_sql(tablespace, inline=True))
if f.rel:
ref_output, pending = self.sql_for_inline_foreign_key_references(f, known_models, style)
if pending:
pr = pending_references.setdefault(f.rel.to, []).append((model, f))
else:
field_output.extend(ref_output)
table_output.append(' '.join(field_output))
if opts.order_with_respect_to:
table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \
style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \
style.SQL_KEYWORD('NULL'))
for field_constraints in opts.unique_together:
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
for i, line in enumerate(table_output): # Combine and add commas.
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
full_statement.append(')')
if opts.db_tablespace:
full_statement.append(self.connection.ops.tablespace_sql(opts.db_tablespace))
full_statement.append(';')
final_output.append('\n'.join(full_statement))
if opts.has_auto_field:
# Add any extra SQL needed to support auto-incrementing primary keys.
auto_column = opts.auto_field.db_column or opts.auto_field.name
autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, auto_column)
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output, pending_references
def sql_for_inline_foreign_key_references(self, field, known_models, style):
"Return the SQL snippet defining the foreign key reference for a field"
qn = self.connection.ops.quote_name
if field.rel.to in known_models:
output = [style.SQL_KEYWORD('REFERENCES') + ' ' + \
style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + \
style.SQL_FIELD(qn(field.rel.to._meta.get_field(field.rel.field_name).column)) + ')' +
self.connection.ops.deferrable_sql()
]
pending = False
else:
# We haven't yet created the table to which this field
# is related, so save it for later.
output = []
pending = True
return output, pending
def sql_for_pending_references(self, model, style, pending_references):
"Returns any ALTER TABLE statements to add constraints after the fact."
from django.db.backends.util import truncate_name
qn = self.connection.ops.quote_name
final_output = []
opts = model._meta
if model in pending_references:
for rel_class, f in pending_references[model]:
rel_opts = rel_class._meta
r_table = rel_opts.db_table
r_col = f.column
table = opts.db_table
col = opts.get_field(f.rel.field_name).column
# For MySQL, r_name must be unique in the first 64 characters.
# So we are careful with character usage here.
r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
(qn(r_table), truncate_name(r_name, self.connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
self.connection.ops.deferrable_sql()))
del pending_references[model]
return final_output
def sql_for_many_to_many(self, model, style):
"Return the CREATE TABLE statments for all the many-to-many tables defined on a model"
output = []
for f in model._meta.local_many_to_many:
output.extend(self.sql_for_many_to_many_field(model, f, style))
return output
def sql_for_many_to_many_field(self, model, f, style):
"Return the CREATE TABLE statements for a single m2m field"
from django.db import models
from django.db.backends.util import truncate_name
output = []
if f.creates_table:
opts = model._meta
qn = self.connection.ops.quote_name
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace:
sql = self.connection.ops.tablespace_sql(tablespace, inline=True)
if sql:
tablespace_sql = ' ' + sql
else:
tablespace_sql = ''
else:
tablespace_sql = ''
table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \
style.SQL_TABLE(qn(f.m2m_db_table())) + ' (']
table_output.append(' %s %s %s%s,' %
(style.SQL_FIELD(qn('id')),
style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()),
style.SQL_KEYWORD('NOT NULL PRIMARY KEY'),
tablespace_sql))
deferred = []
inline_output, deferred = self.sql_for_inline_many_to_many_references(model, f, style)
table_output.extend(inline_output)
table_output.append(' %s (%s, %s)%s' %
(style.SQL_KEYWORD('UNIQUE'),
style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_FIELD(qn(f.m2m_reverse_name())),
tablespace_sql))
table_output.append(')')
if opts.db_tablespace:
# f.db_tablespace is only for indices, so ignore its value here.
table_output.append(self.connection.ops.tablespace_sql(opts.db_tablespace))
table_output.append(';')
output.append('\n'.join(table_output))
for r_table, r_col, table, col in deferred:
r_name = '%s_refs_%s_%x' % (r_col, col,
abs(hash((r_table, table))))
output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
(qn(r_table),
truncate_name(r_name, self.connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
self.connection.ops.deferrable_sql()))
# Add any extra SQL needed to support auto-incrementing PKs
autoinc_sql = self.connection.ops.autoinc_sql(f.m2m_db_table(), 'id')
if autoinc_sql:
for stmt in autoinc_sql:
output.append(stmt)
return output
def sql_for_inline_many_to_many_references(self, model, field, style):
"Create the references to other tables required by a many-to-many table"
from django.db import models
opts = model._meta
qn = self.connection.ops.quote_name
table_output = [
' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(field.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(opts.db_table)),
style.SQL_FIELD(qn(opts.pk.column)),
self.connection.ops.deferrable_sql()),
' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(field.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(field.rel.to._meta.db_table)),
style.SQL_FIELD(qn(field.rel.to._meta.pk.column)),
self.connection.ops.deferrable_sql())
]
deferred = []
return table_output, deferred
def sql_indexes_for_model(self, model, style):
"Returns the CREATE INDEX SQL statements for a single model"
output = []
for f in model._meta.local_fields:
output.extend(self.sql_indexes_for_field(model, f, style))
return output
def sql_indexes_for_field(self, model, f, style):
"Return the CREATE INDEX SQL statements for a single model field"
if f.db_index and not f.unique:
qn = self.connection.ops.quote_name
tablespace = f.db_tablespace or model._meta.db_tablespace
if tablespace:
sql = self.connection.ops.tablespace_sql(tablespace)
if sql:
tablespace_sql = ' ' + sql
else:
tablespace_sql = ''
else:
tablespace_sql = ''
output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' +
style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' +
style.SQL_KEYWORD('ON') + ' ' +
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' +
"(%s)" % style.SQL_FIELD(qn(f.column)) +
"%s;" % tablespace_sql]
else:
output = []
return output
def sql_destroy_model(self, model, references_to_delete, style):
"Return the DROP TABLE and restraint dropping statements for a single model"
# Drop the table now
qn = self.connection.ops.quote_name
output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(model._meta.db_table)))]
if model in references_to_delete:
output.extend(self.sql_remove_table_constraints(model, references_to_delete))
if model._meta.has_auto_field:
ds = self.connection.ops.drop_sequence_sql(model._meta.db_table)
if ds:
output.append(ds)
return output
def sql_remove_table_constraints(self, model, references_to_delete):
output = []
for rel_class, f in references_to_delete[model]:
table = rel_class._meta.db_table
col = f.column
r_table = model._meta.db_table
r_col = model._meta.get_field(f.rel.field_name).column
r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table))))
output.append('%s %s %s %s;' % \
(style.SQL_KEYWORD('ALTER TABLE'),
style.SQL_TABLE(qn(table)),
style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
style.SQL_FIELD(truncate_name(r_name, self.connection.ops.max_name_length()))))
del references_to_delete[model]
return output
def sql_destroy_many_to_many(self, model, f, style):
"Returns the DROP TABLE statements for a single m2m field"
qn = self.connection.ops.quote_name
output = []
if f.creates_table:
output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(f.m2m_db_table()))))
ds = self.connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column))
if ds:
output.append(ds)
return output
def create_test_db(self, verbosity=1, autoclobber=False):
"""
Creates a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
if verbosity >= 1:
print "Creating test database..."
test_database_name = self._create_test_db(verbosity, autoclobber)
self.connection.close()
settings.DATABASE_NAME = test_database_name
call_command('syncdb', verbosity=verbosity, interactive=False)
if settings.CACHE_BACKEND.startswith('db://'):
cache_name = settings.CACHE_BACKEND[len('db://'):]
call_command('createcachetable', cache_name)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
cursor = self.connection.cursor()
return test_database_name
def _create_test_db(self, verbosity, autoclobber):
"Internal implementation - creates the test db tables."
suffix = self.sql_table_creation_suffix()
if settings.TEST_DATABASE_NAME:
test_database_name = settings.TEST_DATABASE_NAME
else:
test_database_name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
qn = self.connection.ops.quote_name
# Create the test database and connect to it. We need to autocommit
# if the database supports it because PostgreSQL doesn't allow
# CREATE/DROP DATABASE statements within transactions.
cursor = self.connection.cursor()
self.set_autocommit()
try:
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
cursor.execute("DROP DATABASE %s" % qn(test_database_name))
if verbosity >= 1:
print "Creating test database..."
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
return test_database_name
def destroy_test_db(self, old_database_name, verbosity=1):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
if verbosity >= 1:
print "Destroying test database..."
self.connection.close()
test_database_name = settings.DATABASE_NAME
settings.DATABASE_NAME = old_database_name
self._destroy_test_db(test_database_name, verbosity)
def _destroy_test_db(self, test_database_name, verbosity):
"Internal implementation - remove the test db tables."
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = self.connection.cursor()
self.set_autocommit()
time.sleep(1) # To avoid "database is being accessed by other users" errors.
cursor.execute("DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name))
self.connection.close()
def set_autocommit(self):
"Make sure a connection is in autocommit mode."
if hasattr(self.connection.connection, "autocommit"):
if callable(self.connection.connection.autocommit):
self.connection.connection.autocommit(True)
else:
self.connection.connection.autocommit = True
elif hasattr(self.connection.connection, "set_isolation_level"):
self.connection.connection.set_isolation_level(0)
def sql_table_creation_suffix(self):
"SQL to append to the end of the test table creation statements"
return ''

View File

@ -8,7 +8,8 @@ ImproperlyConfigured.
""" """
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.backends import BaseDatabaseFeatures, BaseDatabaseOperations from django.db.backends import *
from django.db.backends.creation import BaseDatabaseCreation
def complain(*args, **kwargs): def complain(*args, **kwargs):
raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet." raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet."
@ -25,16 +26,30 @@ class IntegrityError(DatabaseError):
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
quote_name = complain quote_name = complain
class DatabaseWrapper(object): class DatabaseClient(BaseDatabaseClient):
features = BaseDatabaseFeatures() runshell = complain
ops = DatabaseOperations()
class DatabaseIntrospection(BaseDatabaseIntrospection):
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
class DatabaseWrapper(object):
operators = {} operators = {}
cursor = complain cursor = complain
_commit = complain _commit = complain
_rollback = ignore _rollback = ignore
def __init__(self, **kwargs): def __init__(self, *args, **kwargs):
pass super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = BaseDatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = BaseDatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def close(self): def close(self):
pass pass

View File

@ -1,3 +0,0 @@
from django.db.backends.dummy.base import complain
runshell = complain

View File

@ -1 +0,0 @@
DATA_TYPES = {}

View File

@ -1,8 +0,0 @@
from django.db.backends.dummy.base import complain
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
DATA_TYPES_REVERSE = {}

View File

@ -4,7 +4,12 @@ MySQL database backend for Django.
Requires MySQLdb: http://sourceforge.net/projects/mysql-python Requires MySQLdb: http://sourceforge.net/projects/mysql-python
""" """
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util from django.db.backends import *
from django.db.backends.mysql.client import DatabaseClient
from django.db.backends.mysql.creation import DatabaseCreation
from django.db.backends.mysql.introspection import DatabaseIntrospection
from django.db.backends.mysql.validation import DatabaseValidation
try: try:
import MySQLdb as Database import MySQLdb as Database
except ImportError, e: except ImportError, e:
@ -60,7 +65,6 @@ server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
# TRADITIONAL will automatically cause most warnings to be treated as errors. # TRADITIONAL will automatically cause most warnings to be treated as errors.
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
inline_fk_references = False
empty_fetchmany_value = () empty_fetchmany_value = ()
update_can_self_select = False update_can_self_select = False
@ -142,8 +146,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return [first % value, second % value] return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= BINARY %s', 'exact': '= BINARY %s',
'iexact': 'LIKE %s', 'iexact': 'LIKE %s',
@ -164,6 +167,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DatabaseWrapper, self).__init__(**kwargs) super(DatabaseWrapper, self).__init__(**kwargs)
self.server_version = None self.server_version = None
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = DatabaseValidation()
def _valid_connection(self): def _valid_connection(self):
if self.connection is not None: if self.connection is not None:

View File

@ -1,27 +1,29 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
args = [''] def runshell(self):
db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME) args = ['']
user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER) db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME)
passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD) user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER)
host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST) passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD)
port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT) host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST)
defaults_file = settings.DATABASE_OPTIONS.get('read_default_file') port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT)
# Seems to be no good way to set sql_mode with CLI defaults_file = settings.DATABASE_OPTIONS.get('read_default_file')
# Seems to be no good way to set sql_mode with CLI
if defaults_file: if defaults_file:
args += ["--defaults-file=%s" % defaults_file] args += ["--defaults-file=%s" % defaults_file]
if user: if user:
args += ["--user=%s" % user] args += ["--user=%s" % user]
if passwd: if passwd:
args += ["--password=%s" % passwd] args += ["--password=%s" % passwd]
if host: if host:
args += ["--host=%s" % host] args += ["--host=%s" % host]
if port: if port:
args += ["--port=%s" % port] args += ["--port=%s" % port]
if db: if db:
args += [db] args += [db]
os.execvp('mysql', args) os.execvp('mysql', args)

View File

@ -1,28 +1,68 @@
# This dictionary maps Field objects to their associated MySQL column from django.conf import settings
# types, as strings. Column-type strings can contain format strings; they'll from django.db.backends.creation import BaseDatabaseCreation
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output. class DatabaseCreation(BaseDatabaseCreation):
DATA_TYPES = { # This dictionary maps Field objects to their associated MySQL column
'AutoField': 'integer AUTO_INCREMENT', # types, as strings. Column-type strings can contain format strings; they'll
'BooleanField': 'bool', # be interpolated against the values of Field.__dict__ before being output.
'CharField': 'varchar(%(max_length)s)', # If a column type is set to None, it won't be included in the output.
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', data_types = {
'DateField': 'date', 'AutoField': 'integer AUTO_INCREMENT',
'DateTimeField': 'datetime', 'BooleanField': 'bool',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', 'CharField': 'varchar(%(max_length)s)',
'FileField': 'varchar(%(max_length)s)', 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)', 'DateField': 'date',
'FloatField': 'double precision', 'DateTimeField': 'datetime',
'IntegerField': 'integer', 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'IPAddressField': 'char(15)', 'FileField': 'varchar(%(max_length)s)',
'NullBooleanField': 'bool', 'FilePathField': 'varchar(%(max_length)s)',
'OneToOneField': 'integer', 'FloatField': 'double precision',
'PhoneNumberField': 'varchar(20)', 'IntegerField': 'integer',
'PositiveIntegerField': 'integer UNSIGNED', 'IPAddressField': 'char(15)',
'PositiveSmallIntegerField': 'smallint UNSIGNED', 'NullBooleanField': 'bool',
'SlugField': 'varchar(%(max_length)s)', 'OneToOneField': 'integer',
'SmallIntegerField': 'smallint', 'PhoneNumberField': 'varchar(20)',
'TextField': 'longtext', 'PositiveIntegerField': 'integer UNSIGNED',
'TimeField': 'time', 'PositiveSmallIntegerField': 'smallint UNSIGNED',
'USStateField': 'varchar(2)', 'SlugField': 'varchar(%(max_length)s)',
} 'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
def sql_table_creation_suffix(self):
suffix = []
if settings.TEST_DATABASE_CHARSET:
suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET)
if settings.TEST_DATABASE_COLLATION:
suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION)
return ' '.join(suffix)
def sql_for_inline_foreign_key_references(self, field, known_models, style):
"All inline references are pending under MySQL"
return [], True
def sql_for_inline_many_to_many_references(self, model, field, style):
from django.db import models
opts = model._meta
qn = self.connection.ops.quote_name
table_output = [
' %s %s %s,' %
(style.SQL_FIELD(qn(field.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL')),
' %s %s %s,' %
(style.SQL_FIELD(qn(field.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL'))
]
deferred = [
(field.m2m_db_table(), field.m2m_column_name(), opts.db_table,
opts.pk.column),
(field.m2m_db_table(), field.m2m_reverse_name(),
field.rel.to._meta.db_table, field.rel.to._meta.pk.column)
]
return table_output, deferred

View File

@ -1,96 +1,97 @@
from django.db.backends.mysql.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
from MySQLdb import ProgrammingError, OperationalError from MySQLdb import ProgrammingError, OperationalError
from MySQLdb.constants import FIELD_TYPE from MySQLdb.constants import FIELD_TYPE
import re import re
quote_name = DatabaseOperations().quote_name
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
def get_table_list(cursor): class DatabaseIntrospection(BaseDatabaseIntrospection):
"Returns a list of table names in the current database." data_types_reverse = {
cursor.execute("SHOW TABLES") FIELD_TYPE.BLOB: 'TextField',
return [row[0] for row in cursor.fetchall()] FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
FIELD_TYPE.FLOAT: 'FloatField',
FIELD_TYPE.INT24: 'IntegerField',
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'IntegerField',
FIELD_TYPE.SHORT: 'IntegerField',
FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField',
}
def get_table_description(cursor, table_name): def get_table_list(self, cursor):
"Returns a description of the table, with the DB-API cursor.description interface." "Returns a list of table names in the current database."
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) cursor.execute("SHOW TABLES")
return cursor.description return [row[0] for row in cursor.fetchall()]
def _name_to_index(cursor, table_name): def get_table_description(self, cursor, table_name):
""" "Returns a description of the table, with the DB-API cursor.description interface."
Returns a dictionary of {field_name: field_index} for the given table. cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
Indexes are 0-based. return cursor.description
"""
return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))])
def get_relations(cursor, table_name): def _name_to_index(self, cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of {field_name: field_index} for the given table.
representing all relationships to the given table. Indexes are 0-based. Indexes are 0-based.
""" """
my_field_dict = _name_to_index(cursor, table_name) return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))])
constraints = []
relations = {} def get_relations(self, cursor, table_name):
try: """
# This should work for MySQL 5.0. Returns a dictionary of {field_index: (field_index_other_table, other_table)}
cursor.execute(""" representing all relationships to the given table. Indexes are 0-based.
SELECT column_name, referenced_table_name, referenced_column_name """
FROM information_schema.key_column_usage my_field_dict = self._name_to_index(cursor, table_name)
WHERE table_name = %s constraints = []
AND table_schema = DATABASE() relations = {}
AND referenced_table_name IS NOT NULL try:
AND referenced_column_name IS NOT NULL""", [table_name]) # This should work for MySQL 5.0.
constraints.extend(cursor.fetchall()) cursor.execute("""
except (ProgrammingError, OperationalError): SELECT column_name, referenced_table_name, referenced_column_name
# Fall back to "SHOW CREATE TABLE", for previous MySQL versions. FROM information_schema.key_column_usage
# Go through all constraints and save the equal matches. WHERE table_name = %s
cursor.execute("SHOW CREATE TABLE %s" % quote_name(table_name)) AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
constraints.extend(cursor.fetchall())
except (ProgrammingError, OperationalError):
# Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
# Go through all constraints and save the equal matches.
cursor.execute("SHOW CREATE TABLE %s" % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
pos = 0
while True:
match = foreign_key_re.search(row[1], pos)
if match == None:
break
pos = match.end()
constraints.append(match.groups())
for my_fieldname, other_table, other_field in constraints:
other_field_index = self._name_to_index(cursor, other_table)[other_field]
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
indexes = {}
for row in cursor.fetchall(): for row in cursor.fetchall():
pos = 0 indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
while True: return indexes
match = foreign_key_re.search(row[1], pos)
if match == None:
break
pos = match.end()
constraints.append(match.groups())
for my_fieldname, other_table, other_field in constraints:
other_field_index = _name_to_index(cursor, other_table)[other_field]
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
cursor.execute("SHOW INDEX FROM %s" % quote_name(table_name))
indexes = {}
for row in cursor.fetchall():
indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
return indexes
DATA_TYPES_REVERSE = {
FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
FIELD_TYPE.FLOAT: 'FloatField',
FIELD_TYPE.INT24: 'IntegerField',
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'IntegerField',
FIELD_TYPE.SHORT: 'IntegerField',
FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField',
}

View File

@ -0,0 +1,13 @@
from django.db.backends import BaseDatabaseValidation
class DatabaseValidation(BaseDatabaseValidation):
def validate_field(self, errors, opts, f):
"Prior to MySQL 5.0.3, character fields could not exceed 255 characters"
from django.db import models
from django.db import connection
db_version = connection.get_server_version()
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
errors.add(opts,
'"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' %
(f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))

View File

@ -8,8 +8,11 @@ import os
import datetime import datetime
import time import time
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util from django.db.backends import *
from django.db.backends.oracle import query from django.db.backends.oracle import query
from django.db.backends.oracle.client import DatabaseClient
from django.db.backends.oracle.creation import DatabaseCreation
from django.db.backends.oracle.introspection import DatabaseIntrospection
from django.utils.encoding import smart_str, force_unicode from django.utils.encoding import smart_str, force_unicode
# Oracle takes client-side character set encoding from the environment. # Oracle takes client-side character set encoding from the environment.
@ -24,11 +27,8 @@ DatabaseError = Database.Error
IntegrityError = Database.IntegrityError IntegrityError = Database.IntegrityError
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_ordinal = False
empty_fetchmany_value = () empty_fetchmany_value = ()
needs_datetime_string_cast = False needs_datetime_string_cast = False
supports_tablespaces = True
uses_case_insensitive_names = True
uses_custom_query_class = True uses_custom_query_class = True
interprets_empty_strings_as_nulls = True interprets_empty_strings_as_nulls = True
@ -194,10 +194,8 @@ class DatabaseOperations(BaseDatabaseOperations):
return [first % value, second % value] return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= %s', 'exact': '= %s',
'iexact': '= UPPER(%s)', 'iexact': '= UPPER(%s)',
@ -214,6 +212,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
} }
oracle_version = None oracle_version = None
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _valid_connection(self): def _valid_connection(self):
return self.connection is not None return self.connection is not None

View File

@ -1,11 +1,13 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
dsn = settings.DATABASE_USER def runshell(self):
if settings.DATABASE_PASSWORD: dsn = settings.DATABASE_USER
dsn += "/%s" % settings.DATABASE_PASSWORD if settings.DATABASE_PASSWORD:
if settings.DATABASE_NAME: dsn += "/%s" % settings.DATABASE_PASSWORD
dsn += "@%s" % settings.DATABASE_NAME if settings.DATABASE_NAME:
args = ["sqlplus", "-L", dsn] dsn += "@%s" % settings.DATABASE_NAME
os.execvp("sqlplus", args) args = ["sqlplus", "-L", dsn]
os.execvp("sqlplus", args)

View File

@ -1,291 +1,289 @@
import sys, time import sys, time
from django.conf import settings
from django.core import management from django.core import management
from django.db.backends.creation import BaseDatabaseCreation
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
#
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
DATA_TYPES = {
'AutoField': 'NUMBER(11)',
'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))',
'CharField': 'NVARCHAR2(%(max_length)s)',
'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
'FileField': 'NVARCHAR2(%(max_length)s)',
'FilePathField': 'NVARCHAR2(%(max_length)s)',
'FloatField': 'DOUBLE PRECISION',
'IntegerField': 'NUMBER(11)',
'IPAddressField': 'VARCHAR2(15)',
'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))',
'OneToOneField': 'NUMBER(11)',
'PhoneNumberField': 'VARCHAR2(20)',
'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'SlugField': 'NVARCHAR2(50)',
'SmallIntegerField': 'NUMBER(11)',
'TextField': 'NCLOB',
'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)',
'USStateField': 'CHAR(2)',
}
TEST_DATABASE_PREFIX = 'test_' TEST_DATABASE_PREFIX = 'test_'
PASSWORD = 'Im_a_lumberjack' PASSWORD = 'Im_a_lumberjack'
REMEMBER = {}
def create_test_db(settings, connection, verbosity=1, autoclobber=False): class DatabaseCreation(BaseDatabaseCreation):
TEST_DATABASE_NAME = _test_database_name(settings) # This dictionary maps Field objects to their associated Oracle column
TEST_DATABASE_USER = _test_database_user(settings) # types, as strings. Column-type strings can contain format strings; they'll
TEST_DATABASE_PASSWD = _test_database_passwd(settings) # be interpolated against the values of Field.__dict__ before being output.
TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings) # If a column type is set to None, it won't be included in the output.
TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings) #
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
parameters = { data_types = {
'dbname': TEST_DATABASE_NAME, 'AutoField': 'NUMBER(11)',
'user': TEST_DATABASE_USER, 'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))',
'password': TEST_DATABASE_PASSWD, 'CharField': 'NVARCHAR2(%(max_length)s)',
'tblspace': TEST_DATABASE_TBLSPACE, 'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)',
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, 'DateField': 'DATE',
} 'DateTimeField': 'TIMESTAMP',
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
'FileField': 'NVARCHAR2(%(max_length)s)',
'FilePathField': 'NVARCHAR2(%(max_length)s)',
'FloatField': 'DOUBLE PRECISION',
'IntegerField': 'NUMBER(11)',
'IPAddressField': 'VARCHAR2(15)',
'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))',
'OneToOneField': 'NUMBER(11)',
'PhoneNumberField': 'VARCHAR2(20)',
'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'SlugField': 'NVARCHAR2(50)',
'SmallIntegerField': 'NUMBER(11)',
'TextField': 'NCLOB',
'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)',
'USStateField': 'CHAR(2)',
}
def _create_test_db(self, verbosity, autoclobber):
TEST_DATABASE_NAME = self._test_database_name(settings)
TEST_DATABASE_USER = self._test_database_user(settings)
TEST_DATABASE_PASSWD = self._test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings)
REMEMBER['user'] = settings.DATABASE_USER parameters = {
REMEMBER['passwd'] = settings.DATABASE_PASSWORD 'dbname': TEST_DATABASE_NAME,
'user': TEST_DATABASE_USER,
'password': TEST_DATABASE_PASSWD,
'tblspace': TEST_DATABASE_TBLSPACE,
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
}
cursor = connection.cursor() self.remember['user'] = settings.DATABASE_USER
if _test_database_create(settings): self.remember['passwd'] = settings.DATABASE_PASSWORD
if verbosity >= 1:
print 'Creating test database...'
try:
_create_test_db(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("It appears the test database, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
_destroy_test_db(cursor, parameters, verbosity)
if verbosity >= 1:
print "Creating test database..."
_create_test_db(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
if _test_user_create(settings): cursor = self.connection.cursor()
if verbosity >= 1: if self._test_database_create(settings):
print "Creating test user..." if verbosity >= 1:
try: print 'Creating test database...'
_create_test_user(cursor, parameters, verbosity) try:
except Exception, e: self._execute_test_db_creation(cursor, parameters, verbosity)
sys.stderr.write("Got an error creating the test user: %s\n" % e) except Exception, e:
if not autoclobber: sys.stderr.write("Got an error creating the test database: %s\n" % e)
confirm = raw_input("It appears the test user, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_USER) if not autoclobber:
if autoclobber or confirm == 'yes': confirm = raw_input("It appears the test database, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_NAME)
try: if autoclobber or confirm == 'yes':
if verbosity >= 1: try:
print "Destroying old test user..." if verbosity >= 1:
_destroy_test_user(cursor, parameters, verbosity) print "Destroying old test database..."
if verbosity >= 1: self._execute_test_db_destruction(cursor, parameters, verbosity)
print "Creating test user..." if verbosity >= 1:
_create_test_user(cursor, parameters, verbosity) print "Creating test database..."
except Exception, e: self._execute_test_db_creation(cursor, parameters, verbosity)
sys.stderr.write("Got an error recreating the test user: %s\n" % e) except Exception, e:
sys.exit(2) sys.stderr.write("Got an error recreating the test database: %s\n" % e)
else: sys.exit(2)
print "Tests cancelled." else:
sys.exit(1) print "Tests cancelled."
sys.exit(1)
connection.close() if self._test_user_create(settings):
settings.DATABASE_USER = TEST_DATABASE_USER if verbosity >= 1:
settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD print "Creating test user..."
try:
self._create_test_user(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error creating the test user: %s\n" % e)
if not autoclobber:
confirm = raw_input("It appears the test user, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_USER)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test user..."
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
print "Creating test user..."
self._create_test_user(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
management.call_command('syncdb', verbosity=verbosity, interactive=False) settings.DATABASE_USER = TEST_DATABASE_USER
settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD
# Get a cursor (even though we don't need one yet). This has return TEST_DATABASE_NAME
# the side effect of initializing the test database.
cursor = connection.cursor() def _destroy_test_db(self, test_database_name, verbosity=1):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
TEST_DATABASE_NAME = self._test_database_name(settings)
TEST_DATABASE_USER = self._test_database_user(settings)
TEST_DATABASE_PASSWD = self._test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings)
def destroy_test_db(settings, connection, old_database_name, verbosity=1): settings.DATABASE_USER = self.remember['user']
connection.close() settings.DATABASE_PASSWORD = self.remember['passwd']
TEST_DATABASE_NAME = _test_database_name(settings) parameters = {
TEST_DATABASE_USER = _test_database_user(settings) 'dbname': TEST_DATABASE_NAME,
TEST_DATABASE_PASSWD = _test_database_passwd(settings) 'user': TEST_DATABASE_USER,
TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings) 'password': TEST_DATABASE_PASSWD,
TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings) 'tblspace': TEST_DATABASE_TBLSPACE,
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
}
settings.DATABASE_NAME = old_database_name self.remember['user'] = settings.DATABASE_USER
settings.DATABASE_USER = REMEMBER['user'] self.remember['passwd'] = settings.DATABASE_PASSWORD
settings.DATABASE_PASSWORD = REMEMBER['passwd']
parameters = { cursor = self.connection.cursor()
'dbname': TEST_DATABASE_NAME, time.sleep(1) # To avoid "database is being accessed by other users" errors.
'user': TEST_DATABASE_USER, if self._test_user_create(settings):
'password': TEST_DATABASE_PASSWD, if verbosity >= 1:
'tblspace': TEST_DATABASE_TBLSPACE, print 'Destroying test user...'
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, self._destroy_test_user(cursor, parameters, verbosity)
} if self._test_database_create(settings):
if verbosity >= 1:
print 'Destroying test database tables...'
self._execute_test_db_destruction(cursor, parameters, verbosity)
self.connection.close()
REMEMBER['user'] = settings.DATABASE_USER def _execute_test_db_creation(cursor, parameters, verbosity):
REMEMBER['passwd'] = settings.DATABASE_PASSWORD
cursor = connection.cursor()
time.sleep(1) # To avoid "database is being accessed by other users" errors.
if _test_user_create(settings):
if verbosity >= 1:
print 'Destroying test user...'
_destroy_test_user(cursor, parameters, verbosity)
if _test_database_create(settings):
if verbosity >= 1:
print 'Destroying test database...'
_destroy_test_db(cursor, parameters, verbosity)
connection.close()
def _create_test_db(cursor, parameters, verbosity):
if verbosity >= 2:
print "_create_test_db(): dbname = %s" % parameters['dbname']
statements = [
"""CREATE TABLESPACE %(tblspace)s
DATAFILE '%(tblspace)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
"""CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE '%(tblspace_temp)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _create_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_create_test_user(): username = %s" % parameters['user']
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY %(password)s
DEFAULT TABLESPACE %(tblspace)s
TEMPORARY TABLESPACE %(tblspace_temp)s
""",
"""GRANT CONNECT, RESOURCE TO %(user)s""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_db(cursor, parameters, verbosity):
if verbosity >= 2:
print "_destroy_test_db(): dbname=%s" % parameters['dbname']
statements = [
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_destroy_test_user(): user=%s" % parameters['user']
print "Be patient. This can take some time..."
statements = [
'DROP USER %(user)s CASCADE',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(cursor, statements, parameters, verbosity):
for template in statements:
stmt = template % parameters
if verbosity >= 2: if verbosity >= 2:
print stmt print "_create_test_db(): dbname = %s" % parameters['dbname']
statements = [
"""CREATE TABLESPACE %(tblspace)s
DATAFILE '%(tblspace)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
"""CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE '%(tblspace_temp)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _create_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_create_test_user(): username = %s" % parameters['user']
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY %(password)s
DEFAULT TABLESPACE %(tblspace)s
TEMPORARY TABLESPACE %(tblspace_temp)s
""",
"""GRANT CONNECT, RESOURCE TO %(user)s""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _execute_test_db_destruction(cursor, parameters, verbosity):
if verbosity >= 2:
print "_execute_test_db_destruction(): dbname=%s" % parameters['dbname']
statements = [
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_destroy_test_user(): user=%s" % parameters['user']
print "Be patient. This can take some time..."
statements = [
'DROP USER %(user)s CASCADE',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(cursor, statements, parameters, verbosity):
for template in statements:
stmt = template % parameters
if verbosity >= 2:
print stmt
try:
cursor.execute(stmt)
except Exception, err:
sys.stderr.write("Failed (%s)\n" % (err))
raise
def _test_database_name(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try: try:
cursor.execute(stmt) if settings.TEST_DATABASE_NAME:
except Exception, err: name = settings.TEST_DATABASE_NAME
sys.stderr.write("Failed (%s)\n" % (err)) except AttributeError:
pass
except:
raise raise
return name
def _test_database_name(settings): def _test_database_create(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME name = True
try: try:
if settings.TEST_DATABASE_NAME: if settings.TEST_DATABASE_CREATE:
name = settings.TEST_DATABASE_NAME name = True
except AttributeError: else:
pass name = False
except: except AttributeError:
raise pass
return name except:
raise
return name
def _test_database_create(settings): def _test_user_create(settings):
name = True name = True
try: try:
if settings.TEST_DATABASE_CREATE: if settings.TEST_USER_CREATE:
name = True name = True
else: else:
name = False name = False
except AttributeError: except AttributeError:
pass pass
except: except:
raise raise
return name return name
def _test_user_create(settings): def _test_database_user(settings):
name = True name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try: try:
if settings.TEST_USER_CREATE: if settings.TEST_DATABASE_USER:
name = True name = settings.TEST_DATABASE_USER
else: except AttributeError:
name = False pass
except AttributeError: except:
pass raise
except: return name
raise
return name
def _test_database_user(settings): def _test_database_passwd(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME name = PASSWORD
try: try:
if settings.TEST_DATABASE_USER: if settings.TEST_DATABASE_PASSWD:
name = settings.TEST_DATABASE_USER name = settings.TEST_DATABASE_PASSWD
except AttributeError: except AttributeError:
pass pass
except: except:
raise raise
return name return name
def _test_database_passwd(settings): def _test_database_tblspace(settings):
name = PASSWORD name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try: try:
if settings.TEST_DATABASE_PASSWD: if settings.TEST_DATABASE_TBLSPACE:
name = settings.TEST_DATABASE_PASSWD name = settings.TEST_DATABASE_TBLSPACE
except AttributeError: except AttributeError:
pass pass
except: except:
raise raise
return name return name
def _test_database_tblspace(settings): def _test_database_tblspace_tmp(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp'
try: try:
if settings.TEST_DATABASE_TBLSPACE: if settings.TEST_DATABASE_TBLSPACE_TMP:
name = settings.TEST_DATABASE_TBLSPACE name = settings.TEST_DATABASE_TBLSPACE_TMP
except AttributeError: except AttributeError:
pass pass
except: except:
raise raise
return name return name
def _test_database_tblspace_tmp(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp'
try:
if settings.TEST_DATABASE_TBLSPACE_TMP:
name = settings.TEST_DATABASE_TBLSPACE_TMP
except AttributeError:
pass
except:
raise
return name

View File

@ -1,98 +1,103 @@
from django.db.backends.oracle.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
import re
import cx_Oracle import cx_Oracle
import re
quote_name = DatabaseOperations().quote_name
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
def get_table_list(cursor): class DatabaseIntrospection(BaseDatabaseIntrospection):
"Returns a list of table names in the current database." # Maps type objects to Django Field types.
cursor.execute("SELECT TABLE_NAME FROM USER_TABLES") data_types_reverse = {
return [row[0].upper() for row in cursor.fetchall()] cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateTimeField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}
def get_table_description(cursor, table_name): def get_table_list(self, cursor):
"Returns a description of the table, with the DB-API cursor.description interface." "Returns a list of table names in the current database."
cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % quote_name(table_name)) cursor.execute("SELECT TABLE_NAME FROM USER_TABLES")
return cursor.description return [row[0].upper() for row in cursor.fetchall()]
def _name_to_index(cursor, table_name): def get_table_description(self, cursor, table_name):
""" "Returns a description of the table, with the DB-API cursor.description interface."
Returns a dictionary of {field_name: field_index} for the given table. cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % self.connection.ops.quote_name(table_name))
Indexes are 0-based. return cursor.description
"""
return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))])
def get_relations(cursor, table_name): def table_name_converter(self, name):
""" "Table name comparison is case insensitive under Oracle"
Returns a dictionary of {field_index: (field_index_other_table, other_table)} return name.upper()
representing all relationships to the given table. Indexes are 0-based.
""" def _name_to_index(self, cursor, table_name):
cursor.execute(""" """
SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1 Returns a dictionary of {field_name: field_index} for the given table.
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb, Indexes are 0-based.
user_tab_cols ta, user_tab_cols tb """
WHERE user_constraints.table_name = %s AND return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))])
ta.table_name = %s AND
ta.column_name = ca.column_name AND
ca.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
cb.table_name = tb.table_name AND
cb.column_name = tb.column_name AND
ca.position = cb.position""", [table_name, table_name, table_name])
relations = {} def get_relations(self, cursor, table_name):
for row in cursor.fetchall(): """
relations[row[0]] = (row[2], row[1]) Returns a dictionary of {field_index: (field_index_other_table, other_table)}
return relations representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb,
user_tab_cols ta, user_tab_cols tb
WHERE user_constraints.table_name = %s AND
ta.table_name = %s AND
ta.column_name = ca.column_name AND
ca.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
cb.table_name = tb.table_name AND
cb.column_name = tb.column_name AND
ca.position = cb.position""", [table_name, table_name, table_name])
def get_indexes(cursor, table_name): relations = {}
""" for row in cursor.fetchall():
Returns a dictionary of fieldname -> infodict for the given table, relations[row[0]] = (row[2], row[1])
where each infodict is in the format: return relations
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index} def get_indexes(self, cursor, table_name):
""" """
# This query retrieves each index on the given table, including the Returns a dictionary of fieldname -> infodict for the given table,
# first associated field name where each infodict is in the format:
# "We were in the nick of time; you were in great peril!" {'primary_key': boolean representing whether it's the primary key,
sql = """ 'unique': boolean representing whether it's a unique index}
WITH primarycols AS ( """
SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL # This query retrieves each index on the given table, including the
FROM user_cons_columns, user_constraints # first associated field name
WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND # "We were in the nick of time; you were in great peril!"
user_constraints.constraint_type = 'P' AND sql = """
user_cons_columns.table_name = %s), WITH primarycols AS (
uniquecols AS ( SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL
SELECT user_ind_columns.table_name, user_ind_columns.column_name, 1 AS UNIQUECOL FROM user_cons_columns, user_constraints
FROM user_indexes, user_ind_columns WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND
WHERE uniqueness = 'UNIQUE' AND user_constraints.constraint_type = 'P' AND
user_indexes.index_name = user_ind_columns.index_name AND user_cons_columns.table_name = %s),
user_ind_columns.table_name = %s) uniquecols AS (
SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL SELECT user_ind_columns.table_name, user_ind_columns.column_name, 1 AS UNIQUECOL
FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM FROM user_indexes, user_ind_columns
uniquecols) allcols, WHERE uniqueness = 'UNIQUE' AND
primarycols, uniquecols user_indexes.index_name = user_ind_columns.index_name AND
WHERE allcols.column_name = primarycols.column_name (+) AND user_ind_columns.table_name = %s)
allcols.column_name = uniquecols.column_name (+) SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL
""" FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM
cursor.execute(sql, [table_name, table_name]) uniquecols) allcols,
indexes = {} primarycols, uniquecols
for row in cursor.fetchall(): WHERE allcols.column_name = primarycols.column_name (+) AND
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as allcols.column_name = uniquecols.column_name (+)
# a string of space-separated integers. This designates the field """
# indexes (1-based) of the fields that have indexes on the table. cursor.execute(sql, [table_name, table_name])
# Here, we skip any indexes across multiple fields. indexes = {}
indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]} for row in cursor.fetchall():
return indexes # row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]}
return indexes
# Maps type objects to Django Field types.
DATA_TYPES_REVERSE = {
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateTimeField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}

View File

@ -4,9 +4,13 @@ PostgreSQL database backend for Django.
Requires psycopg 1: http://initd.org/projects/psycopg1 Requires psycopg 1: http://initd.org/projects/psycopg1
""" """
from django.utils.encoding import smart_str, smart_unicode from django.db.backends import *
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, util from django.db.backends.postgresql.client import DatabaseClient
from django.db.backends.postgresql.creation import DatabaseCreation
from django.db.backends.postgresql.introspection import DatabaseIntrospection
from django.db.backends.postgresql.operations import DatabaseOperations from django.db.backends.postgresql.operations import DatabaseOperations
from django.utils.encoding import smart_str, smart_unicode
try: try:
import psycopg as Database import psycopg as Database
except ImportError, e: except ImportError, e:
@ -59,12 +63,7 @@ class UnicodeCursorWrapper(object):
def __iter__(self): def __iter__(self):
return iter(self.cursor) return iter(self.cursor)
class DatabaseFeatures(BaseDatabaseFeatures):
pass # This backend uses all the defaults.
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= %s', 'exact': '= %s',
'iexact': 'ILIKE %s', 'iexact': 'ILIKE %s',
@ -82,6 +81,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': 'ILIKE %s', 'iendswith': 'ILIKE %s',
} }
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = BaseDatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self, settings):
set_tz = False set_tz = False
if self.connection is None: if self.connection is None:

View File

@ -1,15 +1,17 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
args = ['psql'] def runshell(self):
if settings.DATABASE_USER: args = ['psql']
args += ["-U", settings.DATABASE_USER] if settings.DATABASE_USER:
if settings.DATABASE_PASSWORD: args += ["-U", settings.DATABASE_USER]
args += ["-W"] if settings.DATABASE_PASSWORD:
if settings.DATABASE_HOST: args += ["-W"]
args.extend(["-h", settings.DATABASE_HOST]) if settings.DATABASE_HOST:
if settings.DATABASE_PORT: args.extend(["-h", settings.DATABASE_HOST])
args.extend(["-p", str(settings.DATABASE_PORT)]) if settings.DATABASE_PORT:
args += [settings.DATABASE_NAME] args.extend(["-p", str(settings.DATABASE_PORT)])
os.execvp('psql', args) args += [settings.DATABASE_NAME]
os.execvp('psql', args)

View File

@ -1,28 +1,38 @@
# This dictionary maps Field objects to their associated PostgreSQL column from django.conf import settings
# types, as strings. Column-type strings can contain format strings; they'll from django.db.backends.creation import BaseDatabaseCreation
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output. class DatabaseCreation(BaseDatabaseCreation):
DATA_TYPES = { # This dictionary maps Field objects to their associated PostgreSQL column
'AutoField': 'serial', # types, as strings. Column-type strings can contain format strings; they'll
'BooleanField': 'boolean', # be interpolated against the values of Field.__dict__ before being output.
'CharField': 'varchar(%(max_length)s)', # If a column type is set to None, it won't be included in the output.
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', data_types = {
'DateField': 'date', 'AutoField': 'serial',
'DateTimeField': 'timestamp with time zone', 'BooleanField': 'boolean',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', 'CharField': 'varchar(%(max_length)s)',
'FileField': 'varchar(%(max_length)s)', 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)', 'DateField': 'date',
'FloatField': 'double precision', 'DateTimeField': 'timestamp with time zone',
'IntegerField': 'integer', 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'IPAddressField': 'inet', 'FileField': 'varchar(%(max_length)s)',
'NullBooleanField': 'boolean', 'FilePathField': 'varchar(%(max_length)s)',
'OneToOneField': 'integer', 'FloatField': 'double precision',
'PhoneNumberField': 'varchar(20)', 'IntegerField': 'integer',
'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)', 'IPAddressField': 'inet',
'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)', 'NullBooleanField': 'boolean',
'SlugField': 'varchar(%(max_length)s)', 'OneToOneField': 'integer',
'SmallIntegerField': 'smallint', 'PhoneNumberField': 'varchar(20)',
'TextField': 'text', 'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)',
'TimeField': 'time', 'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)',
'USStateField': 'varchar(2)', 'SlugField': 'varchar(%(max_length)s)',
} 'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
def sql_table_creation_suffix(self):
assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time."
if settings.TEST_DATABASE_CHARSET:
return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET
return ''

View File

@ -1,86 +1,86 @@
from django.db.backends.postgresql.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
quote_name = DatabaseOperations().quote_name class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
cursor.execute("""
SELECT c.relname
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r', 'v', '')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)""")
return [row[0] for row in cursor.fetchall()]
def get_table_list(cursor): def get_table_description(self, cursor, table_name):
"Returns a list of table names in the current database." "Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute(""" cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
SELECT c.relname return cursor.description
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r', 'v', '')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name): def get_relations(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface." """
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) Returns a dictionary of {field_index: (field_index_other_table, other_table)}
return cursor.description representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT con.conkey, con.confkey, c2.relname
FROM pg_constraint con, pg_class c1, pg_class c2
WHERE c1.oid = con.conrelid
AND c2.oid = con.confrelid
AND c1.relname = %s
AND con.contype = 'f'""", [table_name])
relations = {}
for row in cursor.fetchall():
try:
# row[0] and row[1] are like "{2}", so strip the curly braces.
relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2])
except ValueError:
continue
return relations
def get_relations(cursor, table_name): def get_indexes(self, cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of fieldname -> infodict for the given table,
representing all relationships to the given table. Indexes are 0-based. where each infodict is in the format:
""" {'primary_key': boolean representing whether it's the primary key,
cursor.execute(""" 'unique': boolean representing whether it's a unique index}
SELECT con.conkey, con.confkey, c2.relname """
FROM pg_constraint con, pg_class c1, pg_class c2 # This query retrieves each index on the given table, including the
WHERE c1.oid = con.conrelid # first associated field name
AND c2.oid = con.confrelid cursor.execute("""
AND c1.relname = %s SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
AND con.contype = 'f'""", [table_name]) FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
relations = {} pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
for row in cursor.fetchall(): WHERE c.oid = idx.indrelid
try: AND idx.indexrelid = c2.oid
# row[0] and row[1] are like "{2}", so strip the curly braces. AND attr.attrelid = c.oid
relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2]) AND attr.attnum = idx.indkey[0]
except ValueError: AND c.relname = %s""", [table_name])
continue indexes = {}
return relations for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
if ' ' in row[1]:
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
cursor.execute("""
SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND attr.attrelid = c.oid
AND attr.attnum = idx.indkey[0]
AND c.relname = %s""", [table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
if ' ' in row[1]:
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
# Maps type codes to Django Field types.
DATA_TYPES_REVERSE = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}

View File

@ -4,8 +4,12 @@ PostgreSQL database backend for Django.
Requires psycopg 2: http://initd.org/projects/psycopg2 Requires psycopg 2: http://initd.org/projects/psycopg2
""" """
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures from django.db.backends import *
from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations
from django.db.backends.postgresql.client import DatabaseClient
from django.db.backends.postgresql.creation import DatabaseCreation
from django.db.backends.postgresql_psycopg2.introspection import DatabaseIntrospection
from django.utils.safestring import SafeUnicode from django.utils.safestring import SafeUnicode
try: try:
import psycopg2 as Database import psycopg2 as Database
@ -31,8 +35,6 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
return cursor.query return cursor.query
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = { operators = {
'exact': '= %s', 'exact': '= %s',
'iexact': 'ILIKE %s', 'iexact': 'ILIKE %s',
@ -50,6 +52,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': 'ILIKE %s', 'iendswith': 'ILIKE %s',
} }
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self, settings):
set_tz = False set_tz = False
if self.connection is None: if self.connection is None:

View File

@ -1 +0,0 @@
from django.db.backends.postgresql.client import *

View File

@ -1 +0,0 @@
from django.db.backends.postgresql.creation import *

View File

@ -1,83 +1,21 @@
from django.db.backends.postgresql_psycopg2.base import DatabaseOperations from django.db.backends.postgresql.introspection import DatabaseIntrospection as PostgresDatabaseIntrospection
quote_name = DatabaseOperations().quote_name class DatabaseIntrospection(PostgresDatabaseIntrospection):
def get_table_list(cursor): def get_relations(self, cursor, table_name):
"Returns a list of table names in the current database." """
cursor.execute(""" Returns a dictionary of {field_index: (field_index_other_table, other_table)}
SELECT c.relname representing all relationships to the given table. Indexes are 0-based.
FROM pg_catalog.pg_class c """
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace cursor.execute("""
WHERE c.relkind IN ('r', 'v', '') SELECT con.conkey, con.confkey, c2.relname
AND n.nspname NOT IN ('pg_catalog', 'pg_toast') FROM pg_constraint con, pg_class c1, pg_class c2
AND pg_catalog.pg_table_is_visible(c.oid)""") WHERE c1.oid = con.conrelid
return [row[0] for row in cursor.fetchall()] AND c2.oid = con.confrelid
AND c1.relname = %s
def get_table_description(cursor, table_name): AND con.contype = 'f'""", [table_name])
"Returns a description of the table, with the DB-API cursor.description interface." relations = {}
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) for row in cursor.fetchall():
return cursor.description # row[0] and row[1] are single-item lists, so grab the single item.
relations[row[0][0] - 1] = (row[1][0] - 1, row[2])
def get_relations(cursor, table_name): return relations
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT con.conkey, con.confkey, c2.relname
FROM pg_constraint con, pg_class c1, pg_class c2
WHERE c1.oid = con.conrelid
AND c2.oid = con.confrelid
AND c1.relname = %s
AND con.contype = 'f'""", [table_name])
relations = {}
for row in cursor.fetchall():
# row[0] and row[1] are single-item lists, so grab the single item.
relations[row[0][0] - 1] = (row[1][0] - 1, row[2])
return relations
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
cursor.execute("""
SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND attr.attrelid = c.oid
AND attr.attnum = idx.indkey[0]
AND c.relname = %s""", [table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
if ' ' in row[1]:
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
# Maps type codes to Django Field types.
DATA_TYPES_REVERSE = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}

View File

@ -6,7 +6,11 @@ Python 2.3 and 2.4 require pysqlite2 (http://pysqlite.org/).
Python 2.5 and later use the sqlite3 module in the standard library. Python 2.5 and later use the sqlite3 module in the standard library.
""" """
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util from django.db.backends import *
from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection
try: try:
try: try:
from sqlite3 import dbapi2 as Database from sqlite3 import dbapi2 as Database
@ -46,7 +50,6 @@ if Database.version_info >= (2,4,1):
Database.register_adapter(str, lambda s:s.decode('utf-8')) Database.register_adapter(str, lambda s:s.decode('utf-8'))
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
supports_constraints = False
# SQLite cannot handle us only partially reading from a cursor's result set # SQLite cannot handle us only partially reading from a cursor's result set
# and then writing the same rows to the database in another cursor. This # and then writing the same rows to the database in another cursor. This
# setting ensures we always read result sets fully into memory all in one # setting ensures we always read result sets fully into memory all in one
@ -96,11 +99,8 @@ class DatabaseOperations(BaseDatabaseOperations):
second = '%s-12-31 23:59:59.999999' second = '%s-12-31 23:59:59.999999'
return [first % value, second % value] return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
# SQLite requires LIKE statements to include an ESCAPE clause if the value # SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it. # being escaped has a percent or underscore in it.
# See http://www.sqlite.org/lang_expr.html for an explanation. # See http://www.sqlite.org/lang_expr.html for an explanation.
@ -121,6 +121,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': "LIKE %s ESCAPE '\\'", 'iendswith': "LIKE %s ESCAPE '\\'",
} }
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self, settings):
if self.connection is None: if self.connection is None:
if not settings.DATABASE_NAME: if not settings.DATABASE_NAME:

View File

@ -1,6 +1,8 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings from django.conf import settings
import os import os
def runshell(): class DatabaseClient(BaseDatabaseClient):
args = ['', settings.DATABASE_NAME] def runshell(self):
os.execvp('sqlite3', args) args = ['', settings.DATABASE_NAME]
os.execvp('sqlite3', args)

View File

@ -1,27 +1,73 @@
# SQLite doesn't actually support most of these types, but it "does the right import os
# thing" given more verbose field definitions, so leave them as is so that import sys
# schema inspection is more useful. from django.conf import settings
DATA_TYPES = { from django.db.backends.creation import BaseDatabaseCreation
'AutoField': 'integer',
'BooleanField': 'bool', class DatabaseCreation(BaseDatabaseCreation):
'CharField': 'varchar(%(max_length)s)', # SQLite doesn't actually support most of these types, but it "does the right
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', # thing" given more verbose field definitions, so leave them as is so that
'DateField': 'date', # schema inspection is more useful.
'DateTimeField': 'datetime', data_types = {
'DecimalField': 'decimal', 'AutoField': 'integer',
'FileField': 'varchar(%(max_length)s)', 'BooleanField': 'bool',
'FilePathField': 'varchar(%(max_length)s)', 'CharField': 'varchar(%(max_length)s)',
'FloatField': 'real', 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'IntegerField': 'integer', 'DateField': 'date',
'IPAddressField': 'char(15)', 'DateTimeField': 'datetime',
'NullBooleanField': 'bool', 'DecimalField': 'decimal',
'OneToOneField': 'integer', 'FileField': 'varchar(%(max_length)s)',
'PhoneNumberField': 'varchar(20)', 'FilePathField': 'varchar(%(max_length)s)',
'PositiveIntegerField': 'integer unsigned', 'FloatField': 'real',
'PositiveSmallIntegerField': 'smallint unsigned', 'IntegerField': 'integer',
'SlugField': 'varchar(%(max_length)s)', 'IPAddressField': 'char(15)',
'SmallIntegerField': 'smallint', 'NullBooleanField': 'bool',
'TextField': 'text', 'OneToOneField': 'integer',
'TimeField': 'time', 'PhoneNumberField': 'varchar(20)',
'USStateField': 'varchar(2)', 'PositiveIntegerField': 'integer unsigned',
} 'PositiveSmallIntegerField': 'smallint unsigned',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
def sql_for_pending_references(self, model, style, pending_references):
"SQLite3 doesn't support constraints"
return []
def sql_remove_table_constraints(self, model, references_to_delete):
"SQLite3 doesn't support constraints"
return []
def _create_test_db(self, verbosity, autoclobber):
if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:":
test_database_name = settings.TEST_DATABASE_NAME
# Erase the old test database
if verbosity >= 1:
print "Destroying old test database..."
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
os.remove(test_database_name)
except Exception, e:
sys.stderr.write("Got an error deleting the old test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
if verbosity >= 1:
print "Creating test database..."
else:
test_database_name = ":memory:"
return test_database_name
def _destroy_test_db(self, test_database_name, verbosity):
if test_database_name and test_database_name != ":memory:":
# Remove the SQLite database file
os.remove(test_database_name)

View File

@ -1,84 +1,30 @@
from django.db.backends.sqlite3.base import DatabaseOperations from django.db.backends import BaseDatabaseIntrospection
quote_name = DatabaseOperations().quote_name
def get_table_list(cursor):
"Returns a list of table names in the current database."
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND NOT name='sqlite_sequence'
ORDER BY name""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
return [(info['name'], info['type'], None, None, None, None,
info['null_ok']) for info in _table_info(cursor, table_name)]
def get_relations(cursor, table_name):
raise NotImplementedError
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
indexes = {}
for info in _table_info(cursor, table_name):
indexes[info['name']] = {'primary_key': info['pk'] != 0,
'unique': False}
cursor.execute('PRAGMA index_list(%s)' % quote_name(table_name))
# seq, name, unique
for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]:
if not unique:
continue
cursor.execute('PRAGMA index_info(%s)' % quote_name(index))
info = cursor.fetchall()
# Skip indexes across multiple fields
if len(info) != 1:
continue
name = info[0][2] # seqno, cid, name
indexes[name]['unique'] = True
return indexes
def _table_info(cursor, name):
cursor.execute('PRAGMA table_info(%s)' % quote_name(name))
# cid, name, type, notnull, dflt_value, pk
return [{'name': field[1],
'type': field[2],
'null_ok': not field[3],
'pk': field[5] # undocumented
} for field in cursor.fetchall()]
# Maps SQL types to Django Field types. Some of the SQL types have multiple
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
BASE_DATA_TYPES_REVERSE = {
'bool': 'BooleanField',
'boolean': 'BooleanField',
'smallint': 'SmallIntegerField',
'smallinteger': 'SmallIntegerField',
'int': 'IntegerField',
'integer': 'IntegerField',
'text': 'TextField',
'char': 'CharField',
'date': 'DateField',
'datetime': 'DateTimeField',
'time': 'TimeField',
}
# This light wrapper "fakes" a dictionary interface, because some SQLite data # This light wrapper "fakes" a dictionary interface, because some SQLite data
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched # types include variables in them -- e.g. "varchar(30)" -- and can't be matched
# as a simple dictionary lookup. # as a simple dictionary lookup.
class FlexibleFieldLookupDict: class FlexibleFieldLookupDict:
# Maps SQL types to Django Field types. Some of the SQL types have multiple
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
base_data_types_reverse = {
'bool': 'BooleanField',
'boolean': 'BooleanField',
'smallint': 'SmallIntegerField',
'smallinteger': 'SmallIntegerField',
'int': 'IntegerField',
'integer': 'IntegerField',
'text': 'TextField',
'char': 'CharField',
'date': 'DateField',
'datetime': 'DateTimeField',
'time': 'TimeField',
}
def __getitem__(self, key): def __getitem__(self, key):
key = key.lower() key = key.lower()
try: try:
return BASE_DATA_TYPES_REVERSE[key] return self.base_data_types_reverse[key]
except KeyError: except KeyError:
import re import re
m = re.search(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$', key) m = re.search(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$', key)
@ -86,4 +32,58 @@ class FlexibleFieldLookupDict:
return ('CharField', {'max_length': int(m.group(1))}) return ('CharField', {'max_length': int(m.group(1))})
raise KeyError raise KeyError
DATA_TYPES_REVERSE = FlexibleFieldLookupDict() class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = FlexibleFieldLookupDict()
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND NOT name='sqlite_sequence'
ORDER BY name""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
return [(info['name'], info['type'], None, None, None, None,
info['null_ok']) for info in self._table_info(cursor, table_name)]
def get_relations(self, cursor, table_name):
raise NotImplementedError
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
indexes = {}
for info in self._table_info(cursor, table_name):
indexes[info['name']] = {'primary_key': info['pk'] != 0,
'unique': False}
cursor.execute('PRAGMA index_list(%s)' % self.connection.ops.quote_name(table_name))
# seq, name, unique
for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]:
if not unique:
continue
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
info = cursor.fetchall()
# Skip indexes across multiple fields
if len(info) != 1:
continue
name = info[0][2] # seqno, cid, name
indexes[name]['unique'] = True
return indexes
def _table_info(self, cursor, name):
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(name))
# cid, name, type, notnull, dflt_value, pk
return [{'name': field[1],
'type': field[2],
'null_ok': not field[3],
'pk': field[5] # undocumented
} for field in cursor.fetchall()]

View File

@ -7,7 +7,7 @@ try:
except ImportError: except ImportError:
from django.utils import _decimal as decimal # for Python 2.3 from django.utils import _decimal as decimal # for Python 2.3
from django.db import connection, get_creation_module from django.db import connection
from django.db.models import signals from django.db.models import signals
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
from django.dispatch import dispatcher from django.dispatch import dispatcher
@ -145,14 +145,14 @@ class Field(object):
# as the TextField Django field type, which means XMLField's # as the TextField Django field type, which means XMLField's
# get_internal_type() returns 'TextField'. # get_internal_type() returns 'TextField'.
# #
# But the limitation of the get_internal_type() / DATA_TYPES approach # But the limitation of the get_internal_type() / data_types approach
# is that it cannot handle database column types that aren't already # is that it cannot handle database column types that aren't already
# mapped to one of the built-in Django field types. In this case, you # mapped to one of the built-in Django field types. In this case, you
# can implement db_type() instead of get_internal_type() to specify # can implement db_type() instead of get_internal_type() to specify
# exactly which wacky database column type you want to use. # exactly which wacky database column type you want to use.
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
try: try:
return get_creation_module().DATA_TYPES[self.get_internal_type()] % data return connection.creation.data_types[self.get_internal_type()] % data
except KeyError: except KeyError:
return None return None

View File

@ -3,7 +3,6 @@ from django.conf import settings
from django.db.models import get_app, get_apps from django.db.models import get_app, get_apps
from django.test import _doctest as doctest from django.test import _doctest as doctest
from django.test.utils import setup_test_environment, teardown_test_environment from django.test.utils import setup_test_environment, teardown_test_environment
from django.test.utils import create_test_db, destroy_test_db
from django.test.testcases import OutputChecker, DocTestRunner from django.test.testcases import OutputChecker, DocTestRunner
# The module name for tests outside models.py # The module name for tests outside models.py
@ -139,9 +138,10 @@ def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]):
suite.addTest(test) suite.addTest(test)
old_name = settings.DATABASE_NAME old_name = settings.DATABASE_NAME
create_test_db(verbosity, autoclobber=not interactive) from django.db import connection
connection.creation.create_test_db(verbosity, autoclobber=not interactive)
result = unittest.TextTestRunner(verbosity=verbosity).run(suite) result = unittest.TextTestRunner(verbosity=verbosity).run(suite)
destroy_test_db(old_name, verbosity) connection.creation.destroy_test_db(old_name, verbosity)
teardown_test_environment() teardown_test_environment()

View File

@ -1,16 +1,11 @@
import sys, time, os import sys, time, os
from django.conf import settings from django.conf import settings
from django.db import connection, get_creation_module from django.db import connection
from django.core import mail from django.core import mail
from django.core.management import call_command
from django.test import signals from django.test import signals
from django.template import Template from django.template import Template
from django.utils.translation import deactivate from django.utils.translation import deactivate
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
def instrumented_test_render(self, context): def instrumented_test_render(self, context):
""" """
An instrumented Template render method, providing a signal An instrumented Template render method, providing a signal
@ -70,147 +65,3 @@ def teardown_test_environment():
del mail.outbox del mail.outbox
def _set_autocommit(connection):
"Make sure a connection is in autocommit mode."
if hasattr(connection.connection, "autocommit"):
if callable(connection.connection.autocommit):
connection.connection.autocommit(True)
else:
connection.connection.autocommit = True
elif hasattr(connection.connection, "set_isolation_level"):
connection.connection.set_isolation_level(0)
def get_mysql_create_suffix():
suffix = []
if settings.TEST_DATABASE_CHARSET:
suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET)
if settings.TEST_DATABASE_COLLATION:
suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION)
return ' '.join(suffix)
def get_postgresql_create_suffix():
assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time."
if settings.TEST_DATABASE_CHARSET:
return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET
return ''
def create_test_db(verbosity=1, autoclobber=False):
"""
Creates a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
# If the database backend wants to create the test DB itself, let it
creation_module = get_creation_module()
if hasattr(creation_module, "create_test_db"):
creation_module.create_test_db(settings, connection, verbosity, autoclobber)
return
if verbosity >= 1:
print "Creating test database..."
# If we're using SQLite, it's more convenient to test against an
# in-memory database. Using the TEST_DATABASE_NAME setting you can still choose
# to run on a physical database.
if settings.DATABASE_ENGINE == "sqlite3":
if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:":
TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME
# Erase the old test database
if verbosity >= 1:
print "Destroying old test database..."
if os.access(TEST_DATABASE_NAME, os.F_OK):
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
os.remove(TEST_DATABASE_NAME)
except Exception, e:
sys.stderr.write("Got an error deleting the old test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
if verbosity >= 1:
print "Creating test database..."
else:
TEST_DATABASE_NAME = ":memory:"
else:
suffix = {
'postgresql': get_postgresql_create_suffix,
'postgresql_psycopg2': get_postgresql_create_suffix,
'mysql': get_mysql_create_suffix,
}.get(settings.DATABASE_ENGINE, lambda: '')()
if settings.TEST_DATABASE_NAME:
TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME
else:
TEST_DATABASE_NAME = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
qn = connection.ops.quote_name
# Create the test database and connect to it. We need to autocommit
# if the database supports it because PostgreSQL doesn't allow
# CREATE/DROP DATABASE statements within transactions.
cursor = connection.cursor()
_set_autocommit(connection)
try:
cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix))
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
cursor.execute("DROP DATABASE %s" % qn(TEST_DATABASE_NAME))
if verbosity >= 1:
print "Creating test database..."
cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix))
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
connection.close()
settings.DATABASE_NAME = TEST_DATABASE_NAME
call_command('syncdb', verbosity=verbosity, interactive=False)
if settings.CACHE_BACKEND.startswith('db://'):
cache_name = settings.CACHE_BACKEND[len('db://'):]
call_command('createcachetable', cache_name)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
cursor = connection.cursor()
return TEST_DATABASE_NAME
def destroy_test_db(old_database_name, verbosity=1):
# If the database wants to drop the test DB itself, let it
creation_module = get_creation_module()
if hasattr(creation_module, "destroy_test_db"):
creation_module.destroy_test_db(settings, connection, old_database_name, verbosity)
return
if verbosity >= 1:
print "Destroying test database..."
connection.close()
TEST_DATABASE_NAME = settings.DATABASE_NAME
settings.DATABASE_NAME = old_database_name
if settings.DATABASE_ENGINE == "sqlite3":
if TEST_DATABASE_NAME and TEST_DATABASE_NAME != ":memory:":
# Remove the SQLite database file
os.remove(TEST_DATABASE_NAME)
else:
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = connection.cursor()
_set_autocommit(connection)
time.sleep(1) # To avoid "database is being accessed by other users" errors.
cursor.execute("DROP DATABASE %s" % connection.ops.quote_name(TEST_DATABASE_NAME))
connection.close()

View File

@ -1026,6 +1026,9 @@ a number of utility methods in the ``django.test.utils`` module.
black magic hooks into the template system and restoring normal e-mail black magic hooks into the template system and restoring normal e-mail
services. services.
The creation module of the database backend (``connection.creation``) also
provides some utilities that can be useful during testing.
``create_test_db(verbosity=1, autoclobber=False)`` ``create_test_db(verbosity=1, autoclobber=False)``
Creates a new test database and runs ``syncdb`` against it. Creates a new test database and runs ``syncdb`` against it.
@ -1044,7 +1047,7 @@ a number of utility methods in the ``django.test.utils`` module.
``create_test_db()`` has the side effect of modifying ``create_test_db()`` has the side effect of modifying
``settings.DATABASE_NAME`` to match the name of the test database. ``settings.DATABASE_NAME`` to match the name of the test database.
New in the Django development version, this function returns the name of **New in Django development version:** This function returns the name of
the test database that it created. the test database that it created.
``destroy_test_db(old_database_name, verbosity=1)`` ``destroy_test_db(old_database_name, verbosity=1)``

View File

@ -15,10 +15,6 @@ class Person(models.Model):
def __unicode__(self): def __unicode__(self):
return u'%s %s' % (self.first_name, self.last_name) return u'%s %s' % (self.first_name, self.last_name)
if connection.features.uses_case_insensitive_names:
t_convert = lambda x: x.upper()
else:
t_convert = lambda x: x
qn = connection.ops.quote_name qn = connection.ops.quote_name
__test__ = {'API_TESTS': """ __test__ = {'API_TESTS': """
@ -29,7 +25,7 @@ __test__ = {'API_TESTS': """
>>> opts = Square._meta >>> opts = Square._meta
>>> f1, f2 = opts.get_field('root'), opts.get_field('square') >>> f1, f2 = opts.get_field('root'), opts.get_field('square')
>>> query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' >>> query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)'
... % (t_convert(opts.db_table), qn(f1.column), qn(f2.column))) ... % (connection.introspection.table_name_converter(opts.db_table), qn(f1.column), qn(f2.column)))
>>> cursor.executemany(query, [(i, i**2) for i in range(-5, 6)]) and None or None >>> cursor.executemany(query, [(i, i**2) for i in range(-5, 6)]) and None or None
>>> Square.objects.order_by('root') >>> Square.objects.order_by('root')
[<Square: -5 ** 2 == 25>, <Square: -4 ** 2 == 16>, <Square: -3 ** 2 == 9>, <Square: -2 ** 2 == 4>, <Square: -1 ** 2 == 1>, <Square: 0 ** 2 == 0>, <Square: 1 ** 2 == 1>, <Square: 2 ** 2 == 4>, <Square: 3 ** 2 == 9>, <Square: 4 ** 2 == 16>, <Square: 5 ** 2 == 25>] [<Square: -5 ** 2 == 25>, <Square: -4 ** 2 == 16>, <Square: -3 ** 2 == 9>, <Square: -2 ** 2 == 4>, <Square: -1 ** 2 == 1>, <Square: 0 ** 2 == 0>, <Square: 1 ** 2 == 1>, <Square: 2 ** 2 == 4>, <Square: 3 ** 2 == 9>, <Square: 4 ** 2 == 16>, <Square: 5 ** 2 == 25>]
@ -48,7 +44,7 @@ __test__ = {'API_TESTS': """
>>> opts2 = Person._meta >>> opts2 = Person._meta
>>> f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') >>> f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
>>> query2 = ('SELECT %s, %s FROM %s ORDER BY %s' >>> query2 = ('SELECT %s, %s FROM %s ORDER BY %s'
... % (qn(f3.column), qn(f4.column), t_convert(opts2.db_table), ... % (qn(f3.column), qn(f4.column), connection.introspection.table_name_converter(opts2.db_table),
... qn(f3.column))) ... qn(f3.column)))
>>> cursor.execute(query2) and None or None >>> cursor.execute(query2) and None or None
>>> cursor.fetchone() >>> cursor.fetchone()