mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #24485 -- Allowed combined expressions to set output_field
This commit is contained in:
		| @@ -2,7 +2,9 @@ from functools import wraps | |||||||
|  |  | ||||||
| from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured  # NOQA | from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured  # NOQA | ||||||
| from django.db.models.query import Q, QuerySet, Prefetch  # NOQA | from django.db.models.query import Q, QuerySet, Prefetch  # NOQA | ||||||
| from django.db.models.expressions import Expression, F, Value, Func, Case, When  # NOQA | from django.db.models.expressions import (  # NOQA | ||||||
|  |     Expression, ExpressionWrapper, F, Value, Func, Case, When, | ||||||
|  | ) | ||||||
| from django.db.models.manager import Manager  # NOQA | from django.db.models.manager import Manager  # NOQA | ||||||
| from django.db.models.base import Model  # NOQA | from django.db.models.base import Model  # NOQA | ||||||
| from django.db.models.aggregates import *  # NOQA | from django.db.models.aggregates import *  # NOQA | ||||||
|   | |||||||
| @@ -126,12 +126,12 @@ class BaseExpression(object): | |||||||
|     # aggregate specific fields |     # aggregate specific fields | ||||||
|     is_summary = False |     is_summary = False | ||||||
|  |  | ||||||
|     def get_db_converters(self, connection): |  | ||||||
|         return [self.convert_value] + self.output_field.get_db_converters(connection) |  | ||||||
|  |  | ||||||
|     def __init__(self, output_field=None): |     def __init__(self, output_field=None): | ||||||
|         self._output_field = output_field |         self._output_field = output_field | ||||||
|  |  | ||||||
|  |     def get_db_converters(self, connection): | ||||||
|  |         return [self.convert_value] + self.output_field.get_db_converters(connection) | ||||||
|  |  | ||||||
|     def get_source_expressions(self): |     def get_source_expressions(self): | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
| @@ -656,6 +656,29 @@ class Ref(Expression): | |||||||
|         return [self] |         return [self] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ExpressionWrapper(Expression): | ||||||
|  |     """ | ||||||
|  |     An expression that can wrap another expression so that it can provide | ||||||
|  |     extra context to the inner expression, such as the output_field. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, expression, output_field): | ||||||
|  |         super(ExpressionWrapper, self).__init__(output_field=output_field) | ||||||
|  |         self.expression = expression | ||||||
|  |  | ||||||
|  |     def set_source_expressions(self, exprs): | ||||||
|  |         self.expression = exprs[0] | ||||||
|  |  | ||||||
|  |     def get_source_expressions(self): | ||||||
|  |         return [self.expression] | ||||||
|  |  | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         return self.expression.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{}({})".format(self.__class__.__name__, self.expression) | ||||||
|  |  | ||||||
|  |  | ||||||
| class When(Expression): | class When(Expression): | ||||||
|     template = 'WHEN %(condition)s THEN %(result)s' |     template = 'WHEN %(condition)s THEN %(result)s' | ||||||
|  |  | ||||||
|   | |||||||
| @@ -161,6 +161,27 @@ values, rather than on Python values. | |||||||
| This is documented in :ref:`using F() expressions in queries | This is documented in :ref:`using F() expressions in queries | ||||||
| <using-f-expressions-in-filters>`. | <using-f-expressions-in-filters>`. | ||||||
|  |  | ||||||
|  | .. _using-f-with-annotations: | ||||||
|  |  | ||||||
|  | Using ``F()`` with annotations | ||||||
|  | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
|  | ``F()`` can be used to create dynamic fields on your models by combining | ||||||
|  | different fields with arithmetic:: | ||||||
|  |  | ||||||
|  |     company = Company.objects.annotate( | ||||||
|  |         chairs_needed=F('num_employees') - F('num_chairs')) | ||||||
|  |  | ||||||
|  | If the fields that you're combining are of different types you'll need | ||||||
|  | to tell Django what kind of field will be returned. Since ``F()`` does not | ||||||
|  | directly support ``output_field`` you will need to wrap the expression with | ||||||
|  | :class:`ExpressionWrapper`:: | ||||||
|  |  | ||||||
|  |     from django.db.models import DateTimeField, ExpressionWrapper, F | ||||||
|  |  | ||||||
|  |     Ticket.objects.annotate( | ||||||
|  |         expires=ExpressionWrapper( | ||||||
|  |             F('active_at') + F('duration'), output_field=DateTimeField())) | ||||||
|  |  | ||||||
| .. _func-expressions: | .. _func-expressions: | ||||||
|  |  | ||||||
| @@ -274,17 +295,6 @@ should define the desired ``output_field``. For example, adding an | |||||||
| ``IntegerField()`` and a ``FloatField()`` together should probably have | ``IntegerField()`` and a ``FloatField()`` together should probably have | ||||||
| ``output_field=FloatField()`` defined. | ``output_field=FloatField()`` defined. | ||||||
|  |  | ||||||
| .. note:: |  | ||||||
|  |  | ||||||
|     When you need to define the ``output_field`` for ``F`` expression |  | ||||||
|     arithmetic between different types, it's necessary to surround the |  | ||||||
|     expression in another expression:: |  | ||||||
|  |  | ||||||
|         from django.db.models import DateTimeField, Expression, F |  | ||||||
|  |  | ||||||
|         Race.objects.annotate(finish=Expression( |  | ||||||
|             F('start') + F('duration'), output_field=DateTimeField())) |  | ||||||
|  |  | ||||||
| .. versionchanged:: 1.8 | .. versionchanged:: 1.8 | ||||||
|  |  | ||||||
|     ``output_field`` is a new parameter. |     ``output_field`` is a new parameter. | ||||||
| @@ -343,6 +353,19 @@ instantiating the model field as any arguments relating to data validation | |||||||
| (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's | (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's | ||||||
| output value. | output value. | ||||||
|  |  | ||||||
|  | ``ExpressionWrapper()`` expressions | ||||||
|  | ----------------------------------- | ||||||
|  |  | ||||||
|  | .. class:: ExpressionWrapper(expression, output_field) | ||||||
|  |  | ||||||
|  | .. versionadded:: 1.8 | ||||||
|  |  | ||||||
|  | ``ExpressionWrapper`` simply surrounds another expression and provides access | ||||||
|  | to properties, such as ``output_field``, that may not be available on other | ||||||
|  | expressions. ``ExpressionWrapper`` is necessary when using arithmetic on | ||||||
|  | ``F()`` expressions with different types as described in | ||||||
|  | :ref:`using-f-with-annotations`. | ||||||
|  |  | ||||||
| Conditional expressions | Conditional expressions | ||||||
| ----------------------- | ----------------------- | ||||||
|  |  | ||||||
|   | |||||||
| @@ -84,3 +84,12 @@ class Company(models.Model): | |||||||
|         return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)' |         return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)' | ||||||
|             % (self.name, self.motto, self.ticker_name, self.description) |             % (self.name, self.motto, self.ticker_name, self.description) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @python_2_unicode_compatible | ||||||
|  | class Ticket(models.Model): | ||||||
|  |     active_at = models.DateTimeField() | ||||||
|  |     duration = models.DurationField() | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return '{} - {}'.format(self.active_at, self.duration) | ||||||
|   | |||||||
| @@ -5,13 +5,14 @@ from decimal import Decimal | |||||||
|  |  | ||||||
| from django.core.exceptions import FieldDoesNotExist, FieldError | from django.core.exceptions import FieldDoesNotExist, FieldError | ||||||
| from django.db.models import ( | from django.db.models import ( | ||||||
|     F, BooleanField, CharField, Count, Func, IntegerField, Sum, Value, |     F, BooleanField, CharField, Count, DateTimeField, ExpressionWrapper, Func, | ||||||
|  |     IntegerField, Sum, Value, | ||||||
| ) | ) | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.utils import six | from django.utils import six | ||||||
|  |  | ||||||
| from .models import ( | from .models import ( | ||||||
|     Author, Book, Company, DepartmentStore, Employee, Publisher, Store, |     Author, Book, Company, DepartmentStore, Employee, Publisher, Store, Ticket, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -135,6 +136,24 @@ class NonAggregateAnnotationTestCase(TestCase): | |||||||
|         for book in books: |         for book in books: | ||||||
|             self.assertEqual(book.num_awards, book.publisher.num_awards) |             self.assertEqual(book.num_awards, book.publisher.num_awards) | ||||||
|  |  | ||||||
|  |     def test_mixed_type_annotation_date_interval(self): | ||||||
|  |         active = datetime.datetime(2015, 3, 20, 14, 0, 0) | ||||||
|  |         duration = datetime.timedelta(hours=1) | ||||||
|  |         expires = datetime.datetime(2015, 3, 20, 14, 0, 0) + duration | ||||||
|  |         Ticket.objects.create(active_at=active, duration=duration) | ||||||
|  |         t = Ticket.objects.annotate( | ||||||
|  |             expires=ExpressionWrapper(F('active_at') + F('duration'), output_field=DateTimeField()) | ||||||
|  |         ).first() | ||||||
|  |         self.assertEqual(t.expires, expires) | ||||||
|  |  | ||||||
|  |     def test_mixed_type_annotation_numbers(self): | ||||||
|  |         test = self.b1 | ||||||
|  |         b = Book.objects.annotate( | ||||||
|  |             combined=ExpressionWrapper(F('pages') + F('rating'), output_field=IntegerField()) | ||||||
|  |         ).get(isbn=test.isbn) | ||||||
|  |         combined = int(test.pages + test.rating) | ||||||
|  |         self.assertEqual(b.combined, combined) | ||||||
|  |  | ||||||
|     def test_annotate_with_aggregation(self): |     def test_annotate_with_aggregation(self): | ||||||
|         books = Book.objects.annotate( |         books = Book.objects.annotate( | ||||||
|             is_book=Value(1, output_field=IntegerField()), |             is_book=Value(1, output_field=IntegerField()), | ||||||
|   | |||||||
| @@ -11,8 +11,8 @@ from django.db.models.aggregates import ( | |||||||
|     Avg, Count, Max, Min, StdDev, Sum, Variance, |     Avg, Count, Max, Min, StdDev, Sum, Variance, | ||||||
| ) | ) | ||||||
| from django.db.models.expressions import ( | from django.db.models.expressions import ( | ||||||
|     F, Case, Col, Date, DateTime, Func, OrderBy, Random, RawSQL, Ref, Value, |     F, Case, Col, Date, DateTime, ExpressionWrapper, Func, OrderBy, Random, | ||||||
|     When, |     RawSQL, Ref, Value, When, | ||||||
| ) | ) | ||||||
| from django.db.models.functions import ( | from django.db.models.functions import ( | ||||||
|     Coalesce, Concat, Length, Lower, Substr, Upper, |     Coalesce, Concat, Length, Lower, Substr, Upper, | ||||||
| @@ -855,6 +855,10 @@ class ReprTests(TestCase): | |||||||
|         self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, %s)" % utc) |         self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, %s)" % utc) | ||||||
|         self.assertEqual(repr(F('published')), "F(published)") |         self.assertEqual(repr(F('published')), "F(published)") | ||||||
|         self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>") |         self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>") | ||||||
|  |         self.assertEqual( | ||||||
|  |             repr(ExpressionWrapper(F('cost') + F('tax'), models.IntegerField())), | ||||||
|  |             "ExpressionWrapper(F(cost) + F(tax))" | ||||||
|  |         ) | ||||||
|         self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)") |         self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)") | ||||||
|         self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)') |         self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)') | ||||||
|         self.assertEqual(repr(Random()), "Random()") |         self.assertEqual(repr(Random()), "Random()") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user