mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.
This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause. git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -715,9 +715,6 @@ class ValuesQuerySet(QuerySet): | |||||||
|  |  | ||||||
|     def iterator(self): |     def iterator(self): | ||||||
|         # Purge any extra columns that haven't been explicitly asked for |         # Purge any extra columns that haven't been explicitly asked for | ||||||
|         if self.extra_names is not None: |  | ||||||
|             self.query.trim_extra_select(self.extra_names) |  | ||||||
|  |  | ||||||
|         extra_names = self.query.extra_select.keys() |         extra_names = self.query.extra_select.keys() | ||||||
|         field_names = self.field_names |         field_names = self.field_names | ||||||
|         aggregate_names = self.query.aggregate_select.keys() |         aggregate_names = self.query.aggregate_select.keys() | ||||||
| @@ -741,13 +738,18 @@ class ValuesQuerySet(QuerySet): | |||||||
|         if self._fields: |         if self._fields: | ||||||
|             self.extra_names = [] |             self.extra_names = [] | ||||||
|             self.aggregate_names = [] |             self.aggregate_names = [] | ||||||
|             if not self.query.extra_select and not self.query.aggregate_select: |             if not self.query.extra and not self.query.aggregates: | ||||||
|  |                 # Short cut - if there are no extra or aggregates, then | ||||||
|  |                 # the values() clause must be just field names. | ||||||
|                 self.field_names = list(self._fields) |                 self.field_names = list(self._fields) | ||||||
|             else: |             else: | ||||||
|                 self.query.default_cols = False |                 self.query.default_cols = False | ||||||
|                 self.field_names = [] |                 self.field_names = [] | ||||||
|                 for f in self._fields: |                 for f in self._fields: | ||||||
|                     if self.query.extra_select.has_key(f): |                     # we inspect the full extra_select list since we might | ||||||
|  |                     # be adding back an extra select item that we hadn't | ||||||
|  |                     # had selected previously. | ||||||
|  |                     if self.query.extra.has_key(f): | ||||||
|                         self.extra_names.append(f) |                         self.extra_names.append(f) | ||||||
|                     elif self.query.aggregate_select.has_key(f): |                     elif self.query.aggregate_select.has_key(f): | ||||||
|                         self.aggregate_names.append(f) |                         self.aggregate_names.append(f) | ||||||
| @@ -760,6 +762,8 @@ class ValuesQuerySet(QuerySet): | |||||||
|             self.aggregate_names = None |             self.aggregate_names = None | ||||||
|  |  | ||||||
|         self.query.select = [] |         self.query.select = [] | ||||||
|  |         if self.extra_names is not None: | ||||||
|  |             self.query.set_extra_mask(self.extra_names) | ||||||
|         self.query.add_fields(self.field_names, False) |         self.query.add_fields(self.field_names, False) | ||||||
|         if self.aggregate_names is not None: |         if self.aggregate_names is not None: | ||||||
|             self.query.set_aggregate_mask(self.aggregate_names) |             self.query.set_aggregate_mask(self.aggregate_names) | ||||||
| @@ -816,9 +820,6 @@ class ValuesQuerySet(QuerySet): | |||||||
|  |  | ||||||
| class ValuesListQuerySet(ValuesQuerySet): | class ValuesListQuerySet(ValuesQuerySet): | ||||||
|     def iterator(self): |     def iterator(self): | ||||||
|         if self.extra_names is not None: |  | ||||||
|             self.query.trim_extra_select(self.extra_names) |  | ||||||
|  |  | ||||||
|         if self.flat and len(self._fields) == 1: |         if self.flat and len(self._fields) == 1: | ||||||
|             for row in self.query.results_iter(): |             for row in self.query.results_iter(): | ||||||
|                 yield row[0] |                 yield row[0] | ||||||
|   | |||||||
| @@ -88,7 +88,10 @@ class BaseQuery(object): | |||||||
|  |  | ||||||
|         # These are for extensions. The contents are more or less appended |         # These are for extensions. The contents are more or less appended | ||||||
|         # verbatim to the appropriate clause. |         # verbatim to the appropriate clause. | ||||||
|         self.extra_select = SortedDict()  # Maps col_alias -> (col_sql, params). |         self.extra = SortedDict()  # Maps col_alias -> (col_sql, params). | ||||||
|  |         self.extra_select_mask = None | ||||||
|  |         self._extra_select_cache = None | ||||||
|  |  | ||||||
|         self.extra_tables = () |         self.extra_tables = () | ||||||
|         self.extra_where = () |         self.extra_where = () | ||||||
|         self.extra_params = () |         self.extra_params = () | ||||||
| @@ -214,13 +217,21 @@ class BaseQuery(object): | |||||||
|         if self.aggregate_select_mask is None: |         if self.aggregate_select_mask is None: | ||||||
|             obj.aggregate_select_mask = None |             obj.aggregate_select_mask = None | ||||||
|         else: |         else: | ||||||
|             obj.aggregate_select_mask = self.aggregate_select_mask[:] |             obj.aggregate_select_mask = self.aggregate_select_mask.copy() | ||||||
|         if self._aggregate_select_cache is None: |         if self._aggregate_select_cache is None: | ||||||
|             obj._aggregate_select_cache = None |             obj._aggregate_select_cache = None | ||||||
|         else: |         else: | ||||||
|             obj._aggregate_select_cache = self._aggregate_select_cache.copy() |             obj._aggregate_select_cache = self._aggregate_select_cache.copy() | ||||||
|         obj.max_depth = self.max_depth |         obj.max_depth = self.max_depth | ||||||
|         obj.extra_select = self.extra_select.copy() |         obj.extra = self.extra.copy() | ||||||
|  |         if self.extra_select_mask is None: | ||||||
|  |             obj.extra_select_mask = None | ||||||
|  |         else: | ||||||
|  |             obj.extra_select_mask = self.extra_select_mask.copy() | ||||||
|  |         if self._extra_select_cache is None: | ||||||
|  |             obj._extra_select_cache = None | ||||||
|  |         else: | ||||||
|  |             obj._extra_select_cache = self._extra_select_cache.copy() | ||||||
|         obj.extra_tables = self.extra_tables |         obj.extra_tables = self.extra_tables | ||||||
|         obj.extra_where = self.extra_where |         obj.extra_where = self.extra_where | ||||||
|         obj.extra_params = self.extra_params |         obj.extra_params = self.extra_params | ||||||
| @@ -325,7 +336,7 @@ class BaseQuery(object): | |||||||
|             query = self |             query = self | ||||||
|             self.select = [] |             self.select = [] | ||||||
|             self.default_cols = False |             self.default_cols = False | ||||||
|             self.extra_select = {} |             self.extra = {} | ||||||
|             self.remove_inherited_models() |             self.remove_inherited_models() | ||||||
|  |  | ||||||
|         query.clear_ordering(True) |         query.clear_ordering(True) | ||||||
| @@ -540,13 +551,20 @@ class BaseQuery(object): | |||||||
|             # It would be nice to be able to handle this, but the queries don't |             # It would be nice to be able to handle this, but the queries don't | ||||||
|             # really make sense (or return consistent value sets). Not worth |             # really make sense (or return consistent value sets). Not worth | ||||||
|             # the extra complexity when you can write a real query instead. |             # the extra complexity when you can write a real query instead. | ||||||
|             if self.extra_select and rhs.extra_select: |             if self.extra and rhs.extra: | ||||||
|                 raise ValueError("When merging querysets using 'or', you " |                 raise ValueError("When merging querysets using 'or', you " | ||||||
|                         "cannot have extra(select=...) on both sides.") |                         "cannot have extra(select=...) on both sides.") | ||||||
|             if self.extra_where and rhs.extra_where: |             if self.extra_where and rhs.extra_where: | ||||||
|                 raise ValueError("When merging querysets using 'or', you " |                 raise ValueError("When merging querysets using 'or', you " | ||||||
|                         "cannot have extra(where=...) on both sides.") |                         "cannot have extra(where=...) on both sides.") | ||||||
|         self.extra_select.update(rhs.extra_select) |         self.extra.update(rhs.extra) | ||||||
|  |         extra_select_mask = set() | ||||||
|  |         if self.extra_select_mask is not None: | ||||||
|  |             extra_select_mask.update(self.extra_select_mask) | ||||||
|  |         if rhs.extra_select_mask is not None: | ||||||
|  |             extra_select_mask.update(rhs.extra_select_mask) | ||||||
|  |         if extra_select_mask: | ||||||
|  |             self.set_extra_mask(extra_select_mask) | ||||||
|         self.extra_tables += rhs.extra_tables |         self.extra_tables += rhs.extra_tables | ||||||
|         self.extra_where += rhs.extra_where |         self.extra_where += rhs.extra_where | ||||||
|         self.extra_params += rhs.extra_params |         self.extra_params += rhs.extra_params | ||||||
| @@ -2011,7 +2029,7 @@ class BaseQuery(object): | |||||||
|         except MultiJoin: |         except MultiJoin: | ||||||
|             raise FieldError("Invalid field name: '%s'" % name) |             raise FieldError("Invalid field name: '%s'" % name) | ||||||
|         except FieldError: |         except FieldError: | ||||||
|             names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys() |             names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys() | ||||||
|             names.sort() |             names.sort() | ||||||
|             raise FieldError("Cannot resolve keyword %r into field. " |             raise FieldError("Cannot resolve keyword %r into field. " | ||||||
|                     "Choices are: %s" % (name, ", ".join(names))) |                     "Choices are: %s" % (name, ", ".join(names))) | ||||||
| @@ -2139,7 +2157,7 @@ class BaseQuery(object): | |||||||
|                     pos = entry.find("%s", pos + 2) |                     pos = entry.find("%s", pos + 2) | ||||||
|                 select_pairs[name] = (entry, entry_params) |                 select_pairs[name] = (entry, entry_params) | ||||||
|             # This is order preserving, since self.extra_select is a SortedDict. |             # This is order preserving, since self.extra_select is a SortedDict. | ||||||
|             self.extra_select.update(select_pairs) |             self.extra.update(select_pairs) | ||||||
|         if where: |         if where: | ||||||
|             self.extra_where += tuple(where) |             self.extra_where += tuple(where) | ||||||
|         if params: |         if params: | ||||||
| @@ -2213,22 +2231,26 @@ class BaseQuery(object): | |||||||
|         """ |         """ | ||||||
|         target[model] = set([f.name for f in fields]) |         target[model] = set([f.name for f in fields]) | ||||||
|  |  | ||||||
|     def trim_extra_select(self, names): |  | ||||||
|         """ |  | ||||||
|         Removes any aliases in the extra_select dictionary that aren't in |  | ||||||
|         'names'. |  | ||||||
|  |  | ||||||
|         This is needed if we are selecting certain values that don't incldue |  | ||||||
|         all of the extra_select names. |  | ||||||
|         """ |  | ||||||
|         for key in set(self.extra_select).difference(set(names)): |  | ||||||
|             del self.extra_select[key] |  | ||||||
|  |  | ||||||
|     def set_aggregate_mask(self, names): |     def set_aggregate_mask(self, names): | ||||||
|         "Set the mask of aggregates that will actually be returned by the SELECT" |         "Set the mask of aggregates that will actually be returned by the SELECT" | ||||||
|         self.aggregate_select_mask = names |         if names is None: | ||||||
|  |             self.aggregate_select_mask = None | ||||||
|  |         else: | ||||||
|  |             self.aggregate_select_mask = set(names) | ||||||
|         self._aggregate_select_cache = None |         self._aggregate_select_cache = None | ||||||
|  |  | ||||||
|  |     def set_extra_mask(self, names): | ||||||
|  |         """ | ||||||
|  |         Set the mask of extra select items that will be returned by SELECT, | ||||||
|  |         we don't actually remove them from the Query since they might be used | ||||||
|  |         later | ||||||
|  |         """ | ||||||
|  |         if names is None: | ||||||
|  |             self.extra_select_mask = None | ||||||
|  |         else: | ||||||
|  |             self.extra_select_mask = set(names) | ||||||
|  |         self._extra_select_cache = None | ||||||
|  |  | ||||||
|     def _aggregate_select(self): |     def _aggregate_select(self): | ||||||
|         """The SortedDict of aggregate columns that are not masked, and should |         """The SortedDict of aggregate columns that are not masked, and should | ||||||
|         be used in the SELECT clause. |         be used in the SELECT clause. | ||||||
| @@ -2247,6 +2269,19 @@ class BaseQuery(object): | |||||||
|             return self.aggregates |             return self.aggregates | ||||||
|     aggregate_select = property(_aggregate_select) |     aggregate_select = property(_aggregate_select) | ||||||
|  |  | ||||||
|  |     def _extra_select(self): | ||||||
|  |         if self._extra_select_cache is not None: | ||||||
|  |             return self._extra_select_cache | ||||||
|  |         elif self.extra_select_mask is not None: | ||||||
|  |             self._extra_select_cache = SortedDict([ | ||||||
|  |                 (k,v) for k,v in self.extra.items() | ||||||
|  |                 if k in self.extra_select_mask | ||||||
|  |             ]) | ||||||
|  |             return self._extra_select_cache | ||||||
|  |         else: | ||||||
|  |             return self.extra | ||||||
|  |     extra_select = property(_extra_select) | ||||||
|  |  | ||||||
|     def set_start(self, start): |     def set_start(self, start): | ||||||
|         """ |         """ | ||||||
|         Sets the table from which to start joining. The start position is |         Sets the table from which to start joining. The start position is | ||||||
|   | |||||||
| @@ -178,7 +178,7 @@ class UpdateQuery(Query): | |||||||
|         # from other tables. |         # from other tables. | ||||||
|         query = self.clone(klass=Query) |         query = self.clone(klass=Query) | ||||||
|         query.bump_prefix() |         query.bump_prefix() | ||||||
|         query.extra_select = {} |         query.extra = {} | ||||||
|         query.select = [] |         query.select = [] | ||||||
|         query.add_fields([query.model._meta.pk.name]) |         query.add_fields([query.model._meta.pk.name]) | ||||||
|         must_pre_select = count > 1 and not self.connection.features.update_can_self_select |         must_pre_select = count > 1 and not self.connection.features.update_can_self_select | ||||||
| @@ -409,7 +409,7 @@ class DateQuery(Query): | |||||||
|         self.select = [select] |         self.select = [select] | ||||||
|         self.select_fields = [None] |         self.select_fields = [None] | ||||||
|         self.select_related = False # See #7097. |         self.select_related = False # See #7097. | ||||||
|         self.extra_select = {} |         self.extra = {} | ||||||
|         self.distinct = True |         self.distinct = True | ||||||
|         self.order_by = order == 'ASC' and [1] or [-1] |         self.order_by = order == 'ASC' and [1] or [-1] | ||||||
|  |  | ||||||
|   | |||||||
| @@ -35,6 +35,9 @@ class TestObject(models.Model): | |||||||
|     second = models.CharField(max_length=20) |     second = models.CharField(max_length=20) | ||||||
|     third = models.CharField(max_length=20) |     third = models.CharField(max_length=20) | ||||||
|  |  | ||||||
|  |     def __unicode__(self): | ||||||
|  |         return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third) | ||||||
|  |  | ||||||
| __test__ = {"API_TESTS": """ | __test__ = {"API_TESTS": """ | ||||||
| # Regression tests for #7314 and #7372 | # Regression tests for #7314 and #7372 | ||||||
|  |  | ||||||
| @@ -189,6 +192,19 @@ True | |||||||
| >>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id') | >>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id') | ||||||
| [(u'third', u'first', u'second', 1)] | [(u'third', u'first', u'second', 1)] | ||||||
|  |  | ||||||
|  | # Regression for #10847: the list of extra columns can always be accurately evaluated. | ||||||
|  | # Using an inner query ensures that as_sql() is producing correct output | ||||||
|  | # without requiring full evaluation and execution of the inner query. | ||||||
|  | >>> TestObject.objects.extra(select={'extra': 1}).values('pk') | ||||||
|  | [{'pk': 1}] | ||||||
|  |  | ||||||
|  | >>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk')) | ||||||
|  | [<TestObject: TestObject: first,second,third>] | ||||||
|  |  | ||||||
|  | >>> TestObject.objects.values('pk').extra(select={'extra': 1}) | ||||||
|  | [{'pk': 1}] | ||||||
|  |  | ||||||
|  | >>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1})) | ||||||
|  | [<TestObject: TestObject: first,second,third>] | ||||||
|  |  | ||||||
| """} | """} | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user