1
0
mirror of https://github.com/django/django.git synced 2025-07-05 10:19:20 +00:00

[soc2009/multidb] Removed several instances of unnescary usage of the global connection object, where instead we should be using the connection object for the given Query

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@10896 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-06-03 02:16:53 +00:00
parent 9286db5145
commit 15d405077e
6 changed files with 68 additions and 44 deletions

View File

@ -37,14 +37,10 @@ that need to be done. I'm trying to be as granular as possible.
7) Remove any references to the global ``django.db.connection`` object in the 7) Remove any references to the global ``django.db.connection`` object in the
SQL creation process. This includes(but is probably not limited to): SQL creation process. This includes(but is probably not limited to):
* ``django.db.models.sql.where.Where``
* ``django.db.models.sql.expressions.SQLEvaluator``
* ``django.db.models.sql.query.Query`` uses ``connection`` in place of
``self.connection`` in ``self.add_filter``
* The way we create ``Query`` from ``BaseQuery`` is awkward and hacky. * The way we create ``Query`` from ``BaseQuery`` is awkward and hacky.
* ``django.db.models.query.delete_objects`` * ``django.db.models.query.delete_objects``
* ``django.db.models.query.insert_query`` * ``django.db.models.query.insert_query``
* ``django.db.models.base.Model`` * ``django.db.models.base.Model`` -- in ``save_base``
* ``django.db.models.fields.Field`` This uses it, as do it's subclasses. * ``django.db.models.fields.Field`` This uses it, as do it's subclasses.
* ``django.db.models.fields.related`` It's used all over the place here, * ``django.db.models.fields.related`` It's used all over the place here,
including opening a cursor and executing queries, so that's going to including opening a cursor and executing queries, so that's going to
@ -54,6 +50,7 @@ that need to be done. I'm trying to be as granular as possible.
5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at 5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at
each stage from here on out). each stage from here on out).
5) Implement using kwarg on save() method.
6) Add the ``using`` method to ``QuerySet``. This will more or less "just 6) Add the ``using`` method to ``QuerySet``. This will more or less "just
work" across multiple databases that use the same backend. However, it work" across multiple databases that use the same backend. However, it
will fail gratuitously when trying to use 2 different backends. will fail gratuitously when trying to use 2 different backends.

View File

@ -580,17 +580,16 @@ class Model(object):
def _get_next_or_previous_in_order(self, is_next): def _get_next_or_previous_in_order(self, is_next):
cachename = "__%s_order_cache" % is_next cachename = "__%s_order_cache" % is_next
if not hasattr(self, cachename): if not hasattr(self, cachename):
qn = connection.ops.quote_name op = is_next and 'gt' or 'lt'
op = is_next and '>' or '<'
order = not is_next and '-_order' or '_order' order = not is_next and '-_order' or '_order'
order_field = self._meta.order_with_respect_to order_field = self._meta.order_with_respect_to
# FIXME: When querysets support nested queries, this can be turned obj = self._default_manager.filter(**{
# into a pure queryset operation. order_field.name: getattr(self, order_field.attname)
where = ['%s %s (SELECT %s FROM %s WHERE %s=%%s)' % \ }).filter(**{
(qn('_order'), op, qn('_order'), '_order__%s' % op: self._default_manager.values('_order').filter(**{
qn(self._meta.db_table), qn(self._meta.pk.column))] self._meta.pk.name: self.pk
params = [self.pk] })
obj = self._default_manager.filter(**{order_field.name: getattr(self, order_field.attname)}).extra(where=where, params=params).order_by(order)[:1].get() }).order_by(order)[:1].get()
setattr(self, cachename, obj) setattr(self, cachename, obj)
return getattr(self, cachename) return getattr(self, cachename)

View File

@ -1,5 +1,4 @@
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
from django.db.models.sql.constants import LOOKUP_SEP from django.db.models.sql.constants import LOOKUP_SEP
@ -10,6 +9,7 @@ class SQLEvaluator(object):
self.cols = {} self.cols = {}
self.contains_aggregate = False self.contains_aggregate = False
self.connection = query.connection
self.expression.prepare(self, query, allow_joins) self.expression.prepare(self, query, allow_joins)
def as_sql(self, qn=None): def as_sql(self, qn=None):
@ -19,6 +19,9 @@ class SQLEvaluator(object):
for node, col in self.cols.items(): for node, col in self.cols.items():
self.cols[node] = (change_map.get(col[0], col[0]), col[1]) self.cols[node] = (change_map.get(col[0], col[0]), col[1])
def update_connection(self, connection):
self.connection = connection
##################################################### #####################################################
# Vistor methods for initial expression preparation # # Vistor methods for initial expression preparation #
##################################################### #####################################################
@ -56,7 +59,7 @@ class SQLEvaluator(object):
def evaluate_node(self, node, qn): def evaluate_node(self, node, qn):
if not qn: if not qn:
qn = connection.ops.quote_name qn = self.connection.ops.quote_name
expressions = [] expressions = []
expression_params = [] expression_params = []
@ -75,11 +78,11 @@ class SQLEvaluator(object):
expressions.append(format % sql) expressions.append(format % sql)
expression_params.extend(params) expression_params.extend(params)
return connection.ops.combine_expression(node.connector, expressions), expression_params return self.connection.ops.combine_expression(node.connector, expressions), expression_params
def evaluate_leaf(self, node, qn): def evaluate_leaf(self, node, qn):
if not qn: if not qn:
qn = connection.ops.quote_name qn = self.connection.ops.quote_name
col = self.cols[node] col = self.cols[node]
if hasattr(col, 'as_sql'): if hasattr(col, 'as_sql'):

View File

@ -67,10 +67,10 @@ class BaseQuery(object):
# SQL-related attributes # SQL-related attributes
self.select = [] self.select = []
self.tables = [] # Aliases in the order they are created. self.tables = [] # Aliases in the order they are created.
self.where = where() self.where = where(connection=self.connection)
self.where_class = where self.where_class = where
self.group_by = None self.group_by = None
self.having = where() self.having = where(connection=self.connection)
self.order_by = [] self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False self.distinct = False
@ -151,6 +151,8 @@ class BaseQuery(object):
# supported. It's the only class-reference to the module-level # supported. It's the only class-reference to the module-level
# connection variable. # connection variable.
self.connection = connection self.connection = connection
self.where.update_connection(self.connection)
self.having.update_connection(self.connection)
def get_meta(self): def get_meta(self):
""" """
@ -243,6 +245,8 @@ class BaseQuery(object):
obj.used_aliases = set() obj.used_aliases = set()
obj.filter_is_sticky = False obj.filter_is_sticky = False
obj.__dict__.update(kwargs) obj.__dict__.update(kwargs)
obj.where.update_connection(obj.connection) # where and having track their own connection
obj.having.update_connection(obj.connection)# we need to keep this up to date
if hasattr(obj, '_setup_query'): if hasattr(obj, '_setup_query'):
obj._setup_query() obj._setup_query()
return obj return obj
@ -530,10 +534,10 @@ class BaseQuery(object):
self.where.add(EverythingNode(), AND) self.where.add(EverythingNode(), AND)
elif self.where: elif self.where:
# rhs has an empty where clause. # rhs has an empty where clause.
w = self.where_class() w = self.where_class(connection=self.connection)
w.add(EverythingNode(), AND) w.add(EverythingNode(), AND)
else: else:
w = self.where_class() w = self.where_class(connection=self.connection)
self.where.add(w, connector) self.where.add(w, connector)
# Selection columns and extra extensions are those provided by 'rhs'. # Selection columns and extra extensions are those provided by 'rhs'.
@ -1534,7 +1538,7 @@ class BaseQuery(object):
lookup_type = 'isnull' lookup_type = 'isnull'
value = True value = True
elif (value == '' and lookup_type == 'exact' and elif (value == '' and lookup_type == 'exact' and
connection.features.interprets_empty_strings_as_nulls): self.connection.features.interprets_empty_strings_as_nulls):
lookup_type = 'isnull' lookup_type = 'isnull'
value = True value = True
elif callable(value): elif callable(value):
@ -1546,7 +1550,7 @@ class BaseQuery(object):
for alias, aggregate in self.aggregates.items(): for alias, aggregate in self.aggregates.items():
if alias == parts[0]: if alias == parts[0]:
entry = self.where_class() entry = self.where_class(connection=self.connection)
entry.add((aggregate, lookup_type, value), AND) entry.add((aggregate, lookup_type, value), AND)
if negate: if negate:
entry.negate() entry.negate()
@ -1614,7 +1618,7 @@ class BaseQuery(object):
for alias in join_list: for alias in join_list:
if self.alias_map[alias][JOIN_TYPE] == self.LOUTER: if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
j_col = self.alias_map[alias][RHS_JOIN_COL] j_col = self.alias_map[alias][RHS_JOIN_COL]
entry = self.where_class() entry = self.where_class(connection=self.connection)
entry.add((Constraint(alias, j_col, None), 'isnull', True), AND) entry.add((Constraint(alias, j_col, None), 'isnull', True), AND)
entry.negate() entry.negate()
self.where.add(entry, AND) self.where.add(entry, AND)
@ -1623,7 +1627,7 @@ class BaseQuery(object):
# Leaky abstraction artifact: We have to specifically # Leaky abstraction artifact: We have to specifically
# exclude the "foo__in=[]" case from this handling, because # exclude the "foo__in=[]" case from this handling, because
# it's short-circuited in the Where class. # it's short-circuited in the Where class.
entry = self.where_class() entry = self.where_class(connection=self.connection)
entry.add((Constraint(alias, col, None), 'isnull', True), AND) entry.add((Constraint(alias, col, None), 'isnull', True), AND)
entry.negate() entry.negate()
self.where.add(entry, AND) self.where.add(entry, AND)

View File

@ -48,7 +48,7 @@ class DeleteQuery(Query):
for related in cls._meta.get_all_related_many_to_many_objects(): for related in cls._meta.get_all_related_many_to_many_objects():
if not isinstance(related.field, generic.GenericRelation): if not isinstance(related.field, generic.GenericRelation):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class() where = self.where_class(connection=self.connection)
where.add((Constraint(None, where.add((Constraint(None,
related.field.m2m_reverse_name(), related.field), related.field.m2m_reverse_name(), related.field),
'in', 'in',
@ -57,14 +57,14 @@ class DeleteQuery(Query):
self.do_query(related.field.m2m_db_table(), where) self.do_query(related.field.m2m_db_table(), where)
for f in cls._meta.many_to_many: for f in cls._meta.many_to_many:
w1 = self.where_class() w1 = self.where_class(connection=self.connection)
if isinstance(f, generic.GenericRelation): if isinstance(f, generic.GenericRelation):
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
field = f.rel.to._meta.get_field(f.content_type_field_name) field = f.rel.to._meta.get_field(f.content_type_field_name)
w1.add((Constraint(None, field.column, field), 'exact', w1.add((Constraint(None, field.column, field), 'exact',
ContentType.objects.get_for_model(cls).id), AND) ContentType.objects.get_for_model(cls).id), AND)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class() where = self.where_class(connection=self.connection)
where.add((Constraint(None, f.m2m_column_name(), f), 'in', where.add((Constraint(None, f.m2m_column_name(), f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
AND) AND)
@ -81,7 +81,7 @@ class DeleteQuery(Query):
lot of values in pk_list. lot of values in pk_list.
""" """
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class() where = self.where_class(connection=self.connection)
field = self.model._meta.pk field = self.model._meta.pk
where.add((Constraint(None, field.column, field), 'in', where.add((Constraint(None, field.column, field), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
@ -185,7 +185,7 @@ class UpdateQuery(Query):
# Now we adjust the current query: reset the where clause and get rid # Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select). # of all the tables we don't need (since they're in the sub-select).
self.where = self.where_class() self.where = self.where_class(connection=self.connection)
if self.related_updates or must_pre_select: if self.related_updates or must_pre_select:
# Either we're using the idents in multiple update queries (so # Either we're using the idents in multiple update queries (so
# don't want them to change), or the db backend doesn't support # don't want them to change), or the db backend doesn't support
@ -209,7 +209,7 @@ class UpdateQuery(Query):
This is used by the QuerySet.delete_objects() method. This is used by the QuerySet.delete_objects() method.
""" """
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class() self.where = self.where_class(connection=self.connection)
f = self.model._meta.pk f = self.model._meta.pk
self.where.add((Constraint(None, f.column, f), 'in', self.where.add((Constraint(None, f.column, f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),

View File

@ -4,7 +4,6 @@ Code to manage the creation and SQL rendering of 'where' constraints.
import datetime import datetime
from django.utils import tree from django.utils import tree
from django.db import connection
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
from datastructures import EmptyResultSet, FullResultSet from datastructures import EmptyResultSet, FullResultSet
@ -34,6 +33,18 @@ class WhereNode(tree.Node):
""" """
default = AND default = AND
def __init__(self, *args, **kwargs):
self.connection = kwargs.pop('connection', None)
super(WhereNode, self).__init__(*args, **kwargs)
def __getstate__(self):
"""
Don't try to pickle the connection, our Query will restore it for us.
"""
data = self.__dict__.copy()
del data['connection']
return data
def add(self, data, connector): def add(self, data, connector):
""" """
Add a node to the where-tree. If the data is a list or tuple, it is Add a node to the where-tree. If the data is a list or tuple, it is
@ -53,7 +64,9 @@ class WhereNode(tree.Node):
value = list(value) value = list(value)
if hasattr(obj, "process"): if hasattr(obj, "process"):
try: try:
obj, params = obj.process(lookup_type, value) # FIXME We're calling process too early, the connection could
# change
obj, params = obj.process(lookup_type, value, self.connection)
except (EmptyShortCircuit, EmptyResultSet): except (EmptyShortCircuit, EmptyResultSet):
# There are situations where we want to short-circuit any # There are situations where we want to short-circuit any
# comparisons and make sure that nothing is returned. One # comparisons and make sure that nothing is returned. One
@ -78,6 +91,14 @@ class WhereNode(tree.Node):
super(WhereNode, self).add((obj, lookup_type, annotation, params), super(WhereNode, self).add((obj, lookup_type, annotation, params),
connector) connector)
def update_connection(self, connection):
self.connection = connection
for child in self.children:
if hasattr(child, 'update_connection'):
child.update_connection(connection)
elif hasattr(child[3], 'update_connection'):
child[3].update_connection(connection)
def as_sql(self, qn=None): def as_sql(self, qn=None):
""" """
Returns the SQL version of the where clause and the value to be Returns the SQL version of the where clause and the value to be
@ -88,7 +109,7 @@ class WhereNode(tree.Node):
recursion). recursion).
""" """
if not qn: if not qn:
qn = connection.ops.quote_name qn = self.connection.ops.quote_name
if not self.children: if not self.children:
return None, [] return None, []
result = [] result = []
@ -153,7 +174,7 @@ class WhereNode(tree.Node):
field_sql = lvalue.as_sql(quote_func=qn) field_sql = lvalue.as_sql(quote_func=qn)
if value_annot is datetime.datetime: if value_annot is datetime.datetime:
cast_sql = connection.ops.datetime_cast_sql() cast_sql = self.connection.ops.datetime_cast_sql()
else: else:
cast_sql = '%s' cast_sql = '%s'
@ -163,10 +184,10 @@ class WhereNode(tree.Node):
else: else:
extra = '' extra = ''
if lookup_type in connection.operators: if lookup_type in self.connection.operators:
format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) format = "%s %%s %%s" % (self.connection.ops.lookup_cast(lookup_type),)
return (format % (field_sql, return (format % (field_sql,
connection.operators[lookup_type] % cast_sql, self.connection.operators[lookup_type] % cast_sql,
extra), params) extra), params)
if lookup_type == 'in': if lookup_type == 'in':
@ -179,15 +200,15 @@ class WhereNode(tree.Node):
elif lookup_type in ('range', 'year'): elif lookup_type in ('range', 'year'):
return ('%s BETWEEN %%s and %%s' % field_sql, params) return ('%s BETWEEN %%s and %%s' % field_sql, params)
elif lookup_type in ('month', 'day', 'week_day'): elif lookup_type in ('month', 'day', 'week_day'):
return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql), return ('%s = %%s' % self.connection.ops.date_extract_sql(lookup_type, field_sql),
params) params)
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
return ('%s IS %sNULL' % (field_sql, return ('%s IS %sNULL' % (field_sql,
(not value_annot and 'NOT ' or '')), ()) (not value_annot and 'NOT ' or '')), ())
elif lookup_type == 'search': elif lookup_type == 'search':
return (connection.ops.fulltext_search_sql(field_sql), params) return (self.connection.ops.fulltext_search_sql(field_sql), params)
elif lookup_type in ('regex', 'iregex'): elif lookup_type in ('regex', 'iregex'):
return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params return self.connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
raise TypeError('Invalid lookup_type: %r' % lookup_type) raise TypeError('Invalid lookup_type: %r' % lookup_type)
@ -202,7 +223,7 @@ class WhereNode(tree.Node):
lhs = '%s.%s' % (qn(table_alias), qn(name)) lhs = '%s.%s' % (qn(table_alias), qn(name))
else: else:
lhs = qn(name) lhs = qn(name)
return connection.ops.field_cast_sql(db_type) % lhs return self.connection.ops.field_cast_sql(db_type) % lhs
def relabel_aliases(self, change_map, node=None): def relabel_aliases(self, change_map, node=None):
""" """
@ -257,7 +278,7 @@ class Constraint(object):
def __init__(self, alias, col, field): def __init__(self, alias, col, field):
self.alias, self.col, self.field = alias, col, field self.alias, self.col, self.field = alias, col, field
def process(self, lookup_type, value): def process(self, lookup_type, value, connection):
""" """
Returns a tuple of data suitable for inclusion in a WhereNode Returns a tuple of data suitable for inclusion in a WhereNode
instance. instance.