diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 8ab5de6db7..1f2d372ebb 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -101,10 +101,13 @@ class SQLiteDecimalToFloatMixin: is not acceptable by the GIS functions expecting numeric values. """ def as_sqlite(self, compiler, connection, **extra_context): - for expr in self.get_source_expressions(): - if hasattr(expr, 'value') and isinstance(expr.value, Decimal): - expr.value = float(expr.value) - return super().as_sql(compiler, connection, **extra_context) + copy = self.copy() + copy.set_source_expressions([ + Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal) + else expr + for expr in copy.get_source_expressions() + ]) + return copy.as_sql(compiler, connection, **extra_context) class OracleToleranceMixin: diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py index c2f24eb5ed..8eab2cd2d9 100644 --- a/django/contrib/postgres/fields/ranges.py +++ b/django/contrib/postgres/fields/ranges.py @@ -173,8 +173,7 @@ class DateTimeRangeContains(PostgresOperatorLookup): def process_rhs(self, compiler, connection): # Transform rhs value for db lookup. if isinstance(self.rhs, datetime.date): - output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField() - value = models.Value(self.rhs, output_field=output_field) + value = models.Value(self.rhs) self.rhs = value.resolve_expression(compiler.query) return super().process_rhs(compiler, connection) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 8cceb7d966..5b5a0ae4aa 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1,7 +1,9 @@ import copy import datetime +import functools import inspect from decimal import Decimal +from uuid import UUID from django.core.exceptions import EmptyResultSet, FieldError from django.db import NotSupportedError, connection @@ -56,12 +58,7 @@ class Combinable: def _combine(self, other, connector, reversed): if not hasattr(other, 'resolve_expression'): # everything must be resolvable to an expression - output_field = ( - fields.DurationField() - if isinstance(other, datetime.timedelta) else - None - ) - other = Value(other, output_field=output_field) + other = Value(other) if reversed: return CombinedExpression(other, connector, self) @@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable): pass +_connector_combinators = { + connector: [ + (fields.IntegerField, fields.DecimalField, fields.DecimalField), + (fields.DecimalField, fields.IntegerField, fields.DecimalField), + (fields.IntegerField, fields.FloatField, fields.FloatField), + (fields.FloatField, fields.IntegerField, fields.FloatField), + ] + for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV) +} + + +@functools.lru_cache(maxsize=128) +def _resolve_combined_type(connector, lhs_type, rhs_type): + combinators = _connector_combinators.get(connector, ()) + for combinator_lhs_type, combinator_rhs_type, combined_type in combinators: + if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type): + return combined_type + + class CombinedExpression(SQLiteNumericMixin, Expression): def __init__(self, lhs, connector, rhs, output_field=None): @@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression): def set_source_expressions(self, exprs): self.lhs, self.rhs = exprs + def _resolve_output_field(self): + try: + return super()._resolve_output_field() + except FieldError: + combined_type = _resolve_combined_type( + self.connector, + type(self.lhs.output_field), + type(self.rhs.output_field), + ) + if combined_type is None: + raise + return combined_type() + def as_sql(self, compiler, connection): expressions = [] expression_params = [] @@ -721,6 +750,30 @@ class Value(Expression): def get_group_by_cols(self, alias=None): return [] + def _resolve_output_field(self): + if isinstance(self.value, str): + return fields.CharField() + if isinstance(self.value, bool): + return fields.BooleanField() + if isinstance(self.value, int): + return fields.IntegerField() + if isinstance(self.value, float): + return fields.FloatField() + if isinstance(self.value, datetime.datetime): + return fields.DateTimeField() + if isinstance(self.value, datetime.date): + return fields.DateField() + if isinstance(self.value, datetime.time): + return fields.TimeField() + if isinstance(self.value, datetime.timedelta): + return fields.DurationField() + if isinstance(self.value, Decimal): + return fields.DecimalField() + if isinstance(self.value, bytes): + return fields.BinaryField() + if isinstance(self.value, UUID): + return fields.UUIDField() + class RawSQL(Expression): def __init__(self, sql, params, output_field=None): @@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression): copy.expression = Case( When(self.expression, then=True), default=False, - output_field=fields.BooleanField(), ) return copy.as_sql(compiler, connection) return self.as_sql(compiler, connection) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 79313ddd46..f358e50d5b 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -6,7 +6,7 @@ from copy import copy from django.core.exceptions import EmptyResultSet from django.db.models.expressions import Case, Exists, Func, Value, When from django.db.models.fields import ( - BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField, + CharField, DateTimeField, Field, IntegerField, UUIDField, ) from django.db.models.query_utils import RegisterLookupMixin from django.utils.datastructures import OrderedSet @@ -123,7 +123,7 @@ class Lookup: exprs = [] for expr in (self.lhs, self.rhs): if isinstance(expr, Exists): - expr = Case(When(expr, then=True), default=False, output_field=BooleanField()) + expr = Case(When(expr, then=True), default=False) wrapped = True exprs.append(expr) lookup = type(self)(*exprs) if wrapped else self diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 252966bb6c..31d2572288 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -484,7 +484,15 @@ The ``output_field`` argument should be a model field instance, like after it's retrieved from the database. Usually no arguments are needed when instantiating the model field as any arguments relating to data validation (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's -output value. +output value. If no ``output_field`` is specified it will be tentatively +inferred from the :py:class:`type` of the provided ``value``, if possible. For +example, passing an instance of :py:class:`datetime.datetime` as ``value`` +would default ``output_field`` to :class:`~django.db.models.DateTimeField`. + +.. versionchanged:: 3.2 + + Support for inferring a default ``output_field`` from the type of ``value`` + was added. ``ExpressionWrapper()`` expressions ----------------------------------- diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index e0e80323e3..aa25b0f511 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -233,6 +233,15 @@ Models * The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed on MySQL 8.0.1+. +* :class:`Value() ` expression now + automatically resolves its ``output_field`` to the appropriate + :class:`Field ` subclass based on the type of + it's provided ``value`` for :py:class:`bool`, :py:class:`bytes`, + :py:class:`float`, :py:class:`int`, :py:class:`str`, + :py:class:`datetime.date`, :py:class:`datetime.datetime`, + :py:class:`datetime.time`, :py:class:`datetime.timedelta`, + :py:class:`decimal.Decimal`, and :py:class:`uuid.UUID` instances. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index a8377c9a26..da78b8d9a9 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -848,10 +848,6 @@ class AggregateTestCase(TestCase): book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first() self.assertEqual(book.val, 2) - def test_missing_output_field_raises_error(self): - with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'): - Book.objects.annotate(val=Max(2)).first() - def test_annotation_expressions(self): authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name') @@ -893,7 +889,7 @@ class AggregateTestCase(TestCase): def test_combine_different_types(self): msg = ( - 'Expression contains mixed types: FloatField, IntegerField. ' + 'Expression contains mixed types: FloatField, DecimalField. ' 'You must set output_field.' ) qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')) diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index bdfcb1d89b..b298a1f132 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -388,7 +388,7 @@ class AggregationTests(TestCase): ) def test_annotated_conditional_aggregate(self): - annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75) + annotated_qs = Book.objects.annotate(discount_price=F('price') * Decimal('0.75')) self.assertAlmostEqual( annotated_qs.aggregate(test=Avg(Case( When(pages__lt=400, then='discount_price'), diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 42b8c8f34b..c89bb1d69e 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -3,15 +3,17 @@ import pickle import unittest import uuid from copy import deepcopy +from decimal import Decimal from unittest import mock from django.core.exceptions import FieldError from django.db import DatabaseError, NotSupportedError, connection from django.db.models import ( - Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField, - DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F, - Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev, - Subquery, Sum, TimeField, UUIDField, Value, Variance, When, + Avg, BinaryField, BooleanField, Case, CharField, Count, DateField, + DateTimeField, DecimalField, DurationField, Exists, Expression, + ExpressionList, ExpressionWrapper, F, FloatField, Func, IntegerField, Max, + Min, Model, OrderBy, OuterRef, Q, StdDev, Subquery, Sum, TimeField, + UUIDField, Value, Variance, When, ) from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref from django.db.models.functions import ( @@ -1711,6 +1713,30 @@ class ValueTests(TestCase): value = Value('foo', output_field=CharField()) self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo'])) + def test_resolve_output_field(self): + value_types = [ + ('str', CharField), + (True, BooleanField), + (42, IntegerField), + (3.14, FloatField), + (datetime.date(2019, 5, 15), DateField), + (datetime.datetime(2019, 5, 15), DateTimeField), + (datetime.time(3, 16), TimeField), + (datetime.timedelta(1), DurationField), + (Decimal('3.14'), DecimalField), + (b'', BinaryField), + (uuid.uuid4(), UUIDField), + ] + for value, ouput_field_type in value_types: + with self.subTest(type=type(value)): + expr = Value(value) + self.assertIsInstance(expr.output_field, ouput_field_type) + + def test_resolve_output_field_failure(self): + msg = 'Cannot resolve expression type, unknown output_field' + with self.assertRaisesMessage(FieldError, msg): + Value(object()).output_field + class FieldTransformTests(TestCase): @@ -1848,7 +1874,9 @@ class ExpressionWrapperTests(SimpleTestCase): self.assertEqual(expr.get_group_by_cols(alias=None), []) def test_non_empty_group_by(self): - expr = ExpressionWrapper(Lower(Value('f')), output_field=IntegerField()) + value = Value('f') + value.output_field = None + expr = ExpressionWrapper(Lower(value), output_field=IntegerField()) group_by_cols = expr.get_group_by_cols(alias=None) self.assertEqual(group_by_cols, [expr.expression]) self.assertEqual(group_by_cols[0].output_field, expr.output_field) diff --git a/tests/ordering/tests.py b/tests/ordering/tests.py index 61ec3a8592..fe319b3859 100644 --- a/tests/ordering/tests.py +++ b/tests/ordering/tests.py @@ -1,7 +1,6 @@ from datetime import datetime from operator import attrgetter -from django.core.exceptions import FieldError from django.db.models import ( CharField, DateTimeField, F, Max, OuterRef, Subquery, Value, ) @@ -439,17 +438,6 @@ class OrderingTests(TestCase): qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline') self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1]) - def test_order_by_constant_value_without_output_field(self): - msg = 'Cannot resolve expression type, unknown output_field' - qs = Article.objects.annotate(constant=Value('1')).order_by('constant') - for ordered_qs in ( - qs, - qs.values('headline'), - Article.objects.order_by(Value('1')), - ): - with self.subTest(ordered_qs=ordered_qs), self.assertRaisesMessage(FieldError, msg): - ordered_qs.first() - def test_related_ordering_duplicate_table_reference(self): """ An ordering referencing a model with an ordering referencing a model