diff --git a/TODO.TXT b/TODO.TXT index 85081d548e..506cf75a6e 100644 --- a/TODO.TXT +++ b/TODO.TXT @@ -37,14 +37,10 @@ that need to be done. I'm trying to be as granular as possible. 7) Remove any references to the global ``django.db.connection`` object in the SQL creation process. This includes(but is probably not limited to): - * ``django.db.models.sql.where.Where`` - * ``django.db.models.sql.expressions.SQLEvaluator`` - * ``django.db.models.sql.query.Query`` uses ``connection`` in place of - ``self.connection`` in ``self.add_filter`` * The way we create ``Query`` from ``BaseQuery`` is awkward and hacky. * ``django.db.models.query.delete_objects`` * ``django.db.models.query.insert_query`` - * ``django.db.models.base.Model`` + * ``django.db.models.base.Model`` -- in ``save_base`` * ``django.db.models.fields.Field`` This uses it, as do it's subclasses. * ``django.db.models.fields.related`` It's used all over the place here, including opening a cursor and executing queries, so that's going to @@ -54,6 +50,7 @@ that need to be done. I'm trying to be as granular as possible. 5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at each stage from here on out). +5) Implement using kwarg on save() method. 6) Add the ``using`` method to ``QuerySet``. This will more or less "just work" across multiple databases that use the same backend. However, it will fail gratuitously when trying to use 2 different backends. diff --git a/django/db/models/base.py b/django/db/models/base.py index 13ff7e8f35..bb2dbf0504 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -580,17 +580,16 @@ class Model(object): def _get_next_or_previous_in_order(self, is_next): cachename = "__%s_order_cache" % is_next if not hasattr(self, cachename): - qn = connection.ops.quote_name - op = is_next and '>' or '<' + op = is_next and 'gt' or 'lt' order = not is_next and '-_order' or '_order' order_field = self._meta.order_with_respect_to - # FIXME: When querysets support nested queries, this can be turned - # into a pure queryset operation. - where = ['%s %s (SELECT %s FROM %s WHERE %s=%%s)' % \ - (qn('_order'), op, qn('_order'), - qn(self._meta.db_table), qn(self._meta.pk.column))] - params = [self.pk] - obj = self._default_manager.filter(**{order_field.name: getattr(self, order_field.attname)}).extra(where=where, params=params).order_by(order)[:1].get() + obj = self._default_manager.filter(**{ + order_field.name: getattr(self, order_field.attname) + }).filter(**{ + '_order__%s' % op: self._default_manager.values('_order').filter(**{ + self._meta.pk.name: self.pk + }) + }).order_by(order)[:1].get() setattr(self, cachename, obj) return getattr(self, cachename) diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 920cbffe73..826c4f398c 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -1,5 +1,4 @@ from django.core.exceptions import FieldError -from django.db import connection from django.db.models.fields import FieldDoesNotExist from django.db.models.sql.constants import LOOKUP_SEP @@ -10,6 +9,7 @@ class SQLEvaluator(object): self.cols = {} self.contains_aggregate = False + self.connection = query.connection self.expression.prepare(self, query, allow_joins) def as_sql(self, qn=None): @@ -19,6 +19,9 @@ class SQLEvaluator(object): for node, col in self.cols.items(): self.cols[node] = (change_map.get(col[0], col[0]), col[1]) + def update_connection(self, connection): + self.connection = connection + ##################################################### # Vistor methods for initial expression preparation # ##################################################### @@ -56,7 +59,7 @@ class SQLEvaluator(object): def evaluate_node(self, node, qn): if not qn: - qn = connection.ops.quote_name + qn = self.connection.ops.quote_name expressions = [] expression_params = [] @@ -75,11 +78,11 @@ class SQLEvaluator(object): expressions.append(format % sql) expression_params.extend(params) - return connection.ops.combine_expression(node.connector, expressions), expression_params + return self.connection.ops.combine_expression(node.connector, expressions), expression_params def evaluate_leaf(self, node, qn): if not qn: - qn = connection.ops.quote_name + qn = self.connection.ops.quote_name col = self.cols[node] if hasattr(col, 'as_sql'): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index d290d60e63..2059139600 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -67,10 +67,10 @@ class BaseQuery(object): # SQL-related attributes self.select = [] self.tables = [] # Aliases in the order they are created. - self.where = where() + self.where = where(connection=self.connection) self.where_class = where self.group_by = None - self.having = where() + self.having = where(connection=self.connection) self.order_by = [] self.low_mark, self.high_mark = 0, None # Used for offset/limit self.distinct = False @@ -151,6 +151,8 @@ class BaseQuery(object): # supported. It's the only class-reference to the module-level # connection variable. self.connection = connection + self.where.update_connection(self.connection) + self.having.update_connection(self.connection) def get_meta(self): """ @@ -243,6 +245,8 @@ class BaseQuery(object): obj.used_aliases = set() obj.filter_is_sticky = False obj.__dict__.update(kwargs) + obj.where.update_connection(obj.connection) # where and having track their own connection + obj.having.update_connection(obj.connection)# we need to keep this up to date if hasattr(obj, '_setup_query'): obj._setup_query() return obj @@ -530,10 +534,10 @@ class BaseQuery(object): self.where.add(EverythingNode(), AND) elif self.where: # rhs has an empty where clause. - w = self.where_class() + w = self.where_class(connection=self.connection) w.add(EverythingNode(), AND) else: - w = self.where_class() + w = self.where_class(connection=self.connection) self.where.add(w, connector) # Selection columns and extra extensions are those provided by 'rhs'. @@ -1534,7 +1538,7 @@ class BaseQuery(object): lookup_type = 'isnull' value = True elif (value == '' and lookup_type == 'exact' and - connection.features.interprets_empty_strings_as_nulls): + self.connection.features.interprets_empty_strings_as_nulls): lookup_type = 'isnull' value = True elif callable(value): @@ -1546,7 +1550,7 @@ class BaseQuery(object): for alias, aggregate in self.aggregates.items(): if alias == parts[0]: - entry = self.where_class() + entry = self.where_class(connection=self.connection) entry.add((aggregate, lookup_type, value), AND) if negate: entry.negate() @@ -1614,7 +1618,7 @@ class BaseQuery(object): for alias in join_list: if self.alias_map[alias][JOIN_TYPE] == self.LOUTER: j_col = self.alias_map[alias][RHS_JOIN_COL] - entry = self.where_class() + entry = self.where_class(connection=self.connection) entry.add((Constraint(alias, j_col, None), 'isnull', True), AND) entry.negate() self.where.add(entry, AND) @@ -1623,7 +1627,7 @@ class BaseQuery(object): # Leaky abstraction artifact: We have to specifically # exclude the "foo__in=[]" case from this handling, because # it's short-circuited in the Where class. - entry = self.where_class() + entry = self.where_class(connection=self.connection) entry.add((Constraint(alias, col, None), 'isnull', True), AND) entry.negate() self.where.add(entry, AND) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 0cd393756d..def1ff8ad8 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -48,7 +48,7 @@ class DeleteQuery(Query): for related in cls._meta.get_all_related_many_to_many_objects(): if not isinstance(related.field, generic.GenericRelation): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class() + where = self.where_class(connection=self.connection) where.add((Constraint(None, related.field.m2m_reverse_name(), related.field), 'in', @@ -57,14 +57,14 @@ class DeleteQuery(Query): self.do_query(related.field.m2m_db_table(), where) for f in cls._meta.many_to_many: - w1 = self.where_class() + w1 = self.where_class(connection=self.connection) if isinstance(f, generic.GenericRelation): from django.contrib.contenttypes.models import ContentType field = f.rel.to._meta.get_field(f.content_type_field_name) w1.add((Constraint(None, field.column, field), 'exact', ContentType.objects.get_for_model(cls).id), AND) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class() + where = self.where_class(connection=self.connection) where.add((Constraint(None, f.m2m_column_name(), f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -81,7 +81,7 @@ class DeleteQuery(Query): lot of values in pk_list. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class() + where = self.where_class(connection=self.connection) field = self.model._meta.pk where.add((Constraint(None, field.column, field), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -185,7 +185,7 @@ class UpdateQuery(Query): # Now we adjust the current query: reset the where clause and get rid # of all the tables we don't need (since they're in the sub-select). - self.where = self.where_class() + self.where = self.where_class(connection=self.connection) if self.related_updates or must_pre_select: # Either we're using the idents in multiple update queries (so # don't want them to change), or the db backend doesn't support @@ -209,7 +209,7 @@ class UpdateQuery(Query): This is used by the QuerySet.delete_objects() method. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - self.where = self.where_class() + self.where = self.where_class(connection=self.connection) f = self.model._meta.pk self.where.add((Constraint(None, f.column, f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index b27bc992a6..8420e49d95 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -4,7 +4,6 @@ Code to manage the creation and SQL rendering of 'where' constraints. import datetime from django.utils import tree -from django.db import connection from django.db.models.fields import Field from django.db.models.query_utils import QueryWrapper from datastructures import EmptyResultSet, FullResultSet @@ -34,6 +33,18 @@ class WhereNode(tree.Node): """ default = AND + def __init__(self, *args, **kwargs): + self.connection = kwargs.pop('connection', None) + super(WhereNode, self).__init__(*args, **kwargs) + + def __getstate__(self): + """ + Don't try to pickle the connection, our Query will restore it for us. + """ + data = self.__dict__.copy() + del data['connection'] + return data + def add(self, data, connector): """ Add a node to the where-tree. If the data is a list or tuple, it is @@ -53,7 +64,9 @@ class WhereNode(tree.Node): value = list(value) if hasattr(obj, "process"): try: - obj, params = obj.process(lookup_type, value) + # FIXME We're calling process too early, the connection could + # change + obj, params = obj.process(lookup_type, value, self.connection) except (EmptyShortCircuit, EmptyResultSet): # There are situations where we want to short-circuit any # comparisons and make sure that nothing is returned. One @@ -78,6 +91,14 @@ class WhereNode(tree.Node): super(WhereNode, self).add((obj, lookup_type, annotation, params), connector) + def update_connection(self, connection): + self.connection = connection + for child in self.children: + if hasattr(child, 'update_connection'): + child.update_connection(connection) + elif hasattr(child[3], 'update_connection'): + child[3].update_connection(connection) + def as_sql(self, qn=None): """ Returns the SQL version of the where clause and the value to be @@ -88,7 +109,7 @@ class WhereNode(tree.Node): recursion). """ if not qn: - qn = connection.ops.quote_name + qn = self.connection.ops.quote_name if not self.children: return None, [] result = [] @@ -153,7 +174,7 @@ class WhereNode(tree.Node): field_sql = lvalue.as_sql(quote_func=qn) if value_annot is datetime.datetime: - cast_sql = connection.ops.datetime_cast_sql() + cast_sql = self.connection.ops.datetime_cast_sql() else: cast_sql = '%s' @@ -163,10 +184,10 @@ class WhereNode(tree.Node): else: extra = '' - if lookup_type in connection.operators: - format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) + if lookup_type in self.connection.operators: + format = "%s %%s %%s" % (self.connection.ops.lookup_cast(lookup_type),) return (format % (field_sql, - connection.operators[lookup_type] % cast_sql, + self.connection.operators[lookup_type] % cast_sql, extra), params) if lookup_type == 'in': @@ -179,15 +200,15 @@ class WhereNode(tree.Node): elif lookup_type in ('range', 'year'): return ('%s BETWEEN %%s and %%s' % field_sql, params) elif lookup_type in ('month', 'day', 'week_day'): - return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql), + return ('%s = %%s' % self.connection.ops.date_extract_sql(lookup_type, field_sql), params) elif lookup_type == 'isnull': return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), ()) elif lookup_type == 'search': - return (connection.ops.fulltext_search_sql(field_sql), params) + return (self.connection.ops.fulltext_search_sql(field_sql), params) elif lookup_type in ('regex', 'iregex'): - return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params + return self.connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params raise TypeError('Invalid lookup_type: %r' % lookup_type) @@ -202,7 +223,7 @@ class WhereNode(tree.Node): lhs = '%s.%s' % (qn(table_alias), qn(name)) else: lhs = qn(name) - return connection.ops.field_cast_sql(db_type) % lhs + return self.connection.ops.field_cast_sql(db_type) % lhs def relabel_aliases(self, change_map, node=None): """ @@ -257,7 +278,7 @@ class Constraint(object): def __init__(self, alias, col, field): self.alias, self.col, self.field = alias, col, field - def process(self, lookup_type, value): + def process(self, lookup_type, value, connection): """ Returns a tuple of data suitable for inclusion in a WhereNode instance.