diff --git a/TODO.txt b/TODO.txt index b13387351f..2aaf24a413 100644 --- a/TODO.txt +++ b/TODO.txt @@ -22,4 +22,4 @@ that need to be done. I'm trying to be as granular as possible. 6) Generate SQL, instead of an error for nesting on different DBs. -7) Time permitting add support for a ``DatabaseManager``. +7) Time permitting add support for a ``DatabaseManager``. diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 0f7bf45402..ac50115b48 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -33,6 +33,12 @@ class BaseDatabaseWrapper(local): self.queries = [] self.settings_dict = settings_dict + def __eq__(self, other): + return self.settings_dict == other.settings_dict + + def __ne__(self, other): + return not self == other + def _commit(self): if self.connection is not None: return self.connection.commit() diff --git a/django/db/models/query.py b/django/db/models/query.py index 8f4e435a12..4c80be1bc9 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -750,7 +750,7 @@ class QuerySet(object): Returns the internal query's SQL and parameters (as a tuple). """ obj = self.values("pk") - if connection.settings_dict == obj.query.connection.settings_dict: + if connection == obj.query.connection: return obj.query.as_nested_sql() raise ValueError("Can't do subqueries with queries on different DBs.") @@ -879,7 +879,7 @@ class ValuesQuerySet(QuerySet): % self.__class__.__name__) obj = self._clone() - if connection.settings_dict == obj.query.connection.settings_dict: + if connection == obj.query.connection: return obj.query.as_nested_sql() raise ValueError("Can't do subqueries with queries on different DBs.") diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 0a61e3a845..0de66ee8f7 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -13,7 +13,7 @@ from django.utils.tree import Node from django.utils.datastructures import SortedDict from django.utils.encoding import force_unicode from django.db.backends.util import truncate_name -from django.db import connection +from django.db import connection, connections from django.db.models import signals from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend @@ -126,6 +126,7 @@ class BaseQuery(object): obj_dict['related_select_fields'] = [] obj_dict['related_select_cols'] = [] del obj_dict['connection'] + obj_dict['connection_settings'] = self.connection.settings_dict # Fields can't be pickled, so if a field list has been # specified, we pickle the list of field names instead. @@ -147,10 +148,8 @@ class BaseQuery(object): ] self.__dict__.update(obj_dict) - # XXX: Need a better solution for this when multi-db stuff is - # supported. It's the only class-reference to the module-level - # connection variable. - self.connection = connection + self.connection = connections[connections.alias_for_settings( + obj_dict['connection_settings'])] def get_meta(self): """ diff --git a/django/db/utils.py b/django/db/utils.py index 6b78e5b91e..cee3c392a6 100644 --- a/django/db/utils.py +++ b/django/db/utils.py @@ -81,8 +81,14 @@ class ConnectionHandler(object): """ Returns the alias for the given connection object. """ + return self.alias_for_settings(connection.settings_dict) + + def alias_for_settings(self, settings_dict): + """ + Returns the alias for the given settings dictionary. + """ for alias in self: conn_settings = self.databases[alias] - if conn_settings == connection.settings_dict: + if conn_settings == settings_dict: return alias return None diff --git a/tests/regressiontests/multiple_database/tests.py b/tests/regressiontests/multiple_database/tests.py index 6c98291330..f5d41e9be0 100644 --- a/tests/regressiontests/multiple_database/tests.py +++ b/tests/regressiontests/multiple_database/tests.py @@ -1,4 +1,5 @@ import datetime +import pickle from django.conf import settings from django.db import connections @@ -79,6 +80,15 @@ class QueryTestCase(TestCase): months = Book.objects.dates('published', 'month').using(db) self.assertEqual(sorted(o.month for o in months), [5, 12]) +class PickleQuerySetTestCase(TestCase): + def test_pickling(self): + for db in connections: + qs = Book.objects.all() + self.assertEqual(qs.query.connection, + pickle.loads(pickle.dumps(qs)).query.connection) + self.assertEqual(qs._using, pickle.loads(pickle.dumps(qs))._using) + + if len(settings.DATABASES) > 1: class MetaUsingTestCase(TestCase): def test_meta_using_queries(self):