diff --git a/django/db/__init__.py b/django/db/__init__.py index 22bb720dd4..b0dfa31848 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -14,6 +14,8 @@ __all__ = ('backend', 'connection', 'DatabaseError') # singleton to represent the default connection in connections class dummy(object): + def __repr__(self): + return self.__str__() def __str__(self): return '' _default = dummy() @@ -28,22 +30,12 @@ if not settings.DATABASE_ENGINE: settings.DATABASE_ENGINE = 'dummy' -def connect(settings): +def connect(settings, **kw): """Connect to the database specified in settings. Returns a ConnectionInfo on success, raises ImproperlyConfigured if the settings don't specify a valid database connection. """ - info = ConnectionInfo(settings) - - # Register an event that closes the database connection - # when a Django request is finished. - dispatcher.connect(info.close, signal=signals.request_finished) - - # Register an event that resets connection.queries - # when a Django request is started. - dispatcher.connect(info.reset_queries, signal=signals.request_started) - - return info + return ConnectionInfo(settings, **kw) class ConnectionInfo(object): @@ -52,7 +44,8 @@ class ConnectionInfo(object): creation, introspection, and shell modules, closing the connection, and resetting the connection's query log. """ - def __init__(self, settings=None): + def __init__(self, settings=None, **kw): + super(ConnectionInfo, self).__init__(**kw) if settings is None: from django.conf import settings self.settings = settings @@ -60,6 +53,14 @@ class ConnectionInfo(object): self.connection = self.backend.DatabaseWrapper(settings) self.DatabaseError = self.backend.DatabaseError + # Register an event that closes the database connection + # when a Django request is finished. + dispatcher.connect(self.close, signal=signals.request_finished) + + # Register an event that resets connection.queries + # when a Django request is started. + dispatcher.connect(self.reset_queries, signal=signals.request_started) + def __repr__(self): return "Connection: %r (ENGINE=%s NAME=%s)" \ % (self.connection, @@ -114,7 +115,7 @@ class ConnectionInfo(object): """Reset log of queries executed by connection""" self.connection.queries = [] - + class LazyConnectionManager(object): """Manages named connections lazily, instantiating them as they are requested. @@ -123,8 +124,14 @@ class LazyConnectionManager(object): self.local = local() self.local.connections = {} + # Reset connections on request finish, to make sure each request can + # load the correct connections for its settings + dispatcher.connect(self.reset, signal=signals.request_finished) + def __iter__(self): - return self.local.connections.keys() + # Iterates only over *active* connections, not all possible + # connections + return iter(self.local.connections.keys()) def __getattr__(self, attr): return getattr(self.local.connections, attr) @@ -177,10 +184,27 @@ class LazyConnectionManager(object): cnx[name] = connect(database) return cnx[name] + def items(self): + # Iterates over *all possible* connections + items = [] + for key in self.keys(): + items.append((key, self[key])) + return items + + def keys(self): + # Iterates over *all possible* connections + keys = [_default] + try: + keys.extend(settings.OTHER_DATABASES.keys()) + except AttributeError: + pass + return keys + def reset(self): + if not hasattr(self.local, 'connections'): + return self.local.connections = {} - - + def model_connection_name(klass): """Get the connection name that a model is configured to use, with the current settings. @@ -261,10 +285,11 @@ class ConnectionInfoDescriptor(object): def get_connection(self, instance): return connections[model_connection_name(instance.model)] - - def reset(self): - self.local.cnx = {} + def reset(self): + if not hasattr(self.local, 'cnx'): + return + self.local.cnx = {} class LocalizingProxy: """A lazy-initializing proxy. The proxied object is not @@ -277,6 +302,13 @@ class LocalizingProxy: self.__func = func self.__arg = arg self.__kw = kw + + # We need to clear out this thread's storage at the end of each + # request, in case new settings are loaded with the next + def reset(stor=storage, name=name): + if hasattr(stor, name): + delattr(stor, name) + dispatcher.connect(reset, signal=signals.request_finished) def __getattr__(self, attr): # Private (__*) attributes are munged @@ -295,13 +327,12 @@ class LocalizingProxy: self.__dict__[attr] = val return try: - print self.__storage, self.__name stor = getattr(self.__storage, self.__name) except AttributeError: stor = self.__func(*self.__arg) setattr(self.__storage, self.__name, stor) setattr(stor, attr, val) - + # Create a manager for named connections connections = LazyConnectionManager() @@ -327,11 +358,6 @@ runshell = LocalizingProxy('runshell', _local, lambda: connections[_default].runshell) -# Reset connections on request finish, to make sure each request can -# load the correct connections for its settings -dispatcher.connect(connections.reset, signal=signals.request_finished) - - # Register an event that rolls back all connections # when a Django request has an exception. def _rollback_on_exception(): diff --git a/django/test/utils.py b/django/test/utils.py index 1a144031ab..d39a889bc3 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -1,6 +1,5 @@ import sys, time from django.conf import settings - from django.db import backend, connect, connection, connection_info, connections from django.dispatch import dispatcher from django.test import signals @@ -94,23 +93,15 @@ def create_test_db(verbosity=1, autoclobber=False): cursor = connection.cursor() # Fill OTHER_DATABASES with the TEST_DATABASES settings, - # and connect each named connection to the test database, using - # a separate connection instance for each (so, eg, transactions don't - # collide) + # and connect each named connection to the test database. test_databases = {} for db_name in settings.TEST_DATABASES: - if settings.DATABASE_ENGINE == 'sqlite3': - full_name = TEST_DATABASE_NAME - else: - full_name = TEST_DATABASE_NAME + db_name - db_st = {'DATABASE_NAME': full_name} + db_st = {'DATABASE_NAME': TEST_DATABASE_NAME} if db_name in settings.TEST_DATABASE_MODELS: db_st['MODELS'] = settings.TEST_DATABASE_MODELS.get(db_name, []) test_databases[db_name] = db_st - connections[db_name] = connect(connection_info.settings) - connections[db_name].connection.cursor() # Initialize it settings.OTHER_DATABASES = test_databases - + def destroy_test_db(old_database_name, old_databases, verbosity=1): # Unless we're using SQLite, remove the test database to clean up after # ourselves. Connect to the previous database (not the test database) @@ -118,15 +109,26 @@ def destroy_test_db(old_database_name, old_databases, verbosity=1): # connected to it. if verbosity >= 1: print "Destroying test database..." + + connection.close() for cnx in connections.keys(): connections[cnx].close() + connections.reset() + TEST_DATABASE_NAME = settings.DATABASE_NAME + if verbosity >= 2: + print "Closed connections to %s" % TEST_DATABASE_NAME settings.DATABASE_NAME = old_database_name - + if settings.DATABASE_ENGINE != "sqlite3": settings.OTHER_DATABASES = old_databases for cnx in connections.keys(): - connections[cnx].connection.cursor() + try: + connections[cnx].connection.cursor() + except (KeyboardInterrupt, SystemExit): + raise + except: + pass cursor = connection.cursor() _set_autocommit(connection) time.sleep(1) # To avoid "database is being accessed by other users" errors. diff --git a/tests/modeltests/multiple_databases/models.py b/tests/modeltests/multiple_databases/models.py index 97e91b429c..a0fa492631 100644 --- a/tests/modeltests/multiple_databases/models.py +++ b/tests/modeltests/multiple_databases/models.py @@ -84,22 +84,24 @@ Connection: ... >>> connections['_b'] Connection: ... -# Let's see what connections are available.The default connection is -# in there, but let's ignore it +# Let's see what connections are available. The default connection is always +# included in connections as well, and may be accessed as connections[_default]. ->>> non_default = connections.keys() ->>> non_default.remove(_default) ->>> non_default.sort() ->>> non_default -['_a', '_b'] +>>> connection_names = connections.keys() +>>> connection_names.sort() +>>> connection_names +[, '_a', '_b'] # Invalid connection names raise ImproperlyConfigured + >>> connections['bad'] Traceback (most recent call last): ... ImproperlyConfigured: No database connection 'bad' has been configured -# Models can access their connections through their managers +# The model_connection_name() function will tell you the name of the +# connection that a model is configured to use. + >>> model_connection_name(Artist) '_a' >>> model_connection_name(Widget) @@ -116,6 +118,15 @@ True >>> list(artists) [] +# Models can access their connections through the db property of their +# default manager. + +>>> paul = _[0] +>>> Artist.objects.db +Connection: ... (ENGINE=... NAME=...) +>>> paul._default_manager.db +Connection: ... (ENGINE=... NAME=...) + # When transactions are not managed, model save will commit only # for the model's connection.