1
0
mirror of https://github.com/django/django.git synced 2025-07-04 09:49:12 +00:00

Fixed #12855 -- QuerySets? with extra where parameters now combine correctly. Thanks, Alex Gaynor.

Backport of r12502 from trunk.


git-svn-id: http://code.djangoproject.com/svn/django/branches/releases/1.1.X@12507 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Justin Bronn 2010-02-23 05:59:04 +00:00
parent 16efc1a92e
commit 6a28f581c0
3 changed files with 17 additions and 22 deletions

View File

@ -18,7 +18,8 @@ from django.db.models.fields import FieldDoesNotExist
from django.db.models.query_utils import select_related_descend from django.db.models.query_utils import select_related_descend
from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
ExtraWhere, AND, OR)
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from datastructures import EmptyResultSet, Empty, MultiJoin from datastructures import EmptyResultSet, Empty, MultiJoin
from constants import * from constants import *
@ -92,8 +93,6 @@ class BaseQuery(object):
self._extra_select_cache = None self._extra_select_cache = None
self.extra_tables = () self.extra_tables = ()
self.extra_where = ()
self.extra_params = ()
self.extra_order_by = () self.extra_order_by = ()
# A tuple that is a set of model field names and either True, if these # A tuple that is a set of model field names and either True, if these
@ -232,8 +231,6 @@ class BaseQuery(object):
else: else:
obj._extra_select_cache = self._extra_select_cache.copy() obj._extra_select_cache = self._extra_select_cache.copy()
obj.extra_tables = self.extra_tables obj.extra_tables = self.extra_tables
obj.extra_where = self.extra_where
obj.extra_params = self.extra_params
obj.extra_order_by = self.extra_order_by obj.extra_order_by = self.extra_order_by
obj.deferred_loading = deepcopy(self.deferred_loading) obj.deferred_loading = deepcopy(self.deferred_loading)
if self.filter_is_sticky and self.used_aliases: if self.filter_is_sticky and self.used_aliases:
@ -418,12 +415,6 @@ class BaseQuery(object):
if where: if where:
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
params.extend(w_params) params.extend(w_params)
if self.extra_where:
if not where:
result.append('WHERE')
else:
result.append('AND')
result.append(' AND '.join(self.extra_where))
grouping, gb_params = self.get_grouping() grouping, gb_params = self.get_grouping()
if grouping: if grouping:
@ -458,7 +449,6 @@ class BaseQuery(object):
result.append('LIMIT %d' % val) result.append('LIMIT %d' % val)
result.append('OFFSET %d' % self.low_mark) result.append('OFFSET %d' % self.low_mark)
params.extend(self.extra_params)
return ' '.join(result), tuple(params) return ' '.join(result), tuple(params)
def as_nested_sql(self): def as_nested_sql(self):
@ -553,9 +543,6 @@ class BaseQuery(object):
if self.extra and rhs.extra: if self.extra and rhs.extra:
raise ValueError("When merging querysets using 'or', you " raise ValueError("When merging querysets using 'or', you "
"cannot have extra(select=...) on both sides.") "cannot have extra(select=...) on both sides.")
if self.extra_where and rhs.extra_where:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(where=...) on both sides.")
self.extra.update(rhs.extra) self.extra.update(rhs.extra)
extra_select_mask = set() extra_select_mask = set()
if self.extra_select_mask is not None: if self.extra_select_mask is not None:
@ -565,8 +552,6 @@ class BaseQuery(object):
if extra_select_mask: if extra_select_mask:
self.set_extra_mask(extra_select_mask) self.set_extra_mask(extra_select_mask)
self.extra_tables += rhs.extra_tables self.extra_tables += rhs.extra_tables
self.extra_where += rhs.extra_where
self.extra_params += rhs.extra_params
# Ordering uses the 'rhs' ordering, unless it has none, in which case # Ordering uses the 'rhs' ordering, unless it has none, in which case
# the current ordering is used. # the current ordering is used.
@ -2181,10 +2166,8 @@ class BaseQuery(object):
select_pairs[name] = (entry, entry_params) select_pairs[name] = (entry, entry_params)
# This is order preserving, since self.extra_select is a SortedDict. # This is order preserving, since self.extra_select is a SortedDict.
self.extra.update(select_pairs) self.extra.update(select_pairs)
if where: if where or params:
self.extra_where += tuple(where) self.where.add(ExtraWhere(where, params), AND)
if params:
self.extra_params += tuple(params)
if tables: if tables:
self.extra_tables += tuple(tables) self.extra_tables += tuple(tables)
if order_by: if order_by:

View File

@ -216,7 +216,7 @@ class WhereNode(tree.Node):
child.relabel_aliases(change_map) child.relabel_aliases(change_map)
elif isinstance(child, tree.Node): elif isinstance(child, tree.Node):
self.relabel_aliases(change_map, child) self.relabel_aliases(change_map, child)
else: elif isinstance(child, (list, tuple)):
if isinstance(child[0], (list, tuple)): if isinstance(child[0], (list, tuple)):
elt = list(child[0]) elt = list(child[0])
if elt[0] in change_map: if elt[0] in change_map:
@ -249,6 +249,14 @@ class NothingNode(object):
def relabel_aliases(self, change_map, node=None): def relabel_aliases(self, change_map, node=None):
return return
class ExtraWhere(object):
def __init__(self, sqls, params):
self.sqls = sqls
self.params = params
def as_sql(self, qn=None):
return " AND ".join(self.sqls), tuple(self.params or ())
class Constraint(object): class Constraint(object):
""" """
An object that can be passed to WhereNode.add() and knows how to An object that can be passed to WhereNode.add() and knows how to

View File

@ -208,4 +208,8 @@ True
>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1})) >>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
[<TestObject: TestObject: first,second,third>] [<TestObject: TestObject: first,second,third>]
>>> pk = TestObject.objects.get().pk
>>> TestObject.objects.filter(pk=pk) | TestObject.objects.extra(where=["id > %s"], params=[pk])
[<TestObject: TestObject: first,second,third>]
"""} """}