diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 89332699d5..c1a76584f0 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -876,8 +876,11 @@ class When(Expression): conditional = False def __init__(self, condition=None, then=None, **lookups): - if lookups and condition is None: - condition, lookups = Q(**lookups), None + if lookups: + if condition is None: + condition, lookups = Q(**lookups), None + elif getattr(condition, 'conditional', False): + condition, lookups = Q(condition, **lookups), None if condition is None or not getattr(condition, 'conditional', False) or lookups: raise TypeError( 'When() supports a Q object, a boolean expression, or lookups ' diff --git a/docs/ref/models/conditional-expressions.txt b/docs/ref/models/conditional-expressions.txt index 7616b98e0a..a1e2430ec1 100644 --- a/docs/ref/models/conditional-expressions.txt +++ b/docs/ref/models/conditional-expressions.txt @@ -81,6 +81,10 @@ Keep in mind that each of these values can be an expression. >>> When(then__exact=0, then=1) >>> When(Q(then=0), then=1) +.. versionchanged:: 3.2 + + Support for using the ``condition`` argument with ``lookups`` was added. + ``Case`` -------- diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index 7aa01dcc23..2d49299440 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -178,6 +178,9 @@ Models supported on PostgreSQL, allows acquiring weaker locks that don't block the creation of rows that reference locked rows through a foreign key. +* :class:`When() ` expression now allows + using the ``condition`` argument with ``lookups``. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py index f85def932a..efb4ba227e 100644 --- a/tests/expressions_case/tests.py +++ b/tests/expressions_case/tests.py @@ -6,7 +6,7 @@ from uuid import UUID from django.core.exceptions import FieldError from django.db.models import ( - BinaryField, Case, CharField, Count, DurationField, F, + BinaryField, BooleanField, Case, CharField, Count, DurationField, F, GenericIPAddressField, IntegerField, Max, Min, Q, Sum, TextField, TimeField, UUIDField, Value, When, ) @@ -312,6 +312,17 @@ class CaseExpressionTests(TestCase): transform=attrgetter('integer', 'integer2') ) + def test_condition_with_lookups(self): + qs = CaseTestModel.objects.annotate( + test=Case( + When(Q(integer2=1), string='2', then=Value(False)), + When(Q(integer2=1), string='1', then=Value(True)), + default=Value(False), + output_field=BooleanField(), + ), + ) + self.assertIs(qs.get(integer=1).test, True) + def test_case_reuse(self): SOME_CASE = Case( When(pk=0, then=Value('0')), @@ -1350,6 +1361,8 @@ class CaseWhenTests(SimpleTestCase): When(condition=object()) with self.assertRaisesMessage(TypeError, msg): When(condition=Value(1, output_field=IntegerField())) + with self.assertRaisesMessage(TypeError, msg): + When(Value(1, output_field=IntegerField()), string='1') with self.assertRaisesMessage(TypeError, msg): When()