diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b5e99fb4c0..2548d63100 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -11,6 +11,7 @@ from copy import deepcopy from django.utils.tree import Node from django.utils.datastructures import SortedDict +from django.utils.encoding import force_unicode from django.db import connection from django.db.models import signals from django.db.models.fields import FieldDoesNotExist @@ -77,8 +78,7 @@ class Query(object): # These are for extensions. The contents are more or less appended # verbatim to the appropriate clause. - self.extra_select = {} # Maps col_alias -> col_sql. - self.extra_select_params = () + self.extra_select = SortedDict() # Maps col_alias -> col_sql. self.extra_tables = () self.extra_where = () self.extra_params = () @@ -181,7 +181,6 @@ class Query(object): obj.related_select_cols = [] obj.max_depth = self.max_depth obj.extra_select = self.extra_select.copy() - obj.extra_select_params = self.extra_select_params obj.extra_tables = self.extra_tables obj.extra_where = self.extra_where obj.extra_params = self.extra_params @@ -226,7 +225,7 @@ class Query(object): obj = self.clone(CountQuery, _query=obj, where=self.where_class(), distinct=False) obj.select = [] - obj.extra_select = {} + obj.extra_select = SortedDict() obj.add_count_column() data = obj.execute_sql(SINGLE) if not data: @@ -259,7 +258,9 @@ class Query(object): from_, f_params = self.get_from_clause() where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias) - params = list(self.extra_select_params) + params = [] + for val in self.extra_select.itervalues(): + params.extend(val[1]) result = ['SELECT'] if self.distinct: @@ -413,7 +414,7 @@ class Query(object): """ qn = self.quote_name_unless_alias qn2 = self.connection.ops.quote_name - result = ['(%s) AS %s' % (col, qn2(alias)) for alias, col in self.extra_select.iteritems()] + result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.extra_select.iteritems()] aliases = set(self.extra_select.keys()) if with_aliases: col_aliases = aliases.copy() @@ -1510,7 +1511,6 @@ class Query(object): self.select = [select] self.select_fields = [None] self.extra_select = {} - self.extra_select_params = () def add_select_related(self, fields): """ @@ -1533,14 +1533,25 @@ class Query(object): to the query. """ if select: - # The extra select might be ordered (because it will be accepting - # parameters). - if (isinstance(select, SortedDict) and - not isinstance(self.extra_select, SortedDict)): - self.extra_select = SortedDict(self.extra_select) - self.extra_select.update(select) - if select_params: - self.extra_select_params += tuple(select_params) + # We need to pair any placeholder markers in the 'select' + # dictionary with their parameters in 'select_params' so that + # subsequent updates to the select dictionary also adjust the + # parameters appropriately. + select_pairs = SortedDict() + if select_params: + param_iter = iter(select_params) + else: + param_iter = iter([]) + for name, entry in select.items(): + entry = force_unicode(entry) + entry_params = [] + pos = entry.find("%s") + while pos != -1: + entry_params.append(param_iter.next()) + pos = entry.find("%s", pos + 2) + select_pairs[name] = (entry, entry_params) + # This is order preserving, since self.extra_select is a SortedDict. + self.extra_select.update(select_pairs) if where: self.extra_where += tuple(where) if params: diff --git a/docs/db-api.txt b/docs/db-api.txt index 3c97de3ca6..77a29cc8a5 100644 --- a/docs/db-api.txt +++ b/docs/db-api.txt @@ -1014,6 +1014,12 @@ of the arguments is required, but you should use at least one of them. select=SortedDict(('a', '%s'), ('b', '%s')), select_params=('one', 'two')) + The only thing to be careful about when using select parameters in + ``extra()`` is to avoid using the substring ``"%%s"`` (that's *two* + percent characters before the ``s``) in the select strings. Django's + tracking of parameters looks for ``%s`` and an escaped ``%`` character + like this isn't detected. That will lead to incorrect results. + ``where`` / ``tables`` You can define explicit SQL ``WHERE`` clauses -- perhaps to perform non-explicit joins -- by using ``where``. You can manually add tables to diff --git a/tests/regressiontests/extra_regress/models.py b/tests/regressiontests/extra_regress/models.py index e6665222bf..5bbb57c92c 100644 --- a/tests/regressiontests/extra_regress/models.py +++ b/tests/regressiontests/extra_regress/models.py @@ -1,7 +1,9 @@ import copy +from django.contrib.auth.models import User from django.db import models from django.db.models.query import Q +from django.utils.datastructures import SortedDict class RevisionableModel(models.Model): @@ -23,7 +25,7 @@ class RevisionableModel(models.Model): return new_revision __test__ = {"API_TESTS": """ -### Regression tests for #7314 and #7372 +# Regression tests for #7314 and #7372 >>> rm = RevisionableModel.objects.create(title='First Revision') >>> rm.pk, rm.base.pk @@ -52,4 +54,29 @@ Following queryset should return the most recent revision: >>> qs & qs2 [] +>>> u = User.objects.create_user(username="fred", password="secret", email="fred@example.com") + +# General regression tests: extra select parameters should stay tied to their +# corresponding select portions. Applies when portions are updated or otherwise +# moved around. +>>> qs = User.objects.extra(select=SortedDict((("alpha", "%s"), ("beta", "2"), ("gamma", "%s"))), select_params=(1, 3)) +>>> qs = qs.extra(select={"beta": 4}) +>>> qs = qs.extra(select={"alpha": "%s"}, select_params=[5]) +>>> result = {'alpha': 5, 'beta': 4, 'gamma': 3} +>>> list(qs.filter(id=u.id).values('alpha', 'beta', 'gamma')) == [result] +True + +# Regression test for #7957: Combining extra() calls should leave the +# corresponding parameters associated with the right extra() bit. I.e. internal +# dictionary must remain sorted. +>>> User.objects.extra(select={"alpha": "%s"}, select_params=(1,)).extra(select={"beta": "%s"}, select_params=(2,))[0].alpha +1 +>>> User.objects.extra(select={"beta": "%s"}, select_params=(1,)).extra(select={"alpha": "%s"}, select_params=(2,))[0].alpha +2 + +# Regression test for #7961: When not using a portion of an extra(...) in a +# query, remove any corresponding parameters from the query as well. +>>> list(User.objects.extra(select={"alpha": "%s"}, select_params=(-6,)).filter(id=u.id).values_list('id', flat=True)) == [u.id] +True + """}