From 00c690efbc0b10f67924687f24a7b30397bf47d9 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Mon, 20 Jan 2025 22:36:47 -0500 Subject: [PATCH] Fixed #36117 -- Raised ValueError when providing composite expressions to case / when. Remove redundant Case and When.resolve_expression to delegate composite expression support to BaseExpression. Thanks Jacob Tyler Walls for the report and test. --- django/db/models/expressions.py | 28 ---------------------------- tests/composite_pk/test_filter.py | 19 ++++++++++++++++++- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 50b62a6d38..ad8f8e6650 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1577,20 +1577,6 @@ class When(Expression): # We're only interested in the fields of the result expressions. return [self.result._output_field_or_none] - def resolve_expression( - self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False - ): - c = self.copy() - c.is_summary = summarize - if hasattr(c.condition, "resolve_expression"): - c.condition = c.condition.resolve_expression( - query, allow_joins, reuse, summarize, False - ) - c.result = c.result.resolve_expression( - query, allow_joins, reuse, summarize, for_save - ) - return c - def as_sql(self, compiler, connection, template=None, **extra_context): connection.ops.check_expression_support(self) template_params = extra_context @@ -1658,20 +1644,6 @@ class Case(SQLiteNumericMixin, Expression): def set_source_expressions(self, exprs): *self.cases, self.default = exprs - def resolve_expression( - self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False - ): - c = self.copy() - c.is_summary = summarize - for pos, case in enumerate(c.cases): - c.cases[pos] = case.resolve_expression( - query, allow_joins, reuse, summarize, for_save - ) - c.default = c.default.resolve_expression( - query, allow_joins, reuse, summarize, for_save - ) - return c - def copy(self): c = super().copy() c.cases = c.cases[:] diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py index aa2d9ebe36..fe942b9e5b 100644 --- a/tests/composite_pk/test_filter.py +++ b/tests/composite_pk/test_filter.py @@ -1,4 +1,13 @@ -from django.db.models import F, FilteredRelation, OuterRef, Q, Subquery, TextField +from django.db.models import ( + Case, + F, + FilteredRelation, + OuterRef, + Q, + Subquery, + TextField, + When, +) from django.db.models.functions import Cast from django.db.models.lookups import Exact from django.test import TestCase @@ -409,6 +418,14 @@ class CompositePKFilterTests(TestCase): with self.assertRaisesMessage(ValueError, msg): Comment.objects.filter(text__gt=Cast(F("pk"), TextField())).count() + def test_filter_case_when(self): + msg = "When expression does not support composite primary keys." + with self.assertRaisesMessage(ValueError, msg): + Comment.objects.filter(text=Case(When(text="", then="pk"))) + msg = "Case expression does not support composite primary keys." + with self.assertRaisesMessage(ValueError, msg): + Comment.objects.filter(text=Case(When(text="", then="text"), default="pk")) + def test_outer_ref_pk(self): subquery = Subquery(Comment.objects.filter(pk=OuterRef("pk")).values("id")) tests = [