mirror of
https://github.com/django/django.git
synced 2025-07-04 17:59:13 +00:00
[soc2009/multidb] Support multiple databases where one of them has a custom Query class. This needs more testing as I don't have access to Oracle (or DB2, or MSSQL, or Sybase)
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11384 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
3dd211ff03
commit
ada3f39dca
@ -97,7 +97,6 @@ class BaseDatabaseFeatures(object):
|
||||
# True if django.db.backend.utils.typecast_timestamp is used on values
|
||||
# returned from dates() calls.
|
||||
needs_datetime_string_cast = True
|
||||
uses_custom_query_class = False
|
||||
empty_fetchmany_value = []
|
||||
update_can_self_select = True
|
||||
interprets_empty_strings_as_nulls = False
|
||||
@ -115,6 +114,10 @@ class BaseDatabaseOperations(object):
|
||||
a backend performs ordering or calculates the ID of a recently-inserted
|
||||
row.
|
||||
"""
|
||||
def __init__(self):
|
||||
# this cache is used for backends that provide custom Queyr classes
|
||||
self._cache = {}
|
||||
|
||||
def autoinc_sql(self, table, column):
|
||||
"""
|
||||
Returns any SQL needed to support auto-incrementing primary keys, or
|
||||
@ -277,14 +280,15 @@ class BaseDatabaseOperations(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def query_class(self, DefaultQueryClass):
|
||||
def query_class(self, DefaultQueryClass, subclass=None):
|
||||
"""
|
||||
Given the default Query class, returns a custom Query class
|
||||
to use for this backend. Returns None if a custom Query isn't used.
|
||||
See also BaseDatabaseFeatures.uses_custom_query_class, which regulates
|
||||
whether this method is called at all.
|
||||
to use for this backend. Returns the Query class unmodified if the
|
||||
backend doesn't need a custom Query clsas.
|
||||
"""
|
||||
return None
|
||||
if subclass is not None:
|
||||
return subclass
|
||||
return DefaultQueryClass
|
||||
|
||||
def quote_name(self, name):
|
||||
"""
|
||||
|
@ -133,8 +133,14 @@ WHEN (new.%(col_name)s IS NULL)
|
||||
return u''
|
||||
return force_unicode(value.read())
|
||||
|
||||
def query_class(self, DefaultQueryClass):
|
||||
return query.query_class(DefaultQueryClass, Database)
|
||||
def query_class(self, DefaultQueryClass, subclass=None):
|
||||
if (DefaultQueryClass, subclass) in self._cache:
|
||||
return self._cache[DefaultQueryClass, subclass]
|
||||
Query = query.query_class(DefaultQueryClass, Database)
|
||||
if subclass is not None:
|
||||
Query = type('Query', (subclsas, Query), {})
|
||||
self._cache[DefaultQueryClass, subclass] = Query
|
||||
return Query
|
||||
|
||||
def quote_name(self, name):
|
||||
# SQL92 requires delimited (quoted) names to be case-sensitive. When
|
||||
|
@ -40,7 +40,7 @@ class QuerySet(object):
|
||||
using = None
|
||||
using = using or DEFAULT_DB_ALIAS
|
||||
connection = connections[using]
|
||||
self.query = query or sql.Query(self.model, connection)
|
||||
self.query = query or connection.ops.query_class(sql.Query)(self.model, connection)
|
||||
self._result_cache = None
|
||||
self._iter = None
|
||||
self._sticky_filter = False
|
||||
@ -665,6 +665,10 @@ class QuerySet(object):
|
||||
clone._using = alias
|
||||
connection = connections[alias]
|
||||
clone.query.set_connection(connection)
|
||||
cls = clone.query.get_query_class()
|
||||
clone.query.__class__ = connection.ops.query_class(
|
||||
sql.Query, cls is sql.Query and None or cls
|
||||
)
|
||||
return clone
|
||||
|
||||
###################################
|
||||
@ -1078,16 +1082,16 @@ def delete_objects(seen_objs, using):
|
||||
signals.pre_delete.send(sender=cls, instance=instance)
|
||||
|
||||
pk_list = [pk for pk,instance in items]
|
||||
del_query = sql.DeleteQuery(cls, connection)
|
||||
del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
|
||||
del_query.delete_batch_related(pk_list)
|
||||
|
||||
update_query = sql.UpdateQuery(cls, connection)
|
||||
update_query = connection.ops.query_class(sql.Query, sql.UpdateQuery)(cls, connection)
|
||||
for field, model in cls._meta.get_fields_with_model():
|
||||
if (field.rel and field.null and field.rel.to in seen_objs and
|
||||
filter(lambda f: f.column == field.rel.get_related_field().column,
|
||||
field.rel.to._meta.fields)):
|
||||
if model:
|
||||
sql.UpdateQuery(model, connection).clear_related(field,
|
||||
connection.ops.query_class(sql.Query, sql.UpdateQuery)(model, connection).clear_related(field,
|
||||
pk_list)
|
||||
else:
|
||||
update_query.clear_related(field, pk_list)
|
||||
@ -1098,7 +1102,7 @@ def delete_objects(seen_objs, using):
|
||||
items.reverse()
|
||||
|
||||
pk_list = [pk for pk,instance in items]
|
||||
del_query = sql.DeleteQuery(cls, connection)
|
||||
del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
|
||||
del_query.delete_batch(pk_list)
|
||||
|
||||
# Last cleanup; set NULLs where there once was a reference to the
|
||||
@ -1128,6 +1132,6 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
|
||||
part of the public API.
|
||||
"""
|
||||
connection = connections[using]
|
||||
query = sql.InsertQuery(model, connection)
|
||||
query = connection.ops.query_class(sql.Query, sql.InsertQuery)(model, connection)
|
||||
query.insert_values(values, raw_values)
|
||||
return query.execute_sql(return_id)
|
||||
|
@ -29,9 +29,9 @@ try:
|
||||
except NameError:
|
||||
from sets import Set as set # Python 2.3 fallback
|
||||
|
||||
__all__ = ['Query', 'BaseQuery']
|
||||
__all__ = ['Query']
|
||||
|
||||
class BaseQuery(object):
|
||||
class Query(object):
|
||||
"""
|
||||
A single SQL query.
|
||||
"""
|
||||
@ -151,6 +151,9 @@ class BaseQuery(object):
|
||||
self.connection = connections[connections.alias_for_settings(
|
||||
obj_dict['connection_settings'])]
|
||||
|
||||
def get_query_class(self):
|
||||
return Query
|
||||
|
||||
def get_meta(self):
|
||||
"""
|
||||
Returns the Options instance (the model._meta) from which to start
|
||||
@ -319,7 +322,7 @@ class BaseQuery(object):
|
||||
# over the subquery instead.
|
||||
if self.group_by is not None:
|
||||
from subqueries import AggregateQuery
|
||||
query = AggregateQuery(self.model, self.connection)
|
||||
query = self.connection.ops.query_class(Query, AggregateQuery)(self.model, self.connection)
|
||||
|
||||
obj = self.clone()
|
||||
|
||||
@ -368,7 +371,7 @@ class BaseQuery(object):
|
||||
subquery.clear_ordering(True)
|
||||
subquery.clear_limits()
|
||||
|
||||
obj = AggregateQuery(obj.model, obj.connection)
|
||||
obj = self.connection.ops.query_class(Query, AggregateQuery)(obj.model, obj.connection)
|
||||
obj.add_subquery(subquery)
|
||||
|
||||
obj.add_count_column()
|
||||
@ -1962,7 +1965,7 @@ class BaseQuery(object):
|
||||
original exclude filter (filter_expr) and the portion up to the first
|
||||
N-to-many relation field.
|
||||
"""
|
||||
query = Query(self.model, self.connection)
|
||||
query = self.connection.ops.query_class(Query)(self.model, self.connection)
|
||||
query.add_filter(filter_expr, can_reuse=can_reuse)
|
||||
query.bump_prefix()
|
||||
query.clear_ordering(True)
|
||||
@ -2389,13 +2392,6 @@ class BaseQuery(object):
|
||||
return list(result)
|
||||
return result
|
||||
|
||||
# Use the backend's custom Query class if it defines one. Otherwise, use the
|
||||
# default.
|
||||
if connection.features.uses_custom_query_class:
|
||||
Query = connection.ops.query_class(BaseQuery)
|
||||
else:
|
||||
Query = BaseQuery
|
||||
|
||||
def get_order_dir(field, default='ASC'):
|
||||
"""
|
||||
Returns the field name and direction for an order specification. For
|
||||
|
@ -17,6 +17,9 @@ class DeleteQuery(Query):
|
||||
Delete queries are done through this class, since they are more constrained
|
||||
than general queries.
|
||||
"""
|
||||
def get_query_class(self):
|
||||
return DeleteQuery
|
||||
|
||||
def as_sql(self):
|
||||
"""
|
||||
Creates the SQL for this query. Returns the SQL string and list of
|
||||
@ -96,6 +99,9 @@ class UpdateQuery(Query):
|
||||
super(UpdateQuery, self).__init__(*args, **kwargs)
|
||||
self._setup_query()
|
||||
|
||||
def get_query_class(self):
|
||||
return UpdateQuery
|
||||
|
||||
def _setup_query(self):
|
||||
"""
|
||||
Runs on initialization and after cloning. Any attributes that would
|
||||
@ -279,7 +285,7 @@ class UpdateQuery(Query):
|
||||
return []
|
||||
result = []
|
||||
for model, values in self.related_updates.iteritems():
|
||||
query = UpdateQuery(model, self.connection)
|
||||
query = self.connection.ops.query_class(Query, UpdateQuery)(model, self.connection)
|
||||
query.values = values
|
||||
if self.related_ids:
|
||||
query.add_filter(('pk__in', self.related_ids))
|
||||
@ -294,6 +300,9 @@ class InsertQuery(Query):
|
||||
self.params = ()
|
||||
self.return_id = False
|
||||
|
||||
def get_query_class(self):
|
||||
return InsertQuery
|
||||
|
||||
def clone(self, klass=None, **kwargs):
|
||||
extras = {'columns': self.columns[:], 'values': self.values[:],
|
||||
'params': self.params, 'return_id': self.return_id}
|
||||
@ -376,6 +385,9 @@ class DateQuery(Query):
|
||||
if isinstance(elt, Date):
|
||||
self.date_sql_func = self.connection.ops.date_trunc_sql
|
||||
|
||||
def get_query_class(self):
|
||||
return DateQuery
|
||||
|
||||
def results_iter(self):
|
||||
"""
|
||||
Returns an iterator over the results from executing this query.
|
||||
@ -418,6 +430,8 @@ class AggregateQuery(Query):
|
||||
An AggregateQuery takes another query as a parameter to the FROM
|
||||
clause and only selects the elements in the provided list.
|
||||
"""
|
||||
def get_query_class(self):
|
||||
return AggregateQuery
|
||||
|
||||
def add_subquery(self, query):
|
||||
self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user