From 4f461542b59d4ba6c366eab08bde783f68977ce0 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Thu, 18 Jun 2009 17:31:36 +0000 Subject: [PATCH] [soc2009/multidb] Ensure that when a QuerySet is given a Query object in its construct that we correct the detect the connection that is being used git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11073 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 3 +-- django/db/utils.py | 10 ++++++++++ tests/regressiontests/multiple_database/tests.py | 7 ++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 89768ef023..c6071281b5 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -38,8 +38,7 @@ class QuerySet(object): self._result_cache = None self._iter = None self._sticky_filter = False - self._using = DEFAULT_DB_ALIAS # this will be wrong if a custom Query - # is provided with a non default connection + self._using = connections.alias_for_connection(self.query.connection) ######################## # PYTHON MAGIC METHODS # diff --git a/django/db/utils.py b/django/db/utils.py index 0b51648384..553bbca48d 100644 --- a/django/db/utils.py +++ b/django/db/utils.py @@ -68,3 +68,13 @@ class ConnectionHandler(object): def all(self): return [self[alias] for alias in self] + + def alias_for_connection(self, connection): + """ + Returns the alias for the given connection object. + """ + for alias in self: + conn_settings = self.databases[alias] + if conn_settings == connection.settings_dict: + return alias + return None diff --git a/tests/regressiontests/multiple_database/tests.py b/tests/regressiontests/multiple_database/tests.py index b7ca0e5e23..f38ddda490 100644 --- a/tests/regressiontests/multiple_database/tests.py +++ b/tests/regressiontests/multiple_database/tests.py @@ -6,7 +6,7 @@ from django.test import TestCase from models import Book -class DatabaseSettingTestCase(TestCase): +class ConnectionHandlerTestCase(TestCase): def setUp(self): settings.DATABASES['__test_db'] = { 'DATABASE_ENGINE': 'sqlite3', @@ -20,6 +20,11 @@ class DatabaseSettingTestCase(TestCase): connections['default'].cursor() connections['__test_db'].cursor() + def test_alias_for_connection(self): + for db in connections: + self.assertEqual(db, connections.alias_for_connection(connections[db])) + + class QueryTestCase(TestCase): def test_basic_queries(self): for db in connections: