From e306687a3a5507d59365ba9bf545010e5fd4b2a8 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Mon, 20 Jan 2025 22:26:01 -0500 Subject: [PATCH] [5.2.x] Refs #36042 -- Consolidated composite expression checks in BaseExpression. Remove redundant Func.resolve_expression and adjust CombinedExpression to delegate source expression resolving to super() to perform checks against allows_composite_expressions in a single location. Backport of a76035e925ff4e6d8676c65cb135c74b993b1039 from main. --- django/db/models/expressions.py | 82 +++++++++------------------- tests/composite_pk/test_aggregate.py | 2 +- tests/composite_pk/test_filter.py | 4 +- 3 files changed, 30 insertions(+), 58 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 2494ec4139..50b62a6d38 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -292,16 +292,22 @@ class BaseExpression: """ c = self.copy() c.is_summary = summarize - c.set_source_expressions( - [ - ( - expr.resolve_expression(query, allow_joins, reuse, summarize) - if expr - else None - ) - for expr in c.get_source_expressions() - ] - ) + source_expressions = [ + ( + expr.resolve_expression(query, allow_joins, reuse, summarize) + if expr is not None + else None + ) + for expr in c.get_source_expressions() + ] + if not self.allows_composite_expressions and any( + isinstance(expr, ColPairs) for expr in source_expressions + ): + raise ValueError( + f"{self.__class__.__name__} expression does not support " + "composite primary keys." + ) + c.set_source_expressions(source_expressions) return c @property @@ -754,32 +760,25 @@ class CombinedExpression(SQLiteNumericMixin, Expression): def resolve_expression( self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False ): - lhs = self.lhs.resolve_expression( - query, allow_joins, reuse, summarize, for_save + resolved = super().resolve_expression( + query, + allow_joins, + reuse, + summarize, + for_save, ) - rhs = self.rhs.resolve_expression( - query, allow_joins, reuse, summarize, for_save - ) - if isinstance(lhs, ColPairs) or isinstance(rhs, ColPairs): - raise ValueError("CompositePrimaryKey is not combinable.") if not isinstance(self, (DurationExpression, TemporalSubtraction)): try: - lhs_type = lhs.output_field.get_internal_type() + lhs_type = resolved.lhs.output_field.get_internal_type() except (AttributeError, FieldError): lhs_type = None try: - rhs_type = rhs.output_field.get_internal_type() + rhs_type = resolved.rhs.output_field.get_internal_type() except (AttributeError, FieldError): rhs_type = None if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type: return DurationExpression( - self.lhs, self.connector, self.rhs - ).resolve_expression( - query, - allow_joins, - reuse, - summarize, - for_save, + resolved.lhs, resolved.connector, resolved.rhs ) datetime_fields = {"DateField", "DateTimeField", "TimeField"} if ( @@ -787,18 +786,8 @@ class CombinedExpression(SQLiteNumericMixin, Expression): and lhs_type in datetime_fields and lhs_type == rhs_type ): - return TemporalSubtraction(self.lhs, self.rhs).resolve_expression( - query, - allow_joins, - reuse, - summarize, - for_save, - ) - c = self.copy() - c.is_summary = summarize - c.lhs = lhs - c.rhs = rhs - return c + return TemporalSubtraction(resolved.lhs, resolved.rhs) + return resolved @cached_property def allowed_default(self): @@ -1070,23 +1059,6 @@ class Func(SQLiteNumericMixin, Expression): def set_source_expressions(self, exprs): self.source_expressions = exprs - def resolve_expression( - self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False - ): - c = self.copy() - c.is_summary = summarize - for pos, arg in enumerate(c.source_expressions): - c.source_expressions[pos] = arg.resolve_expression( - query, allow_joins, reuse, summarize, for_save - ) - if not self.allows_composite_expressions and any( - isinstance(expr, ColPairs) for expr in c.get_source_expressions() - ): - raise ValueError( - f"{self.__class__.__name__} does not support composite primary keys." - ) - return c - def as_sql( self, compiler, diff --git a/tests/composite_pk/test_aggregate.py b/tests/composite_pk/test_aggregate.py index e8751df0a3..d852fdce30 100644 --- a/tests/composite_pk/test_aggregate.py +++ b/tests/composite_pk/test_aggregate.py @@ -138,6 +138,6 @@ class CompositePKAggregateTests(TestCase): ) def test_max_pk(self): - msg = "Max does not support composite primary keys." + msg = "Max expression does not support composite primary keys." with self.assertRaisesMessage(ValueError, msg): Comment.objects.aggregate(Max("pk")) diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py index 4edf947423..aa2d9ebe36 100644 --- a/tests/composite_pk/test_filter.py +++ b/tests/composite_pk/test_filter.py @@ -63,7 +63,7 @@ class CompositePKFilterTests(TestCase): Comment.objects.filter(text__gt=F("pk")).count() def test_rhs_combinable(self): - msg = "CompositePrimaryKey is not combinable." + msg = "CombinedExpression expression does not support composite primary keys." for expr in [F("pk") + (1, 1), (1, 1) + F("pk")]: with ( self.subTest(expression=expr), @@ -405,7 +405,7 @@ class CompositePKFilterTests(TestCase): self.assertSequenceEqual(queryset, (self.user_2,)) def test_cannot_cast_pk(self): - msg = "Cast does not support composite primary keys." + msg = "Cast expression does not support composite primary keys." with self.assertRaisesMessage(ValueError, msg): Comment.objects.filter(text__gt=Cast(F("pk"), TextField())).count()