From 04da22633fcda983cb9ee69e63b2ebe99301b717 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Thu, 20 Mar 2008 19:16:04 +0000 Subject: [PATCH] queryset-refactor: Fixed up extra(select=...) calls with parameters so that the parameters are substituted in correctly in all cases. This introduces an extra argument to extra() for this purpose; no alternative there. Also fixed values() to work if you don't specify *all* the extra select aliases in the values() call. Refs #3141. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7340 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 27 ++++++++----- django/db/models/sql/query.py | 40 ++++++++++++------ docs/db-api.txt | 54 ++++++++++--------------- tests/regressiontests/queries/models.py | 13 +++++- 4 files changed, 76 insertions(+), 58 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 5cbaf9d5ff..388a54dc85 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -142,16 +142,16 @@ class _QuerySet(object): else: requested = None max_depth = self.query.max_depth - index_end = len(self.model._meta.fields) extra_select = self.query.extra_select.keys() + index_start = len(extra_select) for row in self.query.results_iter(): if fill_cache: - obj, index_end = get_cached_row(self.model, row, 0, max_depth, - requested=requested) + obj, _ = get_cached_row(self.model, row, index_start, + max_depth, requested=requested) else: - obj = self.model(*row[:index_end]) + obj = self.model(*row[index_start:]) for i, k in enumerate(extra_select): - setattr(obj, k, row[index_end + i]) + setattr(obj, k, row[i]) yield obj def count(self): @@ -413,14 +413,14 @@ class _QuerySet(object): return obj def extra(self, select=None, where=None, params=None, tables=None, - order_by=None): + order_by=None, select_params=None): """ Add extra SQL fragments to the query. """ assert self.query.can_filter(), \ "Cannot change a query once a slice has been taken" clone = self._clone() - clone.query.add_extra(select, where, params, tables, order_by) + clone.query.add_extra(select, select_params, where, params, tables, order_by) return clone def reverse(self): @@ -475,9 +475,10 @@ class ValuesQuerySet(QuerySet): return self.iterator() def iterator(self): - self.field_names.extend([f for f in self.query.extra_select.keys()]) + self.query.trim_extra_select(self.extra_names) + names = self.query.extra_select.keys() + self.field_names for row in self.query.results_iter(): - yield dict(zip(self.field_names, row)) + yield dict(zip(names, row)) def _setup_query(self): """ @@ -487,6 +488,7 @@ class ValuesQuerySet(QuerySet): Called by the _clone() method after initialising the rest of the instance. """ + self.extra_names = [] if self._fields: if not self.query.extra_select: field_names = list(self._fields) @@ -496,7 +498,9 @@ class ValuesQuerySet(QuerySet): for f in self._fields: if f in names: field_names.append(f) - elif not self.query.extra_select.has_key(f): + elif self.query.extra_select.has_key(f): + self.extra_names.append(f) + else: raise FieldDoesNotExist('%s has no field named %r' % (self.model._meta.object_name, f)) else: @@ -513,7 +517,8 @@ class ValuesQuerySet(QuerySet): """ c = super(ValuesQuerySet, self)._clone(klass, **kwargs) c._fields = self._fields[:] - c.field_names = self.field_names[:] + c.field_names = self.field_names + c.extra_names = self.extra_names if setup and hasattr(c, '_setup_query'): c._setup_query() return c diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b67bfae699..6a68448bf1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -73,6 +73,7 @@ class Query(object): # These are for extensions. The contents are more or less appended # verbatim to the appropriate clause. self.extra_select = {} # Maps col_alias -> col_sql. + self.extra_select_params = () self.extra_tables = () self.extra_where = () self.extra_params = () @@ -150,6 +151,7 @@ class Query(object): obj.select_related = self.select_related obj.max_depth = self.max_depth obj.extra_select = self.extra_select.copy() + obj.extra_select_params = self.extra_select_params obj.extra_tables = self.extra_tables obj.extra_where = self.extra_where obj.extra_params = self.extra_params @@ -214,6 +216,7 @@ class Query(object): # get_from_clause() for details. from_, f_params = self.get_from_clause() where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias) + params = list(self.extra_select_params) result = ['SELECT'] if self.distinct: @@ -222,7 +225,7 @@ class Query(object): result.append('FROM') result.extend(from_) - params = list(f_params) + params.extend(f_params) if where: result.append('WHERE %s' % where) @@ -351,8 +354,8 @@ class Query(object): the model. """ qn = self.quote_name_unless_alias - result = [] - aliases = [] + result = ['(%s) AS %s' % (col, alias) for alias, col in self.extra_select.items()] + aliases = self.extra_select.keys() if self.select: for col in self.select: if isinstance(col, (list, tuple)): @@ -364,12 +367,9 @@ class Query(object): if hasattr(col, 'alias'): aliases.append(col.alias) elif self.default_cols: - result = self.get_default_columns(True) - aliases = result[:] - - result.extend(['(%s) AS %s' % (col, alias) - for alias, col in self.extra_select.items()]) - aliases.extend(self.extra_select.keys()) + cols = self.get_default_columns(True) + result.extend(cols) + aliases.extend(cols) self._select_aliases = set(aliases) return result @@ -403,9 +403,9 @@ class Query(object): def get_from_clause(self): """ Returns a list of strings that are joined together to go after the - "FROM" part of the query, as well as any extra parameters that need to - be included. Sub-classes, can override this to create a from-clause via - a "select", for example (e.g. CountQuery). + "FROM" part of the query, as well as a list any extra parameters that + need to be included. Sub-classes, can override this to create a + from-clause via a "select", for example (e.g. CountQuery). This should only be called after any SQL construction methods that might change the tables we need. This means the select columns and @@ -1253,6 +1253,7 @@ class Query(object): self.distinct = False self.select = [select] self.extra_select = {} + self.extra_select_params = () def add_select_related(self, fields): """ @@ -1267,7 +1268,7 @@ class Query(object): d = d.setdefault(part, {}) self.select_related = field_dict - def add_extra(self, select, where, params, tables, order_by): + def add_extra(self, select, select_params, where, params, tables, order_by): """ Adds data to the various extra_* attributes for user-created additions to the query. @@ -1279,6 +1280,8 @@ class Query(object): not isinstance(self.extra_select, SortedDict)): self.extra_select = SortedDict(self.extra_select) self.extra_select.update(select) + if select_params: + self.extra_select_params += tuple(select_params) if where: self.extra_where += tuple(where) if params: @@ -1288,6 +1291,17 @@ class Query(object): if order_by: self.extra_order_by = order_by + def trim_extra_select(self, names): + """ + Removes any aliases in the extra_select dictionary that aren't in + 'names'. + + This is needed if we are selecting certain values that don't incldue + all of the extra_select names. + """ + for key in set(self.extra_select).difference(set(names)): + del self.extra_select[key] + def set_start(self, start): """ Sets the table from which to start joining. The start position is diff --git a/docs/db-api.txt b/docs/db-api.txt index 076f406aa0..dcd648be1c 100644 --- a/docs/db-api.txt +++ b/docs/db-api.txt @@ -841,8 +841,9 @@ You can only refer to ``ForeignKey`` relations in the list of fields passed to list of fields and the ``depth`` parameter in the same ``select_related()`` call, since they are conflicting options. -``extra(select=None, where=None, params=None, tables=None, order_by=None)`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``extra(select=None, where=None, params=None, tables=None, order_by=None, +select_params=None)`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sometimes, the Django query syntax by itself can't easily express a complex ``WHERE`` clause. For these edge cases, Django provides the ``extra()`` @@ -901,31 +902,18 @@ of the arguments is required, but you should use at least one of them. **New in Django development version** In some rare cases, you might wish to pass parameters to the SQL fragments - in ``extra(select=...)```. Since the ``params`` attribute is a sequence - and the ``select`` attribute is a dictionary, some care is required so - that the parameters are matched up correctly with the extra select pieces. - Firstly, in this situation, you should use a - ``django.utils.datastructures.SortedDict`` for the ``select`` value, not - just a normal Python dictionary. Secondly, make sure that your parameters - for the ``select`` come first in the list and that you have not passed any - parameters to an earlier ``extra()`` call for this queryset. + in ``extra(select=...)```. For this purpose, use the ``select_params`` + parameter. Since ``select_params`` is a sequence and the ``select`` + attribute is a dictionary, some care is required so that the parameters + are matched up correctly with the extra select pieces. In this situation, + you should use a ``django.utils.datastructures.SortedDict`` for the + ``select`` value, not just a normal Python dictionary. - This will work:: + This will work, for example:: Blog.objects.extra( select=SortedDict(('a', '%s'), ('b', '%s')), - params=('one', 'two')) - - ... while this won't:: - - # Will not work! - Blog.objects.extra(where=['foo=%s'], params=('bar',)).extra( - select=SortedDict(('a', '%s'), ('b', '%s')), - params=('one', 'two')) - - In the second example, the earlier ``params`` usage will mess up the later - one. So always put your extra select pieces in the first ``extra()`` call - if you need to use parameters in them. + select_params=('one', 'two')) ``where`` / ``tables`` You can define explicit SQL ``WHERE`` clauses -- perhaps to perform @@ -965,19 +953,18 @@ of the arguments is required, but you should use at least one of them. time). ``params`` - The ``select`` and ``where`` parameters described above may use standard - Python database string placeholders -- ``'%s'`` to indicate parameters the - database engine should automatically quote. The ``params`` argument is a - list of any extra parameters to be substituted. + The ``where`` parameter described above may use standard Python database + string placeholders -- ``'%s'`` to indicate parameters the database engine + should automatically quote. The ``params`` argument is a list of any extra + parameters to be substituted. Example:: Entry.objects.extra(where=['headline=%s'], params=['Lennon']) - Always use ``params`` instead of embedding values directly into ``select`` - or ``where`` because ``params`` will ensure values are quoted correctly - according to your particular backend. (For example, quotes will be escaped - correctly.) + Always use ``params`` instead of embedding values directly into ``where`` + because ``params`` will ensure values are quoted correctly according to + your particular backend. (For example, quotes will be escaped correctly.) Bad:: @@ -987,8 +974,9 @@ of the arguments is required, but you should use at least one of them. Entry.objects.extra(where=['headline=%s'], params=['Lennon']) - The combined number of placeholders in the list of strings for ``select`` - or ``where`` should equal the number of values in the ``params`` list. +**New in Django development version** The ``select_params`` argument to +``extra()`` is new. Previously, you could attempt to pass parameters for +``select`` in the ``params`` argument, but it worked very unreliably. QuerySet methods that do not return QuerySets --------------------------------------------- diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index 2db5bf8a34..9b86d60fde 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -282,6 +282,10 @@ Bug #1878, #2939 >>> xx.save() >>> Item.objects.exclude(name='two').values('creator', 'name').distinct().count() 4 +>>> Item.objects.exclude(name='two').extra(select={'foo': '%s'}, select_params=(1,)).values('creator', 'name', 'foo').distinct().count() +4 +>>> Item.objects.exclude(name='two').extra(select={'foo': '%s'}, select_params=(1,)).values('creator', 'name').distinct().count() +4 >>> xx.delete() Bug #2253 @@ -386,6 +390,8 @@ AssertionError: Cannot combine queries on two different base models. Bug #3141 >>> Author.objects.extra(select={'foo': '1'}).count() 4 +>>> Author.objects.extra(select={'foo': '%s'}, select_params=(1,)).count() +4 Bug #2400 >>> Author.objects.filter(item__isnull=True) @@ -462,6 +468,11 @@ True >>> qs.extra(order_by=('-good', 'id')) [, , ] +# Despite having some extra aliases in the query, we can still omit them in a +# values() query. +>>> qs.values('id', 'rank').order_by('id') +[{'id': 1, 'rank': 2}, {'id': 2, 'rank': 1}, {'id': 3, 'rank': 3}] + Bugs #2874, #3002 >>> qs = Item.objects.select_related().order_by('note__note', 'name') >>> list(qs) @@ -533,7 +544,7 @@ thus fail.) # This slightly odd comparison works aorund the fact that PostgreSQL will # return 'one' and 'two' as strings, not Unicode objects. It's a side-effect of # using constants here and not a real concern. ->>> d = Item.objects.extra(select=SortedDict(s), params=params).values('a', 'b')[0] +>>> d = Item.objects.extra(select=SortedDict(s), select_params=params).values('a', 'b')[0] >>> d == {'a': u'one', 'b': u'two'} True