mirror of
				https://github.com/django/django.git
				synced 2025-10-25 22:56:12 +00:00 
			
		
		
		
	Fixed #30446 -- Resolved Value.output_field for stdlib types.
This required implementing a limited form of dynamic dispatch to combine expressions with numerical output. Refs #26355 should eventually provide a better interface for that.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							d08e6f55e3
						
					
				
				
					commit
					1e38f1191d
				
			| @@ -101,10 +101,13 @@ class SQLiteDecimalToFloatMixin: | |||||||
|     is not acceptable by the GIS functions expecting numeric values. |     is not acceptable by the GIS functions expecting numeric values. | ||||||
|     """ |     """ | ||||||
|     def as_sqlite(self, compiler, connection, **extra_context): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         for expr in self.get_source_expressions(): |         copy = self.copy() | ||||||
|             if hasattr(expr, 'value') and isinstance(expr.value, Decimal): |         copy.set_source_expressions([ | ||||||
|                 expr.value = float(expr.value) |             Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal) | ||||||
|         return super().as_sql(compiler, connection, **extra_context) |             else expr | ||||||
|  |             for expr in copy.get_source_expressions() | ||||||
|  |         ]) | ||||||
|  |         return copy.as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class OracleToleranceMixin: | class OracleToleranceMixin: | ||||||
|   | |||||||
| @@ -173,8 +173,7 @@ class DateTimeRangeContains(PostgresOperatorLookup): | |||||||
|     def process_rhs(self, compiler, connection): |     def process_rhs(self, compiler, connection): | ||||||
|         # Transform rhs value for db lookup. |         # Transform rhs value for db lookup. | ||||||
|         if isinstance(self.rhs, datetime.date): |         if isinstance(self.rhs, datetime.date): | ||||||
|             output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField() |             value = models.Value(self.rhs) | ||||||
|             value = models.Value(self.rhs, output_field=output_field) |  | ||||||
|             self.rhs = value.resolve_expression(compiler.query) |             self.rhs = value.resolve_expression(compiler.query) | ||||||
|         return super().process_rhs(compiler, connection) |         return super().process_rhs(compiler, connection) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,7 +1,9 @@ | |||||||
| import copy | import copy | ||||||
| import datetime | import datetime | ||||||
|  | import functools | ||||||
| import inspect | import inspect | ||||||
| from decimal import Decimal | from decimal import Decimal | ||||||
|  | from uuid import UUID | ||||||
|  |  | ||||||
| from django.core.exceptions import EmptyResultSet, FieldError | from django.core.exceptions import EmptyResultSet, FieldError | ||||||
| from django.db import NotSupportedError, connection | from django.db import NotSupportedError, connection | ||||||
| @@ -56,12 +58,7 @@ class Combinable: | |||||||
|     def _combine(self, other, connector, reversed): |     def _combine(self, other, connector, reversed): | ||||||
|         if not hasattr(other, 'resolve_expression'): |         if not hasattr(other, 'resolve_expression'): | ||||||
|             # everything must be resolvable to an expression |             # everything must be resolvable to an expression | ||||||
|             output_field = ( |             other = Value(other) | ||||||
|                 fields.DurationField() |  | ||||||
|                 if isinstance(other, datetime.timedelta) else |  | ||||||
|                 None |  | ||||||
|             ) |  | ||||||
|             other = Value(other, output_field=output_field) |  | ||||||
|  |  | ||||||
|         if reversed: |         if reversed: | ||||||
|             return CombinedExpression(other, connector, self) |             return CombinedExpression(other, connector, self) | ||||||
| @@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable): | |||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _connector_combinators = { | ||||||
|  |     connector: [ | ||||||
|  |         (fields.IntegerField, fields.DecimalField, fields.DecimalField), | ||||||
|  |         (fields.DecimalField, fields.IntegerField, fields.DecimalField), | ||||||
|  |         (fields.IntegerField, fields.FloatField, fields.FloatField), | ||||||
|  |         (fields.FloatField, fields.IntegerField, fields.FloatField), | ||||||
|  |     ] | ||||||
|  |     for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV) | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @functools.lru_cache(maxsize=128) | ||||||
|  | def _resolve_combined_type(connector, lhs_type, rhs_type): | ||||||
|  |     combinators = _connector_combinators.get(connector, ()) | ||||||
|  |     for combinator_lhs_type, combinator_rhs_type, combined_type in combinators: | ||||||
|  |         if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type): | ||||||
|  |             return combined_type | ||||||
|  |  | ||||||
|  |  | ||||||
| class CombinedExpression(SQLiteNumericMixin, Expression): | class CombinedExpression(SQLiteNumericMixin, Expression): | ||||||
|  |  | ||||||
|     def __init__(self, lhs, connector, rhs, output_field=None): |     def __init__(self, lhs, connector, rhs, output_field=None): | ||||||
| @@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression): | |||||||
|     def set_source_expressions(self, exprs): |     def set_source_expressions(self, exprs): | ||||||
|         self.lhs, self.rhs = exprs |         self.lhs, self.rhs = exprs | ||||||
|  |  | ||||||
|  |     def _resolve_output_field(self): | ||||||
|  |         try: | ||||||
|  |             return super()._resolve_output_field() | ||||||
|  |         except FieldError: | ||||||
|  |             combined_type = _resolve_combined_type( | ||||||
|  |                 self.connector, | ||||||
|  |                 type(self.lhs.output_field), | ||||||
|  |                 type(self.rhs.output_field), | ||||||
|  |             ) | ||||||
|  |             if combined_type is None: | ||||||
|  |                 raise | ||||||
|  |             return combined_type() | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|         expressions = [] |         expressions = [] | ||||||
|         expression_params = [] |         expression_params = [] | ||||||
| @@ -721,6 +750,30 @@ class Value(Expression): | |||||||
|     def get_group_by_cols(self, alias=None): |     def get_group_by_cols(self, alias=None): | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|  |     def _resolve_output_field(self): | ||||||
|  |         if isinstance(self.value, str): | ||||||
|  |             return fields.CharField() | ||||||
|  |         if isinstance(self.value, bool): | ||||||
|  |             return fields.BooleanField() | ||||||
|  |         if isinstance(self.value, int): | ||||||
|  |             return fields.IntegerField() | ||||||
|  |         if isinstance(self.value, float): | ||||||
|  |             return fields.FloatField() | ||||||
|  |         if isinstance(self.value, datetime.datetime): | ||||||
|  |             return fields.DateTimeField() | ||||||
|  |         if isinstance(self.value, datetime.date): | ||||||
|  |             return fields.DateField() | ||||||
|  |         if isinstance(self.value, datetime.time): | ||||||
|  |             return fields.TimeField() | ||||||
|  |         if isinstance(self.value, datetime.timedelta): | ||||||
|  |             return fields.DurationField() | ||||||
|  |         if isinstance(self.value, Decimal): | ||||||
|  |             return fields.DecimalField() | ||||||
|  |         if isinstance(self.value, bytes): | ||||||
|  |             return fields.BinaryField() | ||||||
|  |         if isinstance(self.value, UUID): | ||||||
|  |             return fields.UUIDField() | ||||||
|  |  | ||||||
|  |  | ||||||
| class RawSQL(Expression): | class RawSQL(Expression): | ||||||
|     def __init__(self, sql, params, output_field=None): |     def __init__(self, sql, params, output_field=None): | ||||||
| @@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression): | |||||||
|             copy.expression = Case( |             copy.expression = Case( | ||||||
|                 When(self.expression, then=True), |                 When(self.expression, then=True), | ||||||
|                 default=False, |                 default=False, | ||||||
|                 output_field=fields.BooleanField(), |  | ||||||
|             ) |             ) | ||||||
|             return copy.as_sql(compiler, connection) |             return copy.as_sql(compiler, connection) | ||||||
|         return self.as_sql(compiler, connection) |         return self.as_sql(compiler, connection) | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ from copy import copy | |||||||
| from django.core.exceptions import EmptyResultSet | from django.core.exceptions import EmptyResultSet | ||||||
| from django.db.models.expressions import Case, Exists, Func, Value, When | from django.db.models.expressions import Case, Exists, Func, Value, When | ||||||
| from django.db.models.fields import ( | from django.db.models.fields import ( | ||||||
|     BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField, |     CharField, DateTimeField, Field, IntegerField, UUIDField, | ||||||
| ) | ) | ||||||
| from django.db.models.query_utils import RegisterLookupMixin | from django.db.models.query_utils import RegisterLookupMixin | ||||||
| from django.utils.datastructures import OrderedSet | from django.utils.datastructures import OrderedSet | ||||||
| @@ -123,7 +123,7 @@ class Lookup: | |||||||
|         exprs = [] |         exprs = [] | ||||||
|         for expr in (self.lhs, self.rhs): |         for expr in (self.lhs, self.rhs): | ||||||
|             if isinstance(expr, Exists): |             if isinstance(expr, Exists): | ||||||
|                 expr = Case(When(expr, then=True), default=False, output_field=BooleanField()) |                 expr = Case(When(expr, then=True), default=False) | ||||||
|                 wrapped = True |                 wrapped = True | ||||||
|             exprs.append(expr) |             exprs.append(expr) | ||||||
|         lookup = type(self)(*exprs) if wrapped else self |         lookup = type(self)(*exprs) if wrapped else self | ||||||
|   | |||||||
| @@ -484,7 +484,15 @@ The ``output_field`` argument should be a model field instance, like | |||||||
| after it's retrieved from the database. Usually no arguments are needed when | after it's retrieved from the database. Usually no arguments are needed when | ||||||
| instantiating the model field as any arguments relating to data validation | instantiating the model field as any arguments relating to data validation | ||||||
| (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's | (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's | ||||||
| output value. | output value. If no ``output_field`` is specified it will be tentatively | ||||||
|  | inferred from the :py:class:`type` of the provided ``value``, if possible. For | ||||||
|  | example, passing an instance of :py:class:`datetime.datetime` as ``value`` | ||||||
|  | would default ``output_field`` to :class:`~django.db.models.DateTimeField`. | ||||||
|  |  | ||||||
|  | .. versionchanged:: 3.2 | ||||||
|  |  | ||||||
|  |   Support for inferring a default ``output_field`` from the type of ``value`` | ||||||
|  |   was added. | ||||||
|  |  | ||||||
| ``ExpressionWrapper()`` expressions | ``ExpressionWrapper()`` expressions | ||||||
| ----------------------------------- | ----------------------------------- | ||||||
|   | |||||||
| @@ -233,6 +233,15 @@ Models | |||||||
| * The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed | * The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed | ||||||
|   on MySQL 8.0.1+. |   on MySQL 8.0.1+. | ||||||
|  |  | ||||||
|  | * :class:`Value() <django.db.models.Value>` expression now | ||||||
|  |   automatically resolves its ``output_field`` to the appropriate | ||||||
|  |   :class:`Field <django.db.models.Field>` subclass based on the type of | ||||||
|  |   it's provided ``value`` for :py:class:`bool`, :py:class:`bytes`, | ||||||
|  |   :py:class:`float`, :py:class:`int`, :py:class:`str`, | ||||||
|  |   :py:class:`datetime.date`, :py:class:`datetime.datetime`, | ||||||
|  |   :py:class:`datetime.time`, :py:class:`datetime.timedelta`, | ||||||
|  |   :py:class:`decimal.Decimal`, and :py:class:`uuid.UUID` instances. | ||||||
|  |  | ||||||
| Requests and Responses | Requests and Responses | ||||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -848,10 +848,6 @@ class AggregateTestCase(TestCase): | |||||||
|         book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first() |         book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first() | ||||||
|         self.assertEqual(book.val, 2) |         self.assertEqual(book.val, 2) | ||||||
|  |  | ||||||
|     def test_missing_output_field_raises_error(self): |  | ||||||
|         with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'): |  | ||||||
|             Book.objects.annotate(val=Max(2)).first() |  | ||||||
|  |  | ||||||
|     def test_annotation_expressions(self): |     def test_annotation_expressions(self): | ||||||
|         authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') |         authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') | ||||||
|         authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name') |         authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name') | ||||||
| @@ -893,7 +889,7 @@ class AggregateTestCase(TestCase): | |||||||
|  |  | ||||||
|     def test_combine_different_types(self): |     def test_combine_different_types(self): | ||||||
|         msg = ( |         msg = ( | ||||||
|             'Expression contains mixed types: FloatField, IntegerField. ' |             'Expression contains mixed types: FloatField, DecimalField. ' | ||||||
|             'You must set output_field.' |             'You must set output_field.' | ||||||
|         ) |         ) | ||||||
|         qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')) |         qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')) | ||||||
|   | |||||||
| @@ -388,7 +388,7 @@ class AggregationTests(TestCase): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_annotated_conditional_aggregate(self): |     def test_annotated_conditional_aggregate(self): | ||||||
|         annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75) |         annotated_qs = Book.objects.annotate(discount_price=F('price') * Decimal('0.75')) | ||||||
|         self.assertAlmostEqual( |         self.assertAlmostEqual( | ||||||
|             annotated_qs.aggregate(test=Avg(Case( |             annotated_qs.aggregate(test=Avg(Case( | ||||||
|                 When(pages__lt=400, then='discount_price'), |                 When(pages__lt=400, then='discount_price'), | ||||||
|   | |||||||
| @@ -3,15 +3,17 @@ import pickle | |||||||
| import unittest | import unittest | ||||||
| import uuid | import uuid | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  | from decimal import Decimal | ||||||
| from unittest import mock | from unittest import mock | ||||||
|  |  | ||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db import DatabaseError, NotSupportedError, connection | from django.db import DatabaseError, NotSupportedError, connection | ||||||
| from django.db.models import ( | from django.db.models import ( | ||||||
|     Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField, |     Avg, BinaryField, BooleanField, Case, CharField, Count, DateField, | ||||||
|     DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F, |     DateTimeField, DecimalField, DurationField, Exists, Expression, | ||||||
|     Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev, |     ExpressionList, ExpressionWrapper, F, FloatField, Func, IntegerField, Max, | ||||||
|     Subquery, Sum, TimeField, UUIDField, Value, Variance, When, |     Min, Model, OrderBy, OuterRef, Q, StdDev, Subquery, Sum, TimeField, | ||||||
|  |     UUIDField, Value, Variance, When, | ||||||
| ) | ) | ||||||
| from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref | from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref | ||||||
| from django.db.models.functions import ( | from django.db.models.functions import ( | ||||||
| @@ -1711,6 +1713,30 @@ class ValueTests(TestCase): | |||||||
|         value = Value('foo', output_field=CharField()) |         value = Value('foo', output_field=CharField()) | ||||||
|         self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo'])) |         self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo'])) | ||||||
|  |  | ||||||
|  |     def test_resolve_output_field(self): | ||||||
|  |         value_types = [ | ||||||
|  |             ('str', CharField), | ||||||
|  |             (True, BooleanField), | ||||||
|  |             (42, IntegerField), | ||||||
|  |             (3.14, FloatField), | ||||||
|  |             (datetime.date(2019, 5, 15), DateField), | ||||||
|  |             (datetime.datetime(2019, 5, 15), DateTimeField), | ||||||
|  |             (datetime.time(3, 16), TimeField), | ||||||
|  |             (datetime.timedelta(1), DurationField), | ||||||
|  |             (Decimal('3.14'), DecimalField), | ||||||
|  |             (b'', BinaryField), | ||||||
|  |             (uuid.uuid4(), UUIDField), | ||||||
|  |         ] | ||||||
|  |         for value, ouput_field_type in value_types: | ||||||
|  |             with self.subTest(type=type(value)): | ||||||
|  |                 expr = Value(value) | ||||||
|  |                 self.assertIsInstance(expr.output_field, ouput_field_type) | ||||||
|  |  | ||||||
|  |     def test_resolve_output_field_failure(self): | ||||||
|  |         msg = 'Cannot resolve expression type, unknown output_field' | ||||||
|  |         with self.assertRaisesMessage(FieldError, msg): | ||||||
|  |             Value(object()).output_field | ||||||
|  |  | ||||||
|  |  | ||||||
| class FieldTransformTests(TestCase): | class FieldTransformTests(TestCase): | ||||||
|  |  | ||||||
| @@ -1848,7 +1874,9 @@ class ExpressionWrapperTests(SimpleTestCase): | |||||||
|         self.assertEqual(expr.get_group_by_cols(alias=None), []) |         self.assertEqual(expr.get_group_by_cols(alias=None), []) | ||||||
|  |  | ||||||
|     def test_non_empty_group_by(self): |     def test_non_empty_group_by(self): | ||||||
|         expr = ExpressionWrapper(Lower(Value('f')), output_field=IntegerField()) |         value = Value('f') | ||||||
|  |         value.output_field = None | ||||||
|  |         expr = ExpressionWrapper(Lower(value), output_field=IntegerField()) | ||||||
|         group_by_cols = expr.get_group_by_cols(alias=None) |         group_by_cols = expr.get_group_by_cols(alias=None) | ||||||
|         self.assertEqual(group_by_cols, [expr.expression]) |         self.assertEqual(group_by_cols, [expr.expression]) | ||||||
|         self.assertEqual(group_by_cols[0].output_field, expr.output_field) |         self.assertEqual(group_by_cols[0].output_field, expr.output_field) | ||||||
|   | |||||||
| @@ -1,7 +1,6 @@ | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from operator import attrgetter | from operator import attrgetter | ||||||
|  |  | ||||||
| from django.core.exceptions import FieldError |  | ||||||
| from django.db.models import ( | from django.db.models import ( | ||||||
|     CharField, DateTimeField, F, Max, OuterRef, Subquery, Value, |     CharField, DateTimeField, F, Max, OuterRef, Subquery, Value, | ||||||
| ) | ) | ||||||
| @@ -439,17 +438,6 @@ class OrderingTests(TestCase): | |||||||
|         qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline') |         qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline') | ||||||
|         self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1]) |         self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1]) | ||||||
|  |  | ||||||
|     def test_order_by_constant_value_without_output_field(self): |  | ||||||
|         msg = 'Cannot resolve expression type, unknown output_field' |  | ||||||
|         qs = Article.objects.annotate(constant=Value('1')).order_by('constant') |  | ||||||
|         for ordered_qs in ( |  | ||||||
|             qs, |  | ||||||
|             qs.values('headline'), |  | ||||||
|             Article.objects.order_by(Value('1')), |  | ||||||
|         ): |  | ||||||
|             with self.subTest(ordered_qs=ordered_qs), self.assertRaisesMessage(FieldError, msg): |  | ||||||
|                 ordered_qs.first() |  | ||||||
|  |  | ||||||
|     def test_related_ordering_duplicate_table_reference(self): |     def test_related_ordering_duplicate_table_reference(self): | ||||||
|         """ |         """ | ||||||
|         An ordering referencing a model with an ordering referencing a model |         An ordering referencing a model with an ordering referencing a model | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user