diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index ac50115b48..8bf29032c1 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -97,7 +97,6 @@ class BaseDatabaseFeatures(object): # True if django.db.backend.utils.typecast_timestamp is used on values # returned from dates() calls. needs_datetime_string_cast = True - uses_custom_query_class = False empty_fetchmany_value = [] update_can_self_select = True interprets_empty_strings_as_nulls = False @@ -115,6 +114,10 @@ class BaseDatabaseOperations(object): a backend performs ordering or calculates the ID of a recently-inserted row. """ + def __init__(self): + # this cache is used for backends that provide custom Queyr classes + self._cache = {} + def autoinc_sql(self, table, column): """ Returns any SQL needed to support auto-incrementing primary keys, or @@ -277,14 +280,15 @@ class BaseDatabaseOperations(object): """ pass - def query_class(self, DefaultQueryClass): + def query_class(self, DefaultQueryClass, subclass=None): """ Given the default Query class, returns a custom Query class - to use for this backend. Returns None if a custom Query isn't used. - See also BaseDatabaseFeatures.uses_custom_query_class, which regulates - whether this method is called at all. + to use for this backend. Returns the Query class unmodified if the + backend doesn't need a custom Query clsas. """ - return None + if subclass is not None: + return subclass + return DefaultQueryClass def quote_name(self, name): """ diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index ab33e92243..0ce74d0281 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -133,8 +133,14 @@ WHEN (new.%(col_name)s IS NULL) return u'' return force_unicode(value.read()) - def query_class(self, DefaultQueryClass): - return query.query_class(DefaultQueryClass, Database) + def query_class(self, DefaultQueryClass, subclass=None): + if (DefaultQueryClass, subclass) in self._cache: + return self._cache[DefaultQueryClass, subclass] + Query = query.query_class(DefaultQueryClass, Database) + if subclass is not None: + Query = type('Query', (subclsas, Query), {}) + self._cache[DefaultQueryClass, subclass] = Query + return Query def quote_name(self, name): # SQL92 requires delimited (quoted) names to be case-sensitive. When diff --git a/django/db/models/query.py b/django/db/models/query.py index 4c80be1bc9..4e7bacceca 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -40,7 +40,7 @@ class QuerySet(object): using = None using = using or DEFAULT_DB_ALIAS connection = connections[using] - self.query = query or sql.Query(self.model, connection) + self.query = query or connection.ops.query_class(sql.Query)(self.model, connection) self._result_cache = None self._iter = None self._sticky_filter = False @@ -665,6 +665,10 @@ class QuerySet(object): clone._using = alias connection = connections[alias] clone.query.set_connection(connection) + cls = clone.query.get_query_class() + clone.query.__class__ = connection.ops.query_class( + sql.Query, cls is sql.Query and None or cls + ) return clone ################################### @@ -1078,16 +1082,16 @@ def delete_objects(seen_objs, using): signals.pre_delete.send(sender=cls, instance=instance) pk_list = [pk for pk,instance in items] - del_query = sql.DeleteQuery(cls, connection) + del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection) del_query.delete_batch_related(pk_list) - update_query = sql.UpdateQuery(cls, connection) + update_query = connection.ops.query_class(sql.Query, sql.UpdateQuery)(cls, connection) for field, model in cls._meta.get_fields_with_model(): if (field.rel and field.null and field.rel.to in seen_objs and filter(lambda f: f.column == field.rel.get_related_field().column, field.rel.to._meta.fields)): if model: - sql.UpdateQuery(model, connection).clear_related(field, + connection.ops.query_class(sql.Query, sql.UpdateQuery)(model, connection).clear_related(field, pk_list) else: update_query.clear_related(field, pk_list) @@ -1098,7 +1102,7 @@ def delete_objects(seen_objs, using): items.reverse() pk_list = [pk for pk,instance in items] - del_query = sql.DeleteQuery(cls, connection) + del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection) del_query.delete_batch(pk_list) # Last cleanup; set NULLs where there once was a reference to the @@ -1128,6 +1132,6 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None): part of the public API. """ connection = connections[using] - query = sql.InsertQuery(model, connection) + query = connection.ops.query_class(sql.Query, sql.InsertQuery)(model, connection) query.insert_values(values, raw_values) return query.execute_sql(return_id) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 0de66ee8f7..dce7e5a4a1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -29,9 +29,9 @@ try: except NameError: from sets import Set as set # Python 2.3 fallback -__all__ = ['Query', 'BaseQuery'] +__all__ = ['Query'] -class BaseQuery(object): +class Query(object): """ A single SQL query. """ @@ -151,6 +151,9 @@ class BaseQuery(object): self.connection = connections[connections.alias_for_settings( obj_dict['connection_settings'])] + def get_query_class(self): + return Query + def get_meta(self): """ Returns the Options instance (the model._meta) from which to start @@ -319,7 +322,7 @@ class BaseQuery(object): # over the subquery instead. if self.group_by is not None: from subqueries import AggregateQuery - query = AggregateQuery(self.model, self.connection) + query = self.connection.ops.query_class(Query, AggregateQuery)(self.model, self.connection) obj = self.clone() @@ -368,7 +371,7 @@ class BaseQuery(object): subquery.clear_ordering(True) subquery.clear_limits() - obj = AggregateQuery(obj.model, obj.connection) + obj = self.connection.ops.query_class(Query, AggregateQuery)(obj.model, obj.connection) obj.add_subquery(subquery) obj.add_count_column() @@ -1962,7 +1965,7 @@ class BaseQuery(object): original exclude filter (filter_expr) and the portion up to the first N-to-many relation field. """ - query = Query(self.model, self.connection) + query = self.connection.ops.query_class(Query)(self.model, self.connection) query.add_filter(filter_expr, can_reuse=can_reuse) query.bump_prefix() query.clear_ordering(True) @@ -2389,13 +2392,6 @@ class BaseQuery(object): return list(result) return result -# Use the backend's custom Query class if it defines one. Otherwise, use the -# default. -if connection.features.uses_custom_query_class: - Query = connection.ops.query_class(BaseQuery) -else: - Query = BaseQuery - def get_order_dir(field, default='ASC'): """ Returns the field name and direction for an order specification. For diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 03cfd12a1d..71654e5be9 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -17,6 +17,9 @@ class DeleteQuery(Query): Delete queries are done through this class, since they are more constrained than general queries. """ + def get_query_class(self): + return DeleteQuery + def as_sql(self): """ Creates the SQL for this query. Returns the SQL string and list of @@ -96,6 +99,9 @@ class UpdateQuery(Query): super(UpdateQuery, self).__init__(*args, **kwargs) self._setup_query() + def get_query_class(self): + return UpdateQuery + def _setup_query(self): """ Runs on initialization and after cloning. Any attributes that would @@ -279,7 +285,7 @@ class UpdateQuery(Query): return [] result = [] for model, values in self.related_updates.iteritems(): - query = UpdateQuery(model, self.connection) + query = self.connection.ops.query_class(Query, UpdateQuery)(model, self.connection) query.values = values if self.related_ids: query.add_filter(('pk__in', self.related_ids)) @@ -294,6 +300,9 @@ class InsertQuery(Query): self.params = () self.return_id = False + def get_query_class(self): + return InsertQuery + def clone(self, klass=None, **kwargs): extras = {'columns': self.columns[:], 'values': self.values[:], 'params': self.params, 'return_id': self.return_id} @@ -376,6 +385,9 @@ class DateQuery(Query): if isinstance(elt, Date): self.date_sql_func = self.connection.ops.date_trunc_sql + def get_query_class(self): + return DateQuery + def results_iter(self): """ Returns an iterator over the results from executing this query. @@ -418,6 +430,8 @@ class AggregateQuery(Query): An AggregateQuery takes another query as a parameter to the FROM clause and only selects the elements in the provided list. """ + def get_query_class(self): + return AggregateQuery def add_subquery(self, query): self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)