From 1efea11808f7b8a3b31445e0c1c7d270f832d965 Mon Sep 17 00:00:00 2001 From: Luke Plant Date: Thu, 31 Mar 2022 08:10:22 +0200 Subject: [PATCH] Refs #33397 -- Added register_combinable_fields(). --- django/db/models/expressions.py | 65 ++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 10 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index edd644da54..32982777ef 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -2,6 +2,7 @@ import copy import datetime import functools import inspect +from collections import defaultdict from decimal import Decimal from uuid import UUID @@ -465,16 +466,60 @@ class Expression(BaseExpression, Combinable): return hash(self.identity) -_connector_combinators = { - connector: [ - (fields.IntegerField, fields.IntegerField, fields.IntegerField), - (fields.IntegerField, fields.DecimalField, fields.DecimalField), - (fields.DecimalField, fields.IntegerField, fields.DecimalField), - (fields.IntegerField, fields.FloatField, fields.FloatField), - (fields.FloatField, fields.IntegerField, fields.FloatField), - ] - for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV) -} +# Type inference for CombinedExpression.output_field. +_connector_combinations = [ + # Numeric operations - operands of same type. + { + connector: [ + (fields.IntegerField, fields.IntegerField, fields.IntegerField), + (fields.FloatField, fields.FloatField, fields.FloatField), + (fields.DecimalField, fields.DecimalField, fields.DecimalField), + ] + for connector in ( + Combinable.ADD, + Combinable.SUB, + Combinable.MUL, + # Behavior for DIV with integer arguments follows Postgres/SQLite, + # not MySQL/Oracle. + Combinable.DIV, + ) + }, + # Numeric operations - operands of different type. + { + connector: [ + (fields.IntegerField, fields.DecimalField, fields.DecimalField), + (fields.DecimalField, fields.IntegerField, fields.DecimalField), + (fields.IntegerField, fields.FloatField, fields.FloatField), + (fields.FloatField, fields.IntegerField, fields.FloatField), + ] + for connector in ( + Combinable.ADD, + Combinable.SUB, + Combinable.MUL, + Combinable.DIV, + ) + }, +] + +_connector_combinators = defaultdict(list) + + +def register_combinable_fields(lhs, connector, rhs, result): + """ + Register combinable types: + lhs rhs -> result + e.g. + register_combinable_fields( + IntegerField, Combinable.ADD, FloatField, FloatField + ) + """ + _connector_combinators[connector].append((lhs, rhs, result)) + + +for d in _connector_combinations: + for connector, field_types in d.items(): + for lhs, rhs, result in field_types: + register_combinable_fields(lhs, connector, rhs, result) @functools.lru_cache(maxsize=128)