1
0
mirror of https://github.com/django/django.git synced 2025-07-04 17:59:13 +00:00

[soc2009/multidb] Support multiple databases where one of them has a custom Query class. This needs more testing as I don't have access to Oracle (or DB2, or MSSQL, or Sybase)

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11384 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-08-04 04:37:30 +00:00
parent 3dd211ff03
commit ada3f39dca
5 changed files with 51 additions and 27 deletions

View File

@ -97,7 +97,6 @@ class BaseDatabaseFeatures(object):
# True if django.db.backend.utils.typecast_timestamp is used on values
# returned from dates() calls.
needs_datetime_string_cast = True
uses_custom_query_class = False
empty_fetchmany_value = []
update_can_self_select = True
interprets_empty_strings_as_nulls = False
@ -115,6 +114,10 @@ class BaseDatabaseOperations(object):
a backend performs ordering or calculates the ID of a recently-inserted
row.
"""
def __init__(self):
# this cache is used for backends that provide custom Queyr classes
self._cache = {}
def autoinc_sql(self, table, column):
"""
Returns any SQL needed to support auto-incrementing primary keys, or
@ -277,14 +280,15 @@ class BaseDatabaseOperations(object):
"""
pass
def query_class(self, DefaultQueryClass):
def query_class(self, DefaultQueryClass, subclass=None):
"""
Given the default Query class, returns a custom Query class
to use for this backend. Returns None if a custom Query isn't used.
See also BaseDatabaseFeatures.uses_custom_query_class, which regulates
whether this method is called at all.
to use for this backend. Returns the Query class unmodified if the
backend doesn't need a custom Query clsas.
"""
return None
if subclass is not None:
return subclass
return DefaultQueryClass
def quote_name(self, name):
"""

View File

@ -133,8 +133,14 @@ WHEN (new.%(col_name)s IS NULL)
return u''
return force_unicode(value.read())
def query_class(self, DefaultQueryClass):
return query.query_class(DefaultQueryClass, Database)
def query_class(self, DefaultQueryClass, subclass=None):
if (DefaultQueryClass, subclass) in self._cache:
return self._cache[DefaultQueryClass, subclass]
Query = query.query_class(DefaultQueryClass, Database)
if subclass is not None:
Query = type('Query', (subclsas, Query), {})
self._cache[DefaultQueryClass, subclass] = Query
return Query
def quote_name(self, name):
# SQL92 requires delimited (quoted) names to be case-sensitive. When

View File

@ -40,7 +40,7 @@ class QuerySet(object):
using = None
using = using or DEFAULT_DB_ALIAS
connection = connections[using]
self.query = query or sql.Query(self.model, connection)
self.query = query or connection.ops.query_class(sql.Query)(self.model, connection)
self._result_cache = None
self._iter = None
self._sticky_filter = False
@ -665,6 +665,10 @@ class QuerySet(object):
clone._using = alias
connection = connections[alias]
clone.query.set_connection(connection)
cls = clone.query.get_query_class()
clone.query.__class__ = connection.ops.query_class(
sql.Query, cls is sql.Query and None or cls
)
return clone
###################################
@ -1078,16 +1082,16 @@ def delete_objects(seen_objs, using):
signals.pre_delete.send(sender=cls, instance=instance)
pk_list = [pk for pk,instance in items]
del_query = sql.DeleteQuery(cls, connection)
del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
del_query.delete_batch_related(pk_list)
update_query = sql.UpdateQuery(cls, connection)
update_query = connection.ops.query_class(sql.Query, sql.UpdateQuery)(cls, connection)
for field, model in cls._meta.get_fields_with_model():
if (field.rel and field.null and field.rel.to in seen_objs and
filter(lambda f: f.column == field.rel.get_related_field().column,
field.rel.to._meta.fields)):
if model:
sql.UpdateQuery(model, connection).clear_related(field,
connection.ops.query_class(sql.Query, sql.UpdateQuery)(model, connection).clear_related(field,
pk_list)
else:
update_query.clear_related(field, pk_list)
@ -1098,7 +1102,7 @@ def delete_objects(seen_objs, using):
items.reverse()
pk_list = [pk for pk,instance in items]
del_query = sql.DeleteQuery(cls, connection)
del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
del_query.delete_batch(pk_list)
# Last cleanup; set NULLs where there once was a reference to the
@ -1128,6 +1132,6 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
part of the public API.
"""
connection = connections[using]
query = sql.InsertQuery(model, connection)
query = connection.ops.query_class(sql.Query, sql.InsertQuery)(model, connection)
query.insert_values(values, raw_values)
return query.execute_sql(return_id)

View File

@ -29,9 +29,9 @@ try:
except NameError:
from sets import Set as set # Python 2.3 fallback
__all__ = ['Query', 'BaseQuery']
__all__ = ['Query']
class BaseQuery(object):
class Query(object):
"""
A single SQL query.
"""
@ -151,6 +151,9 @@ class BaseQuery(object):
self.connection = connections[connections.alias_for_settings(
obj_dict['connection_settings'])]
def get_query_class(self):
return Query
def get_meta(self):
"""
Returns the Options instance (the model._meta) from which to start
@ -319,7 +322,7 @@ class BaseQuery(object):
# over the subquery instead.
if self.group_by is not None:
from subqueries import AggregateQuery
query = AggregateQuery(self.model, self.connection)
query = self.connection.ops.query_class(Query, AggregateQuery)(self.model, self.connection)
obj = self.clone()
@ -368,7 +371,7 @@ class BaseQuery(object):
subquery.clear_ordering(True)
subquery.clear_limits()
obj = AggregateQuery(obj.model, obj.connection)
obj = self.connection.ops.query_class(Query, AggregateQuery)(obj.model, obj.connection)
obj.add_subquery(subquery)
obj.add_count_column()
@ -1962,7 +1965,7 @@ class BaseQuery(object):
original exclude filter (filter_expr) and the portion up to the first
N-to-many relation field.
"""
query = Query(self.model, self.connection)
query = self.connection.ops.query_class(Query)(self.model, self.connection)
query.add_filter(filter_expr, can_reuse=can_reuse)
query.bump_prefix()
query.clear_ordering(True)
@ -2389,13 +2392,6 @@ class BaseQuery(object):
return list(result)
return result
# Use the backend's custom Query class if it defines one. Otherwise, use the
# default.
if connection.features.uses_custom_query_class:
Query = connection.ops.query_class(BaseQuery)
else:
Query = BaseQuery
def get_order_dir(field, default='ASC'):
"""
Returns the field name and direction for an order specification. For

View File

@ -17,6 +17,9 @@ class DeleteQuery(Query):
Delete queries are done through this class, since they are more constrained
than general queries.
"""
def get_query_class(self):
return DeleteQuery
def as_sql(self):
"""
Creates the SQL for this query. Returns the SQL string and list of
@ -96,6 +99,9 @@ class UpdateQuery(Query):
super(UpdateQuery, self).__init__(*args, **kwargs)
self._setup_query()
def get_query_class(self):
return UpdateQuery
def _setup_query(self):
"""
Runs on initialization and after cloning. Any attributes that would
@ -279,7 +285,7 @@ class UpdateQuery(Query):
return []
result = []
for model, values in self.related_updates.iteritems():
query = UpdateQuery(model, self.connection)
query = self.connection.ops.query_class(Query, UpdateQuery)(model, self.connection)
query.values = values
if self.related_ids:
query.add_filter(('pk__in', self.related_ids))
@ -294,6 +300,9 @@ class InsertQuery(Query):
self.params = ()
self.return_id = False
def get_query_class(self):
return InsertQuery
def clone(self, klass=None, **kwargs):
extras = {'columns': self.columns[:], 'values': self.values[:],
'params': self.params, 'return_id': self.return_id}
@ -376,6 +385,9 @@ class DateQuery(Query):
if isinstance(elt, Date):
self.date_sql_func = self.connection.ops.date_trunc_sql
def get_query_class(self):
return DateQuery
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
@ -418,6 +430,8 @@ class AggregateQuery(Query):
An AggregateQuery takes another query as a parameter to the FROM
clause and only selects the elements in the provided list.
"""
def get_query_class(self):
return AggregateQuery
def add_subquery(self, query):
self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)