1
0
mirror of https://github.com/django/django.git synced 2025-07-05 10:19:20 +00:00

[soc2009/multidb] Support multiple databases where one of them has a custom Query class. This needs more testing as I don't have access to Oracle (or DB2, or MSSQL, or Sybase)

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11384 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-08-04 04:37:30 +00:00
parent 3dd211ff03
commit ada3f39dca
5 changed files with 51 additions and 27 deletions

View File

@ -97,7 +97,6 @@ class BaseDatabaseFeatures(object):
# True if django.db.backend.utils.typecast_timestamp is used on values # True if django.db.backend.utils.typecast_timestamp is used on values
# returned from dates() calls. # returned from dates() calls.
needs_datetime_string_cast = True needs_datetime_string_cast = True
uses_custom_query_class = False
empty_fetchmany_value = [] empty_fetchmany_value = []
update_can_self_select = True update_can_self_select = True
interprets_empty_strings_as_nulls = False interprets_empty_strings_as_nulls = False
@ -115,6 +114,10 @@ class BaseDatabaseOperations(object):
a backend performs ordering or calculates the ID of a recently-inserted a backend performs ordering or calculates the ID of a recently-inserted
row. row.
""" """
def __init__(self):
# this cache is used for backends that provide custom Queyr classes
self._cache = {}
def autoinc_sql(self, table, column): def autoinc_sql(self, table, column):
""" """
Returns any SQL needed to support auto-incrementing primary keys, or Returns any SQL needed to support auto-incrementing primary keys, or
@ -277,14 +280,15 @@ class BaseDatabaseOperations(object):
""" """
pass pass
def query_class(self, DefaultQueryClass): def query_class(self, DefaultQueryClass, subclass=None):
""" """
Given the default Query class, returns a custom Query class Given the default Query class, returns a custom Query class
to use for this backend. Returns None if a custom Query isn't used. to use for this backend. Returns the Query class unmodified if the
See also BaseDatabaseFeatures.uses_custom_query_class, which regulates backend doesn't need a custom Query clsas.
whether this method is called at all.
""" """
return None if subclass is not None:
return subclass
return DefaultQueryClass
def quote_name(self, name): def quote_name(self, name):
""" """

View File

@ -133,8 +133,14 @@ WHEN (new.%(col_name)s IS NULL)
return u'' return u''
return force_unicode(value.read()) return force_unicode(value.read())
def query_class(self, DefaultQueryClass): def query_class(self, DefaultQueryClass, subclass=None):
return query.query_class(DefaultQueryClass, Database) if (DefaultQueryClass, subclass) in self._cache:
return self._cache[DefaultQueryClass, subclass]
Query = query.query_class(DefaultQueryClass, Database)
if subclass is not None:
Query = type('Query', (subclsas, Query), {})
self._cache[DefaultQueryClass, subclass] = Query
return Query
def quote_name(self, name): def quote_name(self, name):
# SQL92 requires delimited (quoted) names to be case-sensitive. When # SQL92 requires delimited (quoted) names to be case-sensitive. When

View File

@ -40,7 +40,7 @@ class QuerySet(object):
using = None using = None
using = using or DEFAULT_DB_ALIAS using = using or DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
self.query = query or sql.Query(self.model, connection) self.query = query or connection.ops.query_class(sql.Query)(self.model, connection)
self._result_cache = None self._result_cache = None
self._iter = None self._iter = None
self._sticky_filter = False self._sticky_filter = False
@ -665,6 +665,10 @@ class QuerySet(object):
clone._using = alias clone._using = alias
connection = connections[alias] connection = connections[alias]
clone.query.set_connection(connection) clone.query.set_connection(connection)
cls = clone.query.get_query_class()
clone.query.__class__ = connection.ops.query_class(
sql.Query, cls is sql.Query and None or cls
)
return clone return clone
################################### ###################################
@ -1078,16 +1082,16 @@ def delete_objects(seen_objs, using):
signals.pre_delete.send(sender=cls, instance=instance) signals.pre_delete.send(sender=cls, instance=instance)
pk_list = [pk for pk,instance in items] pk_list = [pk for pk,instance in items]
del_query = sql.DeleteQuery(cls, connection) del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
del_query.delete_batch_related(pk_list) del_query.delete_batch_related(pk_list)
update_query = sql.UpdateQuery(cls, connection) update_query = connection.ops.query_class(sql.Query, sql.UpdateQuery)(cls, connection)
for field, model in cls._meta.get_fields_with_model(): for field, model in cls._meta.get_fields_with_model():
if (field.rel and field.null and field.rel.to in seen_objs and if (field.rel and field.null and field.rel.to in seen_objs and
filter(lambda f: f.column == field.rel.get_related_field().column, filter(lambda f: f.column == field.rel.get_related_field().column,
field.rel.to._meta.fields)): field.rel.to._meta.fields)):
if model: if model:
sql.UpdateQuery(model, connection).clear_related(field, connection.ops.query_class(sql.Query, sql.UpdateQuery)(model, connection).clear_related(field,
pk_list) pk_list)
else: else:
update_query.clear_related(field, pk_list) update_query.clear_related(field, pk_list)
@ -1098,7 +1102,7 @@ def delete_objects(seen_objs, using):
items.reverse() items.reverse()
pk_list = [pk for pk,instance in items] pk_list = [pk for pk,instance in items]
del_query = sql.DeleteQuery(cls, connection) del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
del_query.delete_batch(pk_list) del_query.delete_batch(pk_list)
# Last cleanup; set NULLs where there once was a reference to the # Last cleanup; set NULLs where there once was a reference to the
@ -1128,6 +1132,6 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
part of the public API. part of the public API.
""" """
connection = connections[using] connection = connections[using]
query = sql.InsertQuery(model, connection) query = connection.ops.query_class(sql.Query, sql.InsertQuery)(model, connection)
query.insert_values(values, raw_values) query.insert_values(values, raw_values)
return query.execute_sql(return_id) return query.execute_sql(return_id)

View File

@ -29,9 +29,9 @@ try:
except NameError: except NameError:
from sets import Set as set # Python 2.3 fallback from sets import Set as set # Python 2.3 fallback
__all__ = ['Query', 'BaseQuery'] __all__ = ['Query']
class BaseQuery(object): class Query(object):
""" """
A single SQL query. A single SQL query.
""" """
@ -151,6 +151,9 @@ class BaseQuery(object):
self.connection = connections[connections.alias_for_settings( self.connection = connections[connections.alias_for_settings(
obj_dict['connection_settings'])] obj_dict['connection_settings'])]
def get_query_class(self):
return Query
def get_meta(self): def get_meta(self):
""" """
Returns the Options instance (the model._meta) from which to start Returns the Options instance (the model._meta) from which to start
@ -319,7 +322,7 @@ class BaseQuery(object):
# over the subquery instead. # over the subquery instead.
if self.group_by is not None: if self.group_by is not None:
from subqueries import AggregateQuery from subqueries import AggregateQuery
query = AggregateQuery(self.model, self.connection) query = self.connection.ops.query_class(Query, AggregateQuery)(self.model, self.connection)
obj = self.clone() obj = self.clone()
@ -368,7 +371,7 @@ class BaseQuery(object):
subquery.clear_ordering(True) subquery.clear_ordering(True)
subquery.clear_limits() subquery.clear_limits()
obj = AggregateQuery(obj.model, obj.connection) obj = self.connection.ops.query_class(Query, AggregateQuery)(obj.model, obj.connection)
obj.add_subquery(subquery) obj.add_subquery(subquery)
obj.add_count_column() obj.add_count_column()
@ -1962,7 +1965,7 @@ class BaseQuery(object):
original exclude filter (filter_expr) and the portion up to the first original exclude filter (filter_expr) and the portion up to the first
N-to-many relation field. N-to-many relation field.
""" """
query = Query(self.model, self.connection) query = self.connection.ops.query_class(Query)(self.model, self.connection)
query.add_filter(filter_expr, can_reuse=can_reuse) query.add_filter(filter_expr, can_reuse=can_reuse)
query.bump_prefix() query.bump_prefix()
query.clear_ordering(True) query.clear_ordering(True)
@ -2389,13 +2392,6 @@ class BaseQuery(object):
return list(result) return list(result)
return result return result
# Use the backend's custom Query class if it defines one. Otherwise, use the
# default.
if connection.features.uses_custom_query_class:
Query = connection.ops.query_class(BaseQuery)
else:
Query = BaseQuery
def get_order_dir(field, default='ASC'): def get_order_dir(field, default='ASC'):
""" """
Returns the field name and direction for an order specification. For Returns the field name and direction for an order specification. For

View File

@ -17,6 +17,9 @@ class DeleteQuery(Query):
Delete queries are done through this class, since they are more constrained Delete queries are done through this class, since they are more constrained
than general queries. than general queries.
""" """
def get_query_class(self):
return DeleteQuery
def as_sql(self): def as_sql(self):
""" """
Creates the SQL for this query. Returns the SQL string and list of Creates the SQL for this query. Returns the SQL string and list of
@ -96,6 +99,9 @@ class UpdateQuery(Query):
super(UpdateQuery, self).__init__(*args, **kwargs) super(UpdateQuery, self).__init__(*args, **kwargs)
self._setup_query() self._setup_query()
def get_query_class(self):
return UpdateQuery
def _setup_query(self): def _setup_query(self):
""" """
Runs on initialization and after cloning. Any attributes that would Runs on initialization and after cloning. Any attributes that would
@ -279,7 +285,7 @@ class UpdateQuery(Query):
return [] return []
result = [] result = []
for model, values in self.related_updates.iteritems(): for model, values in self.related_updates.iteritems():
query = UpdateQuery(model, self.connection) query = self.connection.ops.query_class(Query, UpdateQuery)(model, self.connection)
query.values = values query.values = values
if self.related_ids: if self.related_ids:
query.add_filter(('pk__in', self.related_ids)) query.add_filter(('pk__in', self.related_ids))
@ -294,6 +300,9 @@ class InsertQuery(Query):
self.params = () self.params = ()
self.return_id = False self.return_id = False
def get_query_class(self):
return InsertQuery
def clone(self, klass=None, **kwargs): def clone(self, klass=None, **kwargs):
extras = {'columns': self.columns[:], 'values': self.values[:], extras = {'columns': self.columns[:], 'values': self.values[:],
'params': self.params, 'return_id': self.return_id} 'params': self.params, 'return_id': self.return_id}
@ -376,6 +385,9 @@ class DateQuery(Query):
if isinstance(elt, Date): if isinstance(elt, Date):
self.date_sql_func = self.connection.ops.date_trunc_sql self.date_sql_func = self.connection.ops.date_trunc_sql
def get_query_class(self):
return DateQuery
def results_iter(self): def results_iter(self):
""" """
Returns an iterator over the results from executing this query. Returns an iterator over the results from executing this query.
@ -418,6 +430,8 @@ class AggregateQuery(Query):
An AggregateQuery takes another query as a parameter to the FROM An AggregateQuery takes another query as a parameter to the FROM
clause and only selects the elements in the provided list. clause and only selects the elements in the provided list.
""" """
def get_query_class(self):
return AggregateQuery
def add_subquery(self, query): def add_subquery(self, query):
self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)