From 6a37e9bfae42eb180a8c978a3ade5b41c97c26b1 Mon Sep 17 00:00:00 2001 From: sharonwoo <29156885+sharonwoo@users.noreply.github.com> Date: Thu, 21 Mar 2024 12:53:45 +0800 Subject: [PATCH] Fixed #35257 -- Corrected resolving output_field for IntegerField/DecimalField with NULL. --- django/db/models/expressions.py | 16 +++++++++++----- tests/expressions/tests.py | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 198e094385..6032b4d1f4 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -5,6 +5,7 @@ import inspect from collections import defaultdict from decimal import Decimal from enum import Enum +from itertools import chain from types import NoneType from uuid import UUID @@ -597,10 +598,16 @@ _connector_combinations = [ }, # Numeric with NULL. { - connector: [ - (field_type, NoneType, field_type), - (NoneType, field_type, field_type), - ] + connector: list( + chain.from_iterable( + [(field_type, NoneType, field_type), (NoneType, field_type, field_type)] + for field_type in ( + fields.IntegerField, + fields.DecimalField, + fields.FloatField, + ) + ) + ) for connector in ( Combinable.ADD, Combinable.SUB, @@ -609,7 +616,6 @@ _connector_combinations = [ Combinable.MOD, Combinable.POW, ) - for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField) }, # Date/DateTimeField/DurationField/TimeField. { diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 909e317dca..f7233305a7 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -2654,6 +2654,29 @@ class CombinedExpressionTests(SimpleTestCase): with self.assertRaisesMessage(FieldError, msg): expr.output_field + def test_resolve_output_field_numbers_with_null(self): + test_values = [ + (3.14159, None, FloatField), + (None, 3.14159, FloatField), + (None, 42, IntegerField), + (42, None, IntegerField), + (None, Decimal("3.14"), DecimalField), + (Decimal("3.14"), None, DecimalField), + ] + connectors = [ + Combinable.ADD, + Combinable.SUB, + Combinable.MUL, + Combinable.DIV, + Combinable.MOD, + Combinable.POW, + ] + for lhs, rhs, expected_output_field in test_values: + for connector in connectors: + with self.subTest(lhs=lhs, connector=connector, rhs=rhs): + expr = CombinedExpression(Value(lhs), connector, Value(rhs)) + self.assertIsInstance(expr.output_field, expected_output_field) + def test_resolve_output_field_dates(self): tests = [ # Add - same type.