1
0
mirror of https://github.com/django/django.git synced 2025-07-05 18:29:11 +00:00

[soc2009/multidb] Ensure that when a QuerySet is given a Query object in its construct that we correct the detect the connection that is being used

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11073 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-06-18 17:31:36 +00:00
parent f3808a02c3
commit 4f461542b5
3 changed files with 17 additions and 3 deletions

View File

@ -38,8 +38,7 @@ class QuerySet(object):
self._result_cache = None self._result_cache = None
self._iter = None self._iter = None
self._sticky_filter = False self._sticky_filter = False
self._using = DEFAULT_DB_ALIAS # this will be wrong if a custom Query self._using = connections.alias_for_connection(self.query.connection)
# is provided with a non default connection
######################## ########################
# PYTHON MAGIC METHODS # # PYTHON MAGIC METHODS #

View File

@ -68,3 +68,13 @@ class ConnectionHandler(object):
def all(self): def all(self):
return [self[alias] for alias in self] return [self[alias] for alias in self]
def alias_for_connection(self, connection):
"""
Returns the alias for the given connection object.
"""
for alias in self:
conn_settings = self.databases[alias]
if conn_settings == connection.settings_dict:
return alias
return None

View File

@ -6,7 +6,7 @@ from django.test import TestCase
from models import Book from models import Book
class DatabaseSettingTestCase(TestCase): class ConnectionHandlerTestCase(TestCase):
def setUp(self): def setUp(self):
settings.DATABASES['__test_db'] = { settings.DATABASES['__test_db'] = {
'DATABASE_ENGINE': 'sqlite3', 'DATABASE_ENGINE': 'sqlite3',
@ -20,6 +20,11 @@ class DatabaseSettingTestCase(TestCase):
connections['default'].cursor() connections['default'].cursor()
connections['__test_db'].cursor() connections['__test_db'].cursor()
def test_alias_for_connection(self):
for db in connections:
self.assertEqual(db, connections.alias_for_connection(connections[db]))
class QueryTestCase(TestCase): class QueryTestCase(TestCase):
def test_basic_queries(self): def test_basic_queries(self):
for db in connections: for db in connections: