1
0
mirror of https://github.com/django/django.git synced 2025-03-06 15:32:33 +00:00

[5.2.x] Refs #36042 -- Consolidated composite expression checks in BaseExpression.

Remove redundant Func.resolve_expression and adjust CombinedExpression to
delegate source expression resolving to super() to perform checks against
allows_composite_expressions in a single location.

Backport of a76035e925ff4e6d8676c65cb135c74b993b1039 from main.
This commit is contained in:
Simon Charette 2025-01-20 22:26:01 -05:00 committed by Sarah Boyce
parent ae2f5381fe
commit e306687a3a
3 changed files with 30 additions and 58 deletions

View File

@ -292,16 +292,22 @@ class BaseExpression:
"""
c = self.copy()
c.is_summary = summarize
c.set_source_expressions(
[
source_expressions = [
(
expr.resolve_expression(query, allow_joins, reuse, summarize)
if expr
if expr is not None
else None
)
for expr in c.get_source_expressions()
]
if not self.allows_composite_expressions and any(
isinstance(expr, ColPairs) for expr in source_expressions
):
raise ValueError(
f"{self.__class__.__name__} expression does not support "
"composite primary keys."
)
c.set_source_expressions(source_expressions)
return c
@property
@ -754,51 +760,34 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
lhs = self.lhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
rhs = self.rhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if isinstance(lhs, ColPairs) or isinstance(rhs, ColPairs):
raise ValueError("CompositePrimaryKey is not combinable.")
if not isinstance(self, (DurationExpression, TemporalSubtraction)):
try:
lhs_type = lhs.output_field.get_internal_type()
except (AttributeError, FieldError):
lhs_type = None
try:
rhs_type = rhs.output_field.get_internal_type()
except (AttributeError, FieldError):
rhs_type = None
if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
return DurationExpression(
self.lhs, self.connector, self.rhs
).resolve_expression(
resolved = super().resolve_expression(
query,
allow_joins,
reuse,
summarize,
for_save,
)
if not isinstance(self, (DurationExpression, TemporalSubtraction)):
try:
lhs_type = resolved.lhs.output_field.get_internal_type()
except (AttributeError, FieldError):
lhs_type = None
try:
rhs_type = resolved.rhs.output_field.get_internal_type()
except (AttributeError, FieldError):
rhs_type = None
if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
return DurationExpression(
resolved.lhs, resolved.connector, resolved.rhs
)
datetime_fields = {"DateField", "DateTimeField", "TimeField"}
if (
self.connector == self.SUB
and lhs_type in datetime_fields
and lhs_type == rhs_type
):
return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
query,
allow_joins,
reuse,
summarize,
for_save,
)
c = self.copy()
c.is_summary = summarize
c.lhs = lhs
c.rhs = rhs
return c
return TemporalSubtraction(resolved.lhs, resolved.rhs)
return resolved
@cached_property
def allowed_default(self):
@ -1070,23 +1059,6 @@ class Func(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs):
self.source_expressions = exprs
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
c = self.copy()
c.is_summary = summarize
for pos, arg in enumerate(c.source_expressions):
c.source_expressions[pos] = arg.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if not self.allows_composite_expressions and any(
isinstance(expr, ColPairs) for expr in c.get_source_expressions()
):
raise ValueError(
f"{self.__class__.__name__} does not support composite primary keys."
)
return c
def as_sql(
self,
compiler,

View File

@ -138,6 +138,6 @@ class CompositePKAggregateTests(TestCase):
)
def test_max_pk(self):
msg = "Max does not support composite primary keys."
msg = "Max expression does not support composite primary keys."
with self.assertRaisesMessage(ValueError, msg):
Comment.objects.aggregate(Max("pk"))

View File

@ -63,7 +63,7 @@ class CompositePKFilterTests(TestCase):
Comment.objects.filter(text__gt=F("pk")).count()
def test_rhs_combinable(self):
msg = "CompositePrimaryKey is not combinable."
msg = "CombinedExpression expression does not support composite primary keys."
for expr in [F("pk") + (1, 1), (1, 1) + F("pk")]:
with (
self.subTest(expression=expr),
@ -405,7 +405,7 @@ class CompositePKFilterTests(TestCase):
self.assertSequenceEqual(queryset, (self.user_2,))
def test_cannot_cast_pk(self):
msg = "Cast does not support composite primary keys."
msg = "Cast expression does not support composite primary keys."
with self.assertRaisesMessage(ValueError, msg):
Comment.objects.filter(text__gt=Cast(F("pk"), TextField())).count()