diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index 91835a9ca3..5bbf29e8ab 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -37,7 +37,7 @@ class BoolOr(Aggregate): class JSONBAgg(Aggregate): function = 'JSONB_AGG' - _output_field = JSONField() + output_field = JSONField() def convert_value(self, value, expression, connection, context): if not value: diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py index a3a3381745..a06187c4bc 100644 --- a/django/contrib/postgres/fields/jsonb.py +++ b/django/contrib/postgres/fields/jsonb.py @@ -115,7 +115,7 @@ class KeyTransform(Transform): class KeyTextTransform(KeyTransform): operator = '->>' nested_operator = '#>>' - _output_field = TextField() + output_field = TextField() class KeyTransformTextLookupMixin: diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index 9d66976ae0..cc47dbfeb6 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -47,7 +47,7 @@ class SearchVectorCombinable: class SearchVector(SearchVectorCombinable, Func): function = 'to_tsvector' arg_joiner = " || ' ' || " - _output_field = SearchVectorField() + output_field = SearchVectorField() config = None def __init__(self, *expressions, **extra): @@ -125,7 +125,7 @@ class SearchQueryCombinable: class SearchQuery(SearchQueryCombinable, Value): - _output_field = SearchQueryField() + output_field = SearchQueryField() def __init__(self, value, output_field=None, *, config=None, invert=False): self.config = config @@ -170,7 +170,7 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression): class SearchRank(Func): function = 'ts_rank' - _output_field = FloatField() + output_field = FloatField() def __init__(self, vector, query, **extra): if not hasattr(vector, 'resolve_expression'): diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 7e5c7313ce..dc2be87453 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -125,11 +125,11 @@ class BaseExpression: # aggregate specific fields is_summary = False - _output_field = None + _output_field_resolved_to_none = False def __init__(self, output_field=None): if output_field is not None: - self._output_field = output_field + self.output_field = output_field def get_db_converters(self, connection): return [self.convert_value] + self.output_field.get_db_converters(connection) @@ -223,21 +223,23 @@ class BaseExpression: @cached_property def output_field(self): """Return the output type of this expressions.""" - if self._output_field_or_none is None: - raise FieldError("Cannot resolve expression type, unknown output_field") - return self._output_field_or_none + output_field = self._resolve_output_field() + if output_field is None: + self._output_field_resolved_to_none = True + raise FieldError('Cannot resolve expression type, unknown output_field') + return output_field @cached_property def _output_field_or_none(self): """ - Return the output field of this expression, or None if no output type - can be resolved. Note that the 'output_field' property will raise - FieldError if no type can be resolved, but this attribute allows for - None values. + Return the output field of this expression, or None if + _resolve_output_field() didn't return an output type. """ - if self._output_field is None: - self._output_field = self._resolve_output_field() - return self._output_field + try: + return self.output_field + except FieldError: + if not self._output_field_resolved_to_none: + raise def _resolve_output_field(self): """ @@ -249,9 +251,9 @@ class BaseExpression: the type here is a convenience for the common case. The user should supply their own output_field with more complex computations. - If a source does not have an `_output_field` then we exclude it from - this check. If all sources are `None`, then an error will be thrown - higher up the stack in the `output_field` property. + If a source's output field resolves to None, exclude it from this check. + If all sources are None, then an error is raised higher up the stack in + the output_field property. """ sources_iter = (source for source in self.get_source_fields() if source is not None) for output_field in sources_iter: @@ -603,14 +605,14 @@ class Value(Expression): def as_sql(self, compiler, connection): connection.ops.check_expression_support(self) val = self.value - # check _output_field to avoid triggering an exception - if self._output_field is not None: + output_field = self._output_field_or_none + if output_field is not None: if self.for_save: - val = self.output_field.get_db_prep_save(val, connection=connection) + val = output_field.get_db_prep_save(val, connection=connection) else: - val = self.output_field.get_db_prep_value(val, connection=connection) - if hasattr(self._output_field, 'get_placeholder'): - return self._output_field.get_placeholder(val, compiler, connection), [val] + val = output_field.get_db_prep_value(val, connection=connection) + if hasattr(output_field, 'get_placeholder'): + return output_field.get_placeholder(val, compiler, connection), [val] if val is None: # cx_Oracle does not always convert None to the appropriate # NULL type (like in case expressions using numbers), so we @@ -652,7 +654,7 @@ class RawSQL(Expression): return [self] def __hash__(self): - h = hash(self.sql) ^ hash(self._output_field) + h = hash(self.sql) ^ hash(self.output_field) for param in self.params: h ^= hash(param) return h @@ -998,7 +1000,7 @@ class Exists(Subquery): super().__init__(*args, **kwargs) def __invert__(self): - return type(self)(self.queryset, self.output_field, negated=(not self.negated), **self.extra) + return type(self)(self.queryset, negated=(not self.negated), **self.extra) @property def output_field(self): diff --git a/django/db/models/functions/base.py b/django/db/models/functions/base.py index c487bb4ab5..82a6083c58 100644 --- a/django/db/models/functions/base.py +++ b/django/db/models/functions/base.py @@ -24,12 +24,12 @@ class Cast(Func): def as_sql(self, compiler, connection, **extra_context): if 'db_type' not in extra_context: - extra_context['db_type'] = self._output_field.db_type(connection) + extra_context['db_type'] = self.output_field.db_type(connection) return super().as_sql(compiler, connection, **extra_context) def as_mysql(self, compiler, connection): extra_context = {} - output_field_class = type(self._output_field) + output_field_class = type(self.output_field) if output_field_class in self.mysql_types: extra_context['db_type'] = self.mysql_types[output_field_class] return self.as_sql(compiler, connection, **extra_context) diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py index a56731e48f..52f1f73ae8 100644 --- a/django/db/models/functions/datetime.py +++ b/django/db/models/functions/datetime.py @@ -243,9 +243,8 @@ class TruncDate(TruncBase): kind = 'date' lookup_name = 'date' - @cached_property - def output_field(self): - return DateField() + def __init__(self, *args, output_field=None, **kwargs): + super().__init__(*args, output_field=DateField(), **kwargs) def as_sql(self, compiler, connection): # Cast to date rather than truncate to date. @@ -259,9 +258,8 @@ class TruncTime(TruncBase): kind = 'time' lookup_name = 'time' - @cached_property - def output_field(self): - return TimeField() + def __init__(self, *args, output_field=None, **kwargs): + super().__init__(*args, output_field=TimeField(), **kwargs) def as_sql(self, compiler, connection): # Cast to date rather than truncate to date. diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index 8b41076aa7..81162e549f 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -551,6 +551,9 @@ Miscellaneous in the cache backend as an intermediate class in ``CacheKeyWarning``'s inheritance of ``RuntimeWarning``. +* Renamed ``BaseExpression._output_field`` to ``output_field``. You may need + to update custom expressions. + .. _deprecated-features-2.0: Features deprecated in 2.0 diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index ed3e2ce628..532690ca8d 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -532,6 +532,16 @@ class BasicExpressionsTests(TestCase): outer = Company.objects.filter(pk__in=Subquery(inner.values('pk'))) self.assertFalse(outer.exists()) + def test_explicit_output_field(self): + class FuncA(Func): + output_field = models.CharField() + + class FuncB(Func): + pass + + expr = FuncB(FuncA()) + self.assertEqual(expr.output_field, FuncA.output_field) + class IterableLookupInnerExpressionsTests(TestCase): @classmethod