diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index f2d6b577f2..1343f17209 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -77,8 +77,9 @@ class Query(object): LOUTER = 'LEFT OUTER JOIN' alias_prefix = 'T' + query_terms = QUERY_TERMS - def __init__(self, model, connection): + def __init__(self, model, connection, where=WhereNode): self.model = model self.connection = connection self.alias_map = {} # Maps alias to table name @@ -91,7 +92,8 @@ class Query(object): # SQL-related attributes self.select = [] self.tables = [] # Aliases in the order they are created. - self.where = WhereNode() + self.where = where() + self.where_class = where self.group_by = [] self.having = [] self.order_by = [] @@ -156,6 +158,7 @@ class Query(object): obj.select = self.select[:] obj.tables = self.tables[:] obj.where = copy.deepcopy(self.where) + obj.where_class = self.where_class obj.group_by = self.group_by[:] obj.having = self.having[:] obj.order_by = self.order_by[:] @@ -192,7 +195,7 @@ class Query(object): obj.clear_limits() obj.select_related = False if obj.distinct and len(obj.select) > 1: - obj = self.clone(CountQuery, _query=obj, where=WhereNode(), + obj = self.clone(CountQuery, _query=obj, where=self.where_class(), distinct=False) obj.select = [] obj.extra_select = SortedDict() @@ -319,12 +322,12 @@ class Query(object): elif self.where: # rhs has an empty where clause. Make it match everything (see # above for reasoning). - w = WhereNode() + w = self.where_class() alias = self.join((None, self.model._meta.db_table, None, None)) pk = self.model._meta.pk w.add(EverythingNode(), AND) else: - w = WhereNode() + w = self.where_class() self.where.add(w, connector) # Selection columns and extra extensions are those provided by 'rhs'. @@ -704,7 +707,7 @@ class Query(object): raise TypeError("Cannot parse keyword query %r" % arg) # Work out the lookup type and remove it from 'parts', if necessary. - if len(parts) == 1 or parts[-1] not in QUERY_TERMS: + if len(parts) == 1 or parts[-1] not in self.query_terms: lookup_type = 'exact' else: lookup_type = parts.pop() @@ -1109,7 +1112,7 @@ class DeleteQuery(Query): for related in cls._meta.get_all_related_many_to_many_objects(): if not isinstance(related.field, generic.GenericRelation): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode() + where = self.where_class() where.add((None, related.field.m2m_reverse_name(), related.field, 'in', pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), @@ -1117,14 +1120,14 @@ class DeleteQuery(Query): self.do_query(related.field.m2m_db_table(), where) for f in cls._meta.many_to_many: - w1 = WhereNode() + w1 = self.where_class() if isinstance(f, generic.GenericRelation): from django.contrib.contenttypes.models import ContentType field = f.rel.to._meta.get_field(f.content_type_field_name) w1.add((None, field.column, field, 'exact', ContentType.objects.get_for_model(cls).id), AND) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode() + where = self.where_class() where.add((None, f.m2m_column_name(), f, 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -1141,7 +1144,7 @@ class DeleteQuery(Query): lot of values in pk_list. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode() + where = self.where_class() field = self.model._meta.pk where.add((None, field.column, field, 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -1185,7 +1188,7 @@ class UpdateQuery(Query): This is used by the QuerySet.delete_objects() method. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode() + where = self.where_class() f = self.model._meta.pk where.add((None, f.column, f, 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),