mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	[1.8.x] Fixed #26071 -- Fixed crash with __in lookup in a Case expression.
Partial backport of afe0bb7b13 from master.
			
			
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							e625859f08
						
					
				
				
					commit
					5b3c66d8b6
				
			| @@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin): | ||||
|             bilateral_transforms.append((self.__class__, self.init_lookups)) | ||||
|         return bilateral_transforms | ||||
|  | ||||
|     @cached_property | ||||
|     def contains_aggregate(self): | ||||
|         return self.lhs.contains_aggregate | ||||
|  | ||||
|  | ||||
| class Lookup(RegisterLookupMixin): | ||||
|     lookup_name = None | ||||
| @@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin): | ||||
|     def as_sql(self, compiler, connection): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @cached_property | ||||
|     def contains_aggregate(self): | ||||
|         return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) | ||||
|  | ||||
|  | ||||
| class BuiltinLookup(Lookup): | ||||
|     def process_lhs(self, compiler, connection, lhs=None): | ||||
|   | ||||
| @@ -315,9 +315,9 @@ class WhereNode(tree.Node): | ||||
|  | ||||
|     @classmethod | ||||
|     def _contains_aggregate(cls, obj): | ||||
|         if not isinstance(obj, tree.Node): | ||||
|             return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) | ||||
|         if isinstance(obj, tree.Node): | ||||
|             return any(cls._contains_aggregate(c) for c in obj.children) | ||||
|         return obj.contains_aggregate | ||||
|  | ||||
|     @cached_property | ||||
|     def contains_aggregate(self): | ||||
| @@ -336,6 +336,7 @@ class EverythingNode(object): | ||||
|     """ | ||||
|     A node that matches everything. | ||||
|     """ | ||||
|     contains_aggregate = False | ||||
|  | ||||
|     def as_sql(self, compiler=None, connection=None): | ||||
|         return '', [] | ||||
| @@ -345,11 +346,16 @@ class NothingNode(object): | ||||
|     """ | ||||
|     A node that matches nothing. | ||||
|     """ | ||||
|     contains_aggregate = False | ||||
|  | ||||
|     def as_sql(self, compiler=None, connection=None): | ||||
|         raise EmptyResultSet | ||||
|  | ||||
|  | ||||
| class ExtraWhere(object): | ||||
|     # The contents are a black box - assume no aggregates are used. | ||||
|     contains_aggregate = False | ||||
|  | ||||
|     def __init__(self, sqls, params): | ||||
|         self.sqls = sqls | ||||
|         self.params = params | ||||
| @@ -410,6 +416,10 @@ class Constraint(object): | ||||
|  | ||||
|  | ||||
| class SubqueryConstraint(object): | ||||
|     # Even if aggregates would be used in a subquery, the outer query isn't | ||||
|     # interested about those. | ||||
|     contains_aggregate = False | ||||
|  | ||||
|     def __init__(self, alias, columns, targets, query_object): | ||||
|         self.alias = alias | ||||
|         self.columns = columns | ||||
|   | ||||
| @@ -23,3 +23,6 @@ Bugfixes | ||||
|   ``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that | ||||
|   already had the other specified, or when removing one of them from a field | ||||
|   that had both (:ticket:`26034`). | ||||
|  | ||||
| * Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression | ||||
|   (:ticket:`26071`). | ||||
|   | ||||
| @@ -8,7 +8,7 @@ from uuid import UUID | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db import connection, models | ||||
| from django.db.models import F, Q, Max, Min, Value | ||||
| from django.db.models import F, Q, Max, Min, Sum, Value | ||||
| from django.db.models.expressions import Case, When | ||||
| from django.test import TestCase | ||||
| from django.utils import six | ||||
| @@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase): | ||||
|             transform=attrgetter('integer', 'join_test') | ||||
|         ) | ||||
|  | ||||
|     def test_annotate_with_in_clause(self): | ||||
|         fk_rels = FKCaseTestModel.objects.filter(integer__in=[5]) | ||||
|         self.assertQuerysetEqual( | ||||
|             CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case( | ||||
|                 When(fk_rel__in=fk_rels, then=F('fk_rel__integer')), | ||||
|                 default=Value(0), | ||||
|             ))).order_by('pk'), | ||||
|             [(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)], | ||||
|             transform=attrgetter('integer', 'in_test') | ||||
|         ) | ||||
|  | ||||
|     def test_annotate_with_join_in_condition(self): | ||||
|         self.assertQuerysetEqual( | ||||
|             CaseTestModel.objects.annotate(join_test=Case( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user