mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	[1.8.x] Fixed #26071 -- Fixed crash with __in lookup in a Case expression.
Partial backport of afe0bb7b13 from master.
			
			
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							e625859f08
						
					
				
				
					commit
					5b3c66d8b6
				
			| @@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin): | |||||||
|             bilateral_transforms.append((self.__class__, self.init_lookups)) |             bilateral_transforms.append((self.__class__, self.init_lookups)) | ||||||
|         return bilateral_transforms |         return bilateral_transforms | ||||||
|  |  | ||||||
|  |     @cached_property | ||||||
|  |     def contains_aggregate(self): | ||||||
|  |         return self.lhs.contains_aggregate | ||||||
|  |  | ||||||
|  |  | ||||||
| class Lookup(RegisterLookupMixin): | class Lookup(RegisterLookupMixin): | ||||||
|     lookup_name = None |     lookup_name = None | ||||||
| @@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin): | |||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @cached_property | ||||||
|  |     def contains_aggregate(self): | ||||||
|  |         return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BuiltinLookup(Lookup): | class BuiltinLookup(Lookup): | ||||||
|     def process_lhs(self, compiler, connection, lhs=None): |     def process_lhs(self, compiler, connection, lhs=None): | ||||||
|   | |||||||
| @@ -315,9 +315,9 @@ class WhereNode(tree.Node): | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _contains_aggregate(cls, obj): |     def _contains_aggregate(cls, obj): | ||||||
|         if not isinstance(obj, tree.Node): |         if isinstance(obj, tree.Node): | ||||||
|             return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) |             return any(cls._contains_aggregate(c) for c in obj.children) | ||||||
|         return any(cls._contains_aggregate(c) for c in obj.children) |         return obj.contains_aggregate | ||||||
|  |  | ||||||
|     @cached_property |     @cached_property | ||||||
|     def contains_aggregate(self): |     def contains_aggregate(self): | ||||||
| @@ -336,6 +336,7 @@ class EverythingNode(object): | |||||||
|     """ |     """ | ||||||
|     A node that matches everything. |     A node that matches everything. | ||||||
|     """ |     """ | ||||||
|  |     contains_aggregate = False | ||||||
|  |  | ||||||
|     def as_sql(self, compiler=None, connection=None): |     def as_sql(self, compiler=None, connection=None): | ||||||
|         return '', [] |         return '', [] | ||||||
| @@ -345,11 +346,16 @@ class NothingNode(object): | |||||||
|     """ |     """ | ||||||
|     A node that matches nothing. |     A node that matches nothing. | ||||||
|     """ |     """ | ||||||
|  |     contains_aggregate = False | ||||||
|  |  | ||||||
|     def as_sql(self, compiler=None, connection=None): |     def as_sql(self, compiler=None, connection=None): | ||||||
|         raise EmptyResultSet |         raise EmptyResultSet | ||||||
|  |  | ||||||
|  |  | ||||||
| class ExtraWhere(object): | class ExtraWhere(object): | ||||||
|  |     # The contents are a black box - assume no aggregates are used. | ||||||
|  |     contains_aggregate = False | ||||||
|  |  | ||||||
|     def __init__(self, sqls, params): |     def __init__(self, sqls, params): | ||||||
|         self.sqls = sqls |         self.sqls = sqls | ||||||
|         self.params = params |         self.params = params | ||||||
| @@ -410,6 +416,10 @@ class Constraint(object): | |||||||
|  |  | ||||||
|  |  | ||||||
| class SubqueryConstraint(object): | class SubqueryConstraint(object): | ||||||
|  |     # Even if aggregates would be used in a subquery, the outer query isn't | ||||||
|  |     # interested about those. | ||||||
|  |     contains_aggregate = False | ||||||
|  |  | ||||||
|     def __init__(self, alias, columns, targets, query_object): |     def __init__(self, alias, columns, targets, query_object): | ||||||
|         self.alias = alias |         self.alias = alias | ||||||
|         self.columns = columns |         self.columns = columns | ||||||
|   | |||||||
| @@ -23,3 +23,6 @@ Bugfixes | |||||||
|   ``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that |   ``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that | ||||||
|   already had the other specified, or when removing one of them from a field |   already had the other specified, or when removing one of them from a field | ||||||
|   that had both (:ticket:`26034`). |   that had both (:ticket:`26034`). | ||||||
|  |  | ||||||
|  | * Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression | ||||||
|  |   (:ticket:`26071`). | ||||||
|   | |||||||
| @@ -8,7 +8,7 @@ from uuid import UUID | |||||||
|  |  | ||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db import connection, models | from django.db import connection, models | ||||||
| from django.db.models import F, Q, Max, Min, Value | from django.db.models import F, Q, Max, Min, Sum, Value | ||||||
| from django.db.models.expressions import Case, When | from django.db.models.expressions import Case, When | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.utils import six | from django.utils import six | ||||||
| @@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase): | |||||||
|             transform=attrgetter('integer', 'join_test') |             transform=attrgetter('integer', 'join_test') | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_annotate_with_in_clause(self): | ||||||
|  |         fk_rels = FKCaseTestModel.objects.filter(integer__in=[5]) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case( | ||||||
|  |                 When(fk_rel__in=fk_rels, then=F('fk_rel__integer')), | ||||||
|  |                 default=Value(0), | ||||||
|  |             ))).order_by('pk'), | ||||||
|  |             [(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)], | ||||||
|  |             transform=attrgetter('integer', 'in_test') | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_annotate_with_join_in_condition(self): |     def test_annotate_with_join_in_condition(self): | ||||||
|         self.assertQuerysetEqual( |         self.assertQuerysetEqual( | ||||||
|             CaseTestModel.objects.annotate(join_test=Case( |             CaseTestModel.objects.annotate(join_test=Case( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user