mirror of
https://github.com/django/django.git
synced 2025-01-23 08:39:17 +00:00
Fixed #10868 -- Stopped restoring database connections after the tests' execution in order to prevent the production database from being exposed to potential threads that would still be running. Also did a bit of PEP8-cleaning while I was in the area. Many thanks to ovidiu for the report and to Anssi Kääriäinen for thoroughly investigating this issue.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@17411 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
b5d0cc9091
commit
f1dc83cb98
@ -2,11 +2,13 @@ import sys
|
||||
import time
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.utils import load_backend
|
||||
|
||||
# The prefix to put on the default database name when creating
|
||||
# the test database.
|
||||
TEST_DATABASE_PREFIX = 'test_'
|
||||
|
||||
|
||||
class BaseDatabaseCreation(object):
|
||||
"""
|
||||
This class encapsulates all backend-specific differences that pertain to
|
||||
@ -57,35 +59,45 @@ class BaseDatabaseCreation(object):
|
||||
if tablespace and f.unique:
|
||||
# We must specify the index tablespace inline, because we
|
||||
# won't be generating a CREATE INDEX statement for this field.
|
||||
tablespace_sql = self.connection.ops.tablespace_sql(tablespace, inline=True)
|
||||
tablespace_sql = self.connection.ops.tablespace_sql(
|
||||
tablespace, inline=True)
|
||||
if tablespace_sql:
|
||||
field_output.append(tablespace_sql)
|
||||
if f.rel:
|
||||
ref_output, pending = self.sql_for_inline_foreign_key_references(f, known_models, style)
|
||||
ref_output, pending = self.sql_for_inline_foreign_key_references(
|
||||
f, known_models, style)
|
||||
if pending:
|
||||
pending_references.setdefault(f.rel.to, []).append((model, f))
|
||||
pending_references.setdefault(f.rel.to, []).append(
|
||||
(model, f))
|
||||
else:
|
||||
field_output.extend(ref_output)
|
||||
table_output.append(' '.join(field_output))
|
||||
for field_constraints in opts.unique_together:
|
||||
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
|
||||
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
|
||||
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' %
|
||||
", ".join(
|
||||
[style.SQL_FIELD(qn(opts.get_field(f).column))
|
||||
for f in field_constraints]))
|
||||
|
||||
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
|
||||
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' +
|
||||
style.SQL_TABLE(qn(opts.db_table)) + ' (']
|
||||
for i, line in enumerate(table_output): # Combine and add commas.
|
||||
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
|
||||
full_statement.append(
|
||||
' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
|
||||
full_statement.append(')')
|
||||
if opts.db_tablespace:
|
||||
tablespace_sql = self.connection.ops.tablespace_sql(opts.db_tablespace)
|
||||
tablespace_sql = self.connection.ops.tablespace_sql(
|
||||
opts.db_tablespace)
|
||||
if tablespace_sql:
|
||||
full_statement.append(tablespace_sql)
|
||||
full_statement.append(';')
|
||||
final_output.append('\n'.join(full_statement))
|
||||
|
||||
if opts.has_auto_field:
|
||||
# Add any extra SQL needed to support auto-incrementing primary keys.
|
||||
# Add any extra SQL needed to support auto-incrementing primary
|
||||
# keys.
|
||||
auto_column = opts.auto_field.db_column or opts.auto_field.name
|
||||
autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, auto_column)
|
||||
autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table,
|
||||
auto_column)
|
||||
if autoinc_sql:
|
||||
for stmt in autoinc_sql:
|
||||
final_output.append(stmt)
|
||||
@ -93,12 +105,15 @@ class BaseDatabaseCreation(object):
|
||||
return final_output, pending_references
|
||||
|
||||
def sql_for_inline_foreign_key_references(self, field, known_models, style):
|
||||
"Return the SQL snippet defining the foreign key reference for a field"
|
||||
"""
|
||||
Return the SQL snippet defining the foreign key reference for a field.
|
||||
"""
|
||||
qn = self.connection.ops.quote_name
|
||||
if field.rel.to in known_models:
|
||||
output = [style.SQL_KEYWORD('REFERENCES') + ' ' + \
|
||||
style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + \
|
||||
style.SQL_FIELD(qn(field.rel.to._meta.get_field(field.rel.field_name).column)) + ')' +
|
||||
output = [style.SQL_KEYWORD('REFERENCES') + ' ' +
|
||||
style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' +
|
||||
style.SQL_FIELD(qn(field.rel.to._meta.get_field(
|
||||
field.rel.field_name).column)) + ')' +
|
||||
self.connection.ops.deferrable_sql()
|
||||
]
|
||||
pending = False
|
||||
@ -111,7 +126,9 @@ class BaseDatabaseCreation(object):
|
||||
return output, pending
|
||||
|
||||
def sql_for_pending_references(self, model, style, pending_references):
|
||||
"Returns any ALTER TABLE statements to add constraints after the fact."
|
||||
"""
|
||||
Returns any ALTER TABLE statements to add constraints after the fact.
|
||||
"""
|
||||
from django.db.backends.util import truncate_name
|
||||
|
||||
if not model._meta.managed or model._meta.proxy:
|
||||
@ -128,16 +145,21 @@ class BaseDatabaseCreation(object):
|
||||
col = opts.get_field(f.rel.field_name).column
|
||||
# For MySQL, r_name must be unique in the first 64 characters.
|
||||
# So we are careful with character usage here.
|
||||
r_name = '%s_refs_%s_%s' % (r_col, col, self._digest(r_table, table))
|
||||
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
|
||||
(qn(r_table), qn(truncate_name(r_name, self.connection.ops.max_name_length())),
|
||||
r_name = '%s_refs_%s_%s' % (
|
||||
r_col, col, self._digest(r_table, table))
|
||||
final_output.append(style.SQL_KEYWORD('ALTER TABLE') +
|
||||
' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
|
||||
(qn(r_table), qn(truncate_name(
|
||||
r_name, self.connection.ops.max_name_length())),
|
||||
qn(r_col), qn(table), qn(col),
|
||||
self.connection.ops.deferrable_sql()))
|
||||
del pending_references[model]
|
||||
return final_output
|
||||
|
||||
def sql_indexes_for_model(self, model, style):
|
||||
"Returns the CREATE INDEX SQL statements for a single model"
|
||||
"""
|
||||
Returns the CREATE INDEX SQL statements for a single model.
|
||||
"""
|
||||
if not model._meta.managed or model._meta.proxy:
|
||||
return []
|
||||
output = []
|
||||
@ -146,7 +168,9 @@ class BaseDatabaseCreation(object):
|
||||
return output
|
||||
|
||||
def sql_indexes_for_field(self, model, f, style):
|
||||
"Return the CREATE INDEX SQL statements for a single model field"
|
||||
"""
|
||||
Return the CREATE INDEX SQL statements for a single model field.
|
||||
"""
|
||||
from django.db.backends.util import truncate_name
|
||||
|
||||
if f.db_index and not f.unique:
|
||||
@ -160,7 +184,8 @@ class BaseDatabaseCreation(object):
|
||||
tablespace_sql = ''
|
||||
i_name = '%s_%s' % (model._meta.db_table, self._digest(f.column))
|
||||
output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' +
|
||||
style.SQL_TABLE(qn(truncate_name(i_name, self.connection.ops.max_name_length()))) + ' ' +
|
||||
style.SQL_TABLE(qn(truncate_name(
|
||||
i_name, self.connection.ops.max_name_length()))) + ' ' +
|
||||
style.SQL_KEYWORD('ON') + ' ' +
|
||||
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' +
|
||||
"(%s)" % style.SQL_FIELD(qn(f.column)) +
|
||||
@ -170,7 +195,10 @@ class BaseDatabaseCreation(object):
|
||||
return output
|
||||
|
||||
def sql_destroy_model(self, model, references_to_delete, style):
|
||||
"Return the DROP TABLE and restraint dropping statements for a single model"
|
||||
"""
|
||||
Return the DROP TABLE and restraint dropping statements for a single
|
||||
model.
|
||||
"""
|
||||
if not model._meta.managed or model._meta.proxy:
|
||||
return []
|
||||
# Drop the table now
|
||||
@ -178,8 +206,8 @@ class BaseDatabaseCreation(object):
|
||||
output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
|
||||
style.SQL_TABLE(qn(model._meta.db_table)))]
|
||||
if model in references_to_delete:
|
||||
output.extend(self.sql_remove_table_constraints(model, references_to_delete, style))
|
||||
|
||||
output.extend(self.sql_remove_table_constraints(
|
||||
model, references_to_delete, style))
|
||||
if model._meta.has_auto_field:
|
||||
ds = self.connection.ops.drop_sequence_sql(model._meta.db_table)
|
||||
if ds:
|
||||
@ -188,7 +216,6 @@ class BaseDatabaseCreation(object):
|
||||
|
||||
def sql_remove_table_constraints(self, model, references_to_delete, style):
|
||||
from django.db.backends.util import truncate_name
|
||||
|
||||
if not model._meta.managed or model._meta.proxy:
|
||||
return []
|
||||
output = []
|
||||
@ -198,12 +225,14 @@ class BaseDatabaseCreation(object):
|
||||
col = f.column
|
||||
r_table = model._meta.db_table
|
||||
r_col = model._meta.get_field(f.rel.field_name).column
|
||||
r_name = '%s_refs_%s_%s' % (col, r_col, self._digest(table, r_table))
|
||||
r_name = '%s_refs_%s_%s' % (
|
||||
col, r_col, self._digest(table, r_table))
|
||||
output.append('%s %s %s %s;' % \
|
||||
(style.SQL_KEYWORD('ALTER TABLE'),
|
||||
style.SQL_TABLE(qn(table)),
|
||||
style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
|
||||
style.SQL_FIELD(qn(truncate_name(r_name, self.connection.ops.max_name_length())))))
|
||||
style.SQL_FIELD(qn(truncate_name(
|
||||
r_name, self.connection.ops.max_name_length())))))
|
||||
del references_to_delete[model]
|
||||
return output
|
||||
|
||||
@ -221,7 +250,8 @@ class BaseDatabaseCreation(object):
|
||||
test_db_repr = ''
|
||||
if verbosity >= 2:
|
||||
test_db_repr = " ('%s')" % test_database_name
|
||||
print "Creating test database for alias '%s'%s..." % (self.connection.alias, test_db_repr)
|
||||
print "Creating test database for alias '%s'%s..." % (
|
||||
self.connection.alias, test_db_repr)
|
||||
|
||||
self._create_test_db(verbosity, autoclobber)
|
||||
|
||||
@ -255,7 +285,8 @@ class BaseDatabaseCreation(object):
|
||||
for cache_alias in settings.CACHES:
|
||||
cache = get_cache(cache_alias)
|
||||
if isinstance(cache, BaseDatabaseCache):
|
||||
call_command('createcachetable', cache._table, database=self.connection.alias)
|
||||
call_command('createcachetable', cache._table,
|
||||
database=self.connection.alias)
|
||||
|
||||
# Get a cursor (even though we don't need one yet). This has
|
||||
# the side effect of initializing the test database.
|
||||
@ -275,7 +306,9 @@ class BaseDatabaseCreation(object):
|
||||
return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber):
|
||||
"Internal implementation - creates the test db tables."
|
||||
"""
|
||||
Internal implementation - creates the test db tables.
|
||||
"""
|
||||
suffix = self.sql_table_creation_suffix()
|
||||
|
||||
test_database_name = self._get_test_db_name()
|
||||
@ -288,19 +321,28 @@ class BaseDatabaseCreation(object):
|
||||
cursor = self.connection.cursor()
|
||||
self._prepare_for_test_db_ddl()
|
||||
try:
|
||||
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
|
||||
cursor.execute(
|
||||
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
|
||||
except Exception, e:
|
||||
sys.stderr.write("Got an error creating the test database: %s\n" % e)
|
||||
sys.stderr.write(
|
||||
"Got an error creating the test database: %s\n" % e)
|
||||
if not autoclobber:
|
||||
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
|
||||
confirm = raw_input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name)
|
||||
if autoclobber or confirm == 'yes':
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
print "Destroying old test database '%s'..." % self.connection.alias
|
||||
cursor.execute("DROP DATABASE %s" % qn(test_database_name))
|
||||
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
|
||||
print ("Destroying old test database '%s'..."
|
||||
% self.connection.alias)
|
||||
cursor.execute(
|
||||
"DROP DATABASE %s" % qn(test_database_name))
|
||||
cursor.execute(
|
||||
"CREATE DATABASE %s %s" % (qn(test_database_name),
|
||||
suffix))
|
||||
except Exception, e:
|
||||
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
|
||||
sys.stderr.write(
|
||||
"Got an error recreating the test database: %s\n" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
print "Tests cancelled."
|
||||
@ -319,21 +361,36 @@ class BaseDatabaseCreation(object):
|
||||
test_db_repr = ''
|
||||
if verbosity >= 2:
|
||||
test_db_repr = " ('%s')" % test_database_name
|
||||
print "Destroying test database for alias '%s'%s..." % (self.connection.alias, test_db_repr)
|
||||
self.connection.settings_dict['NAME'] = old_database_name
|
||||
print "Destroying test database for alias '%s'%s..." % (
|
||||
self.connection.alias, test_db_repr)
|
||||
|
||||
self._destroy_test_db(test_database_name, verbosity)
|
||||
# Temporarily use a new connection and a copy of the settings dict.
|
||||
# This prevents the production database from being exposed to potential
|
||||
# child threads while (or after) the test database is destroyed.
|
||||
# Refs #10868.
|
||||
settings_dict = self.connection.settings_dict.copy()
|
||||
settings_dict['NAME'] = old_database_name
|
||||
backend = load_backend(settings_dict['ENGINE'])
|
||||
new_connection = backend.DatabaseWrapper(
|
||||
settings_dict,
|
||||
alias='__destroy_test_db__',
|
||||
allow_thread_sharing=False)
|
||||
new_connection.creation._destroy_test_db(test_database_name, verbosity)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
"Internal implementation - remove the test db tables."
|
||||
"""
|
||||
Internal implementation - remove the test db tables.
|
||||
"""
|
||||
# Remove the test database to clean up after
|
||||
# ourselves. Connect to the previous database (not the test database)
|
||||
# to do so, because it's not allowed to delete a database while being
|
||||
# connected to it.
|
||||
cursor = self.connection.cursor()
|
||||
self._prepare_for_test_db_ddl()
|
||||
time.sleep(1) # To avoid "database is being accessed by other users" errors.
|
||||
cursor.execute("DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name))
|
||||
# Wait to avoid "database is being accessed by other users" errors.
|
||||
time.sleep(1)
|
||||
cursor.execute("DROP DATABASE %s"
|
||||
% self.connection.ops.quote_name(test_database_name))
|
||||
self.connection.close()
|
||||
|
||||
def set_autocommit(self):
|
||||
@ -346,15 +403,17 @@ class BaseDatabaseCreation(object):
|
||||
|
||||
def _prepare_for_test_db_ddl(self):
|
||||
"""
|
||||
Internal implementation - Hook for tasks that should be performed before
|
||||
the ``CREATE DATABASE``/``DROP DATABASE`` clauses used by testing code
|
||||
to create/ destroy test databases. Needed e.g. in PostgreSQL to rollback
|
||||
and close any active transaction.
|
||||
Internal implementation - Hook for tasks that should be performed
|
||||
before the ``CREATE DATABASE``/``DROP DATABASE`` clauses used by
|
||||
testing code to create/ destroy test databases. Needed e.g. in
|
||||
PostgreSQL to rollback and close any active transaction.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
"SQL to append to the end of the test table creation statements"
|
||||
"""
|
||||
SQL to append to the end of the test table creation statements.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def test_db_signature(self):
|
||||
|
@ -17,15 +17,18 @@ TEST_MODULE = 'tests'
|
||||
|
||||
doctestOutputChecker = OutputChecker()
|
||||
|
||||
|
||||
class DjangoTestRunner(unittest.TextTestRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"DjangoTestRunner is deprecated; it's functionality is indistinguishable from TextTestRunner",
|
||||
"DjangoTestRunner is deprecated; it's functionality is "
|
||||
"indistinguishable from TextTestRunner",
|
||||
DeprecationWarning
|
||||
)
|
||||
super(DjangoTestRunner, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def get_tests(app_module):
|
||||
parts = app_module.__name__.split('.')
|
||||
prefix, last = parts[:-1], parts[-1]
|
||||
@ -49,8 +52,11 @@ def get_tests(app_module):
|
||||
raise
|
||||
return test_module
|
||||
|
||||
|
||||
def build_suite(app_module):
|
||||
"Create a complete Django test suite for the provided application module"
|
||||
"""
|
||||
Create a complete Django test suite for the provided application module.
|
||||
"""
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
# Load unit and doctests in the models.py module. If module has
|
||||
@ -58,7 +64,8 @@ def build_suite(app_module):
|
||||
if hasattr(app_module, 'suite'):
|
||||
suite.addTest(app_module.suite())
|
||||
else:
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(app_module))
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(
|
||||
app_module))
|
||||
try:
|
||||
suite.addTest(doctest.DocTestSuite(app_module,
|
||||
checker=doctestOutputChecker,
|
||||
@ -76,25 +83,29 @@ def build_suite(app_module):
|
||||
if hasattr(test_module, 'suite'):
|
||||
suite.addTest(test_module.suite())
|
||||
else:
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(test_module))
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(
|
||||
test_module))
|
||||
try:
|
||||
suite.addTest(doctest.DocTestSuite(test_module,
|
||||
checker=doctestOutputChecker,
|
||||
runner=DocTestRunner))
|
||||
suite.addTest(doctest.DocTestSuite(
|
||||
test_module, checker=doctestOutputChecker,
|
||||
runner=DocTestRunner))
|
||||
except ValueError:
|
||||
# No doc tests in tests.py
|
||||
pass
|
||||
return suite
|
||||
|
||||
|
||||
def build_test(label):
|
||||
"""Construct a test case with the specified label. Label should be of the
|
||||
"""
|
||||
Construct a test case with the specified label. Label should be of the
|
||||
form model.TestClass or model.TestClass.test_method. Returns an
|
||||
instantiated test or test suite corresponding to the label provided.
|
||||
|
||||
"""
|
||||
parts = label.split('.')
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise ValueError("Test label '%s' should be of the form app.TestCase or app.TestCase.test_method" % label)
|
||||
raise ValueError("Test label '%s' should be of the form app.TestCase "
|
||||
"or app.TestCase.test_method" % label)
|
||||
|
||||
#
|
||||
# First, look for TestCase instances with a name that matches
|
||||
@ -112,9 +123,12 @@ def build_test(label):
|
||||
if issubclass(TestClass, (unittest.TestCase, real_unittest.TestCase)):
|
||||
if len(parts) == 2: # label is app.TestClass
|
||||
try:
|
||||
return unittest.TestLoader().loadTestsFromTestCase(TestClass)
|
||||
return unittest.TestLoader().loadTestsFromTestCase(
|
||||
TestClass)
|
||||
except TypeError:
|
||||
raise ValueError("Test label '%s' does not refer to a test class" % label)
|
||||
raise ValueError(
|
||||
"Test label '%s' does not refer to a test class"
|
||||
% label)
|
||||
else: # label is app.TestClass.test_method
|
||||
return TestClass(parts[2])
|
||||
except TypeError:
|
||||
@ -135,7 +149,8 @@ def build_test(label):
|
||||
for test in doctests:
|
||||
if test._dt_test.name in (
|
||||
'%s.%s' % (module.__name__, '.'.join(parts[1:])),
|
||||
'%s.__test__.%s' % (module.__name__, '.'.join(parts[1:]))):
|
||||
'%s.__test__.%s' % (
|
||||
module.__name__, '.'.join(parts[1:]))):
|
||||
tests.append(test)
|
||||
except ValueError:
|
||||
# No doctests found.
|
||||
@ -148,6 +163,7 @@ def build_test(label):
|
||||
# Construct a suite out of the tests that matched.
|
||||
return unittest.TestSuite(tests)
|
||||
|
||||
|
||||
def partition_suite(suite, classes, bins):
|
||||
"""
|
||||
Partitions a test suite by test type.
|
||||
@ -169,14 +185,15 @@ def partition_suite(suite, classes, bins):
|
||||
else:
|
||||
bins[-1].addTest(test)
|
||||
|
||||
|
||||
def reorder_suite(suite, classes):
|
||||
"""
|
||||
Reorders a test suite by test type.
|
||||
|
||||
classes is a sequence of types
|
||||
`classes` is a sequence of types
|
||||
|
||||
All tests of type clases[0] are placed first, then tests of type classes[1], etc.
|
||||
Tests with no match in classes are placed last.
|
||||
All tests of type classes[0] are placed first, then tests of type
|
||||
classes[1], etc. Tests with no match in classes are placed last.
|
||||
"""
|
||||
class_count = len(classes)
|
||||
bins = [unittest.TestSuite() for i in range(class_count+1)]
|
||||
@ -185,6 +202,7 @@ def reorder_suite(suite, classes):
|
||||
bins[0].addTests(bins[i+1])
|
||||
return bins[0]
|
||||
|
||||
|
||||
def dependency_ordered(test_databases, dependencies):
|
||||
"""Reorder test_databases into an order that honors the dependencies
|
||||
described in TEST_DEPENDENCIES.
|
||||
@ -200,7 +218,8 @@ def dependency_ordered(test_databases, dependencies):
|
||||
dependencies_satisfied = True
|
||||
for alias in aliases:
|
||||
if alias in dependencies:
|
||||
if all(a in resolved_databases for a in dependencies[alias]):
|
||||
if all(a in resolved_databases
|
||||
for a in dependencies[alias]):
|
||||
# all dependencies for this alias are satisfied
|
||||
dependencies.pop(alias)
|
||||
resolved_databases.add(alias)
|
||||
@ -216,10 +235,12 @@ def dependency_ordered(test_databases, dependencies):
|
||||
deferred.append((signature, (db_name, aliases)))
|
||||
|
||||
if not changed:
|
||||
raise ImproperlyConfigured("Circular dependency in TEST_DEPENDENCIES")
|
||||
raise ImproperlyConfigured(
|
||||
"Circular dependency in TEST_DEPENDENCIES")
|
||||
test_databases = deferred
|
||||
return ordered_test_databases
|
||||
|
||||
|
||||
class DjangoTestSuiteRunner(object):
|
||||
def __init__(self, verbosity=1, interactive=True, failfast=True, **kwargs):
|
||||
self.verbosity = verbosity
|
||||
@ -264,7 +285,8 @@ class DjangoTestSuiteRunner(object):
|
||||
if connection.settings_dict['TEST_MIRROR']:
|
||||
# If the database is marked as a test mirror, save
|
||||
# the alias.
|
||||
mirrored_aliases[alias] = connection.settings_dict['TEST_MIRROR']
|
||||
mirrored_aliases[alias] = (
|
||||
connection.settings_dict['TEST_MIRROR'])
|
||||
else:
|
||||
# Store a tuple with DB parameters that uniquely identify it.
|
||||
# If we have two aliases with the same values for that tuple,
|
||||
@ -276,53 +298,57 @@ class DjangoTestSuiteRunner(object):
|
||||
item[1].append(alias)
|
||||
|
||||
if 'TEST_DEPENDENCIES' in connection.settings_dict:
|
||||
dependencies[alias] = connection.settings_dict['TEST_DEPENDENCIES']
|
||||
dependencies[alias] = (
|
||||
connection.settings_dict['TEST_DEPENDENCIES'])
|
||||
else:
|
||||
if alias != DEFAULT_DB_ALIAS:
|
||||
dependencies[alias] = connection.settings_dict.get('TEST_DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
||||
dependencies[alias] = connection.settings_dict.get(
|
||||
'TEST_DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
||||
|
||||
# Second pass -- actually create the databases.
|
||||
old_names = []
|
||||
mirrors = []
|
||||
for signature, (db_name, aliases) in dependency_ordered(test_databases.items(), dependencies):
|
||||
for signature, (db_name, aliases) in dependency_ordered(
|
||||
test_databases.items(), dependencies):
|
||||
# Actually create the database for the first connection
|
||||
connection = connections[aliases[0]]
|
||||
old_names.append((connection, db_name, True))
|
||||
test_db_name = connection.creation.create_test_db(self.verbosity, autoclobber=not self.interactive)
|
||||
test_db_name = connection.creation.create_test_db(
|
||||
self.verbosity, autoclobber=not self.interactive)
|
||||
for alias in aliases[1:]:
|
||||
connection = connections[alias]
|
||||
if db_name:
|
||||
old_names.append((connection, db_name, False))
|
||||
connection.settings_dict['NAME'] = test_db_name
|
||||
else:
|
||||
# If settings_dict['NAME'] isn't defined, we have a backend where
|
||||
# the name isn't important -- e.g., SQLite, which uses :memory:.
|
||||
# Force create the database instead of assuming it's a duplicate.
|
||||
# If settings_dict['NAME'] isn't defined, we have a backend
|
||||
# where the name isn't important -- e.g., SQLite, which
|
||||
# uses :memory:. Force create the database instead of
|
||||
# assuming it's a duplicate.
|
||||
old_names.append((connection, db_name, True))
|
||||
connection.creation.create_test_db(self.verbosity, autoclobber=not self.interactive)
|
||||
connection.creation.create_test_db(
|
||||
self.verbosity, autoclobber=not self.interactive)
|
||||
|
||||
for alias, mirror_alias in mirrored_aliases.items():
|
||||
mirrors.append((alias, connections[alias].settings_dict['NAME']))
|
||||
connections[alias].settings_dict['NAME'] = connections[mirror_alias].settings_dict['NAME']
|
||||
connections[alias].settings_dict['NAME'] = (
|
||||
connections[mirror_alias].settings_dict['NAME'])
|
||||
connections[alias].features = connections[mirror_alias].features
|
||||
|
||||
return old_names, mirrors
|
||||
|
||||
def run_suite(self, suite, **kwargs):
|
||||
return unittest.TextTestRunner(verbosity=self.verbosity, failfast=self.failfast).run(suite)
|
||||
return unittest.TextTestRunner(
|
||||
verbosity=self.verbosity, failfast=self.failfast).run(suite)
|
||||
|
||||
def teardown_databases(self, old_config, **kwargs):
|
||||
from django.db import connections
|
||||
"""
|
||||
Destroys all the non-mirror databases.
|
||||
"""
|
||||
old_names, mirrors = old_config
|
||||
# Point all the mirrors back to the originals
|
||||
for alias, old_name in mirrors:
|
||||
connections[alias].settings_dict['NAME'] = old_name
|
||||
# Destroy all the non-mirror databases
|
||||
for connection, old_name, destroy in old_names:
|
||||
if destroy:
|
||||
connection.creation.destroy_test_db(old_name, self.verbosity)
|
||||
else:
|
||||
connection.settings_dict['NAME'] = old_name
|
||||
|
||||
def teardown_test_environment(self, **kwargs):
|
||||
unittest.removeHandler()
|
||||
|
@ -946,6 +946,19 @@ apply URL escaping again. This is wrong for URLs whose unquoted form contains
|
||||
a ``%xx`` sequence, but such URLs are very unlikely to happen in the wild,
|
||||
since they would confuse browsers too.
|
||||
|
||||
Database connections after running the test suite
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The default test runner now does not restore the database connections after the
|
||||
tests' execution any more. This prevents the production database from being
|
||||
exposed to potential threads that would still be running and attempting to
|
||||
create new connections.
|
||||
|
||||
If your code relied on connections to the production database being created
|
||||
after the tests' execution, then you may restore the previous behavior by
|
||||
subclassing ``DjangoTestRunner`` and overriding its ``teardown_databases()``
|
||||
method.
|
||||
|
||||
Features deprecated in 1.4
|
||||
==========================
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user