1
0
mirror of https://github.com/django/django.git synced 2025-03-06 15:32:33 +00:00

[5.2.x] Refs #36042 -- Consolidated composite expression checks in BaseExpression.

Remove redundant Func.resolve_expression and adjust CombinedExpression to
delegate source expression resolving to super() to perform checks against
allows_composite_expressions in a single location.

Backport of a76035e925ff4e6d8676c65cb135c74b993b1039 from main.
This commit is contained in:
Simon Charette 2025-01-20 22:26:01 -05:00 committed by Sarah Boyce
parent ae2f5381fe
commit e306687a3a
3 changed files with 30 additions and 58 deletions

View File

@ -292,16 +292,22 @@ class BaseExpression:
""" """
c = self.copy() c = self.copy()
c.is_summary = summarize c.is_summary = summarize
c.set_source_expressions( source_expressions = [
[ (
( expr.resolve_expression(query, allow_joins, reuse, summarize)
expr.resolve_expression(query, allow_joins, reuse, summarize) if expr is not None
if expr else None
else None )
) for expr in c.get_source_expressions()
for expr in c.get_source_expressions() ]
] if not self.allows_composite_expressions and any(
) isinstance(expr, ColPairs) for expr in source_expressions
):
raise ValueError(
f"{self.__class__.__name__} expression does not support "
"composite primary keys."
)
c.set_source_expressions(source_expressions)
return c return c
@property @property
@ -754,32 +760,25 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def resolve_expression( def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
): ):
lhs = self.lhs.resolve_expression( resolved = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save query,
allow_joins,
reuse,
summarize,
for_save,
) )
rhs = self.rhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if isinstance(lhs, ColPairs) or isinstance(rhs, ColPairs):
raise ValueError("CompositePrimaryKey is not combinable.")
if not isinstance(self, (DurationExpression, TemporalSubtraction)): if not isinstance(self, (DurationExpression, TemporalSubtraction)):
try: try:
lhs_type = lhs.output_field.get_internal_type() lhs_type = resolved.lhs.output_field.get_internal_type()
except (AttributeError, FieldError): except (AttributeError, FieldError):
lhs_type = None lhs_type = None
try: try:
rhs_type = rhs.output_field.get_internal_type() rhs_type = resolved.rhs.output_field.get_internal_type()
except (AttributeError, FieldError): except (AttributeError, FieldError):
rhs_type = None rhs_type = None
if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type: if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
return DurationExpression( return DurationExpression(
self.lhs, self.connector, self.rhs resolved.lhs, resolved.connector, resolved.rhs
).resolve_expression(
query,
allow_joins,
reuse,
summarize,
for_save,
) )
datetime_fields = {"DateField", "DateTimeField", "TimeField"} datetime_fields = {"DateField", "DateTimeField", "TimeField"}
if ( if (
@ -787,18 +786,8 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
and lhs_type in datetime_fields and lhs_type in datetime_fields
and lhs_type == rhs_type and lhs_type == rhs_type
): ):
return TemporalSubtraction(self.lhs, self.rhs).resolve_expression( return TemporalSubtraction(resolved.lhs, resolved.rhs)
query, return resolved
allow_joins,
reuse,
summarize,
for_save,
)
c = self.copy()
c.is_summary = summarize
c.lhs = lhs
c.rhs = rhs
return c
@cached_property @cached_property
def allowed_default(self): def allowed_default(self):
@ -1070,23 +1059,6 @@ class Func(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.source_expressions = exprs self.source_expressions = exprs
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
c = self.copy()
c.is_summary = summarize
for pos, arg in enumerate(c.source_expressions):
c.source_expressions[pos] = arg.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if not self.allows_composite_expressions and any(
isinstance(expr, ColPairs) for expr in c.get_source_expressions()
):
raise ValueError(
f"{self.__class__.__name__} does not support composite primary keys."
)
return c
def as_sql( def as_sql(
self, self,
compiler, compiler,

View File

@ -138,6 +138,6 @@ class CompositePKAggregateTests(TestCase):
) )
def test_max_pk(self): def test_max_pk(self):
msg = "Max does not support composite primary keys." msg = "Max expression does not support composite primary keys."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
Comment.objects.aggregate(Max("pk")) Comment.objects.aggregate(Max("pk"))

View File

@ -63,7 +63,7 @@ class CompositePKFilterTests(TestCase):
Comment.objects.filter(text__gt=F("pk")).count() Comment.objects.filter(text__gt=F("pk")).count()
def test_rhs_combinable(self): def test_rhs_combinable(self):
msg = "CompositePrimaryKey is not combinable." msg = "CombinedExpression expression does not support composite primary keys."
for expr in [F("pk") + (1, 1), (1, 1) + F("pk")]: for expr in [F("pk") + (1, 1), (1, 1) + F("pk")]:
with ( with (
self.subTest(expression=expr), self.subTest(expression=expr),
@ -405,7 +405,7 @@ class CompositePKFilterTests(TestCase):
self.assertSequenceEqual(queryset, (self.user_2,)) self.assertSequenceEqual(queryset, (self.user_2,))
def test_cannot_cast_pk(self): def test_cannot_cast_pk(self):
msg = "Cast does not support composite primary keys." msg = "Cast expression does not support composite primary keys."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
Comment.objects.filter(text__gt=Cast(F("pk"), TextField())).count() Comment.objects.filter(text__gt=Cast(F("pk"), TextField())).count()