From f59fd15c4928caf3dfcbd50f6ab47be409a43b01 Mon Sep 17 00:00:00 2001 From: Josh Smeaton Date: Thu, 26 Dec 2013 00:13:18 +1100 Subject: [PATCH] Fixed #14030 -- Allowed annotations to accept all expressions --- AUTHORS | 1 + django/contrib/contenttypes/fields.py | 2 +- django/contrib/gis/db/backends/base.py | 12 +- .../gis/db/backends/mysql/operations.py | 6 +- .../gis/db/backends/oracle/operations.py | 21 +- .../gis/db/backends/postgis/operations.py | 15 +- .../gis/db/backends/spatialite/operations.py | 9 +- django/contrib/gis/db/models/aggregates.py | 55 +- django/contrib/gis/db/models/fields.py | 21 +- django/contrib/gis/db/models/lookups.py | 13 +- django/contrib/gis/db/models/query.py | 2 +- .../contrib/gis/db/models/sql/aggregates.py | 13 +- django/contrib/gis/db/models/sql/compiler.py | 4 +- django/contrib/gis/db/models/sql/query.py | 25 +- django/db/backends/sqlite3/base.py | 6 +- django/db/models/__init__.py | 2 +- django/db/models/aggregates.py | 154 ++++-- django/db/models/expressions.py | 493 +++++++++++++++-- django/db/models/fields/__init__.py | 2 - django/db/models/fields/related.py | 2 +- django/db/models/query.py | 113 ++-- django/db/models/query_utils.py | 15 + django/db/models/sql/aggregates.py | 8 + django/db/models/sql/compiler.py | 94 ++-- django/db/models/sql/datastructures.py | 66 --- django/db/models/sql/expressions.py | 119 ---- django/db/models/sql/query.py | 384 +++++++------ django/db/models/sql/subqueries.py | 4 +- django/db/models/sql/where.py | 9 +- docs/index.txt | 3 +- docs/internals/deprecation.txt | 11 + docs/ref/models/expressions.txt | 522 ++++++++++++++++++ docs/ref/models/index.txt | 1 + docs/ref/models/queries.txt | 109 ---- docs/ref/models/querysets.txt | 101 +++- docs/releases/1.8.txt | 32 ++ docs/topics/db/aggregation.txt | 5 + tests/aggregation/tests.py | 279 +++++++++- tests/annotations/__init__.py | 0 tests/annotations/fixtures/annotations.json | 243 ++++++++ tests/annotations/models.py | 86 +++ tests/annotations/tests.py | 288 ++++++++++ tests/expressions/tests.py | 23 +- 43 files changed, 2572 insertions(+), 801 deletions(-) delete mode 100644 django/db/models/sql/expressions.py create mode 100644 docs/ref/models/expressions.txt create mode 100644 tests/annotations/__init__.py create mode 100644 tests/annotations/fixtures/annotations.json create mode 100644 tests/annotations/models.py create mode 100644 tests/annotations/tests.py diff --git a/AUTHORS b/AUTHORS index 475ac4585e..dd78347071 100644 --- a/AUTHORS +++ b/AUTHORS @@ -347,6 +347,7 @@ answer newbie questions, and generally made Django that much better: Jorge Bastida Jorge Gajon Joseph Kocherhans + Josh Smeaton Joshua Ginsberg Jozko Skrablin J. Pablo Fernandez diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index 201016ae6e..cc115b01aa 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -10,7 +10,7 @@ from django.db.models import signals, FieldDoesNotExist, DO_NOTHING from django.db.models.base import ModelBase from django.db.models.fields.related import ForeignObject, ForeignObjectRel from django.db.models.related import PathInfo -from django.db.models.sql.datastructures import Col +from django.db.models.expressions import Col from django.contrib.contenttypes.models import ContentType from django.utils.encoding import smart_text, python_2_unicode_compatible diff --git a/django/contrib/gis/db/backends/base.py b/django/contrib/gis/db/backends/base.py index c6e48ce0ef..0d2a9db870 100644 --- a/django/contrib/gis/db/backends/base.py +++ b/django/contrib/gis/db/backends/base.py @@ -186,7 +186,7 @@ class BaseSpatialOperations(object): """ raise NotImplementedError('Distance operations not available on this spatial backend.') - def get_geom_placeholder(self, f, value): + def get_geom_placeholder(self, f, value, qn): """ Returns the placeholder for the given geometry field with the given value. Depending on the spatial backend, the placeholder may contain a @@ -195,16 +195,6 @@ class BaseSpatialOperations(object): """ raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') - def get_expression_column(self, evaluator): - """ - Helper method to return the quoted column string from the evaluator - for its expression. - """ - for expr, col_tup in evaluator.cols: - if expr is evaluator.expression: - return '%s.%s' % tuple(map(self.quote_name, col_tup)) - raise Exception("Could not find the column for the expression.") - # Spatial SQL Construction def spatial_aggregate_sql(self, agg): raise NotImplementedError('Aggregate support not implemented for this spatial backend.') diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index bcbd634fd8..191e2c8956 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -35,14 +35,14 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations): def geo_db_type(self, f): return f.geom_type - def get_geom_placeholder(self, f, value): + def get_geom_placeholder(self, f, value, qn): """ The placeholder here has to include MySQL's WKT constructor. Because MySQL does not support spatial transformations, there is no need to modify the placeholder based on the contents of the given value. """ - if hasattr(value, 'expression'): - placeholder = self.get_expression_column(value) + if hasattr(value, 'as_sql'): + placeholder, _ = qn.compile(value) else: placeholder = '%s(%%s)' % self.from_text return placeholder diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index d4671e6cd6..aa002d3b82 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -9,7 +9,7 @@ """ import re -from django.db.backends.oracle.base import DatabaseOperations +from django.db.backends.oracle.base import DatabaseOperations, Database from django.contrib.gis.db.backends.base import BaseSpatialOperations from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter from django.contrib.gis.db.backends.utils import SpatialOperator @@ -145,9 +145,11 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): else: return None - def convert_geom(self, clob, geo_field): - if clob: - return Geometry(clob.read(), geo_field.srid) + def convert_geom(self, value, geo_field): + if value: + if isinstance(value, Database.LOB): + value = value.read() + return Geometry(value, geo_field.srid) else: return None @@ -184,7 +186,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): return [dist_param] - def get_geom_placeholder(self, f, value): + def get_geom_placeholder(self, f, value, qn): """ Provides a proper substitution value for Geometries that are not in the SRID of the field. Specifically, this routine will substitute in the @@ -196,14 +198,15 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): def transform_value(val, srid): return val.srid != srid - if hasattr(value, 'expression'): + if hasattr(value, 'as_sql'): if transform_value(value, f.srid): placeholder = '%s(%%s, %s)' % (self.transform, f.srid) else: placeholder = '%s' # No geometry value used for F expression, substitute in # the column name instead. - return placeholder % self.get_expression_column(value) + sql, _ = qn.compile(value) + return placeholder % sql else: if transform_value(value, f.srid): return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (self.transform, value.srid, f.srid) @@ -219,9 +222,9 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): if agg_name == 'union': agg_name += 'agg' if agg.is_extent: - sql_template = '%(function)s(%(field)s)' + sql_template = '%(function)s(%(expressions)s)' else: - sql_template = '%(function)s(SDOAGGRTYPE(%(field)s,%(tolerance)s))' + sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' sql_function = getattr(self, agg_name) return self.select % sql_template, sql_function diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 690194fa9e..d78b081950 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -22,7 +22,7 @@ class PostGISOperator(SpatialOperator): super(PostGISOperator, self).__init__(**kwargs) def as_sql(self, connection, lookup, *args): - if lookup.lhs.source.geography and not self.geography: + if lookup.lhs.output_field.geography and not self.geography: raise ValueError('PostGIS geography does not support the "%s" ' 'function/operator.' % (self.func or self.op,)) return super(PostGISOperator, self).as_sql(connection, lookup, *args) @@ -32,7 +32,7 @@ class PostGISDistanceOperator(PostGISOperator): sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s' def as_sql(self, connection, lookup, template_params, sql_params): - if not lookup.lhs.source.geography and lookup.lhs.source.geodetic(connection): + if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): sql_template = self.sql_template if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid': template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'}) @@ -215,7 +215,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): Converts the geometry returned from PostGIS aggretates. """ if hex: - return Geometry(hex) + return Geometry(hex, srid=geo_field.srid) else: return None @@ -284,7 +284,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): else: return [dist_param] - def get_geom_placeholder(self, f, value): + def get_geom_placeholder(self, f, value, qn): """ Provides a proper substitution value for Geometries that are not in the SRID of the field. Specifically, this routine will substitute in the @@ -296,11 +296,12 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): # Adding Transform() to the SQL placeholder. placeholder = '%s(%%s, %s)' % (self.transform, f.srid) - if hasattr(value, 'expression'): + if hasattr(value, 'as_sql'): # If this is an F expression, then we don't really want # a placeholder and instead substitute in the column # of the expression. - placeholder = placeholder % self.get_expression_column(value) + sql, _ = qn.compile(value) + placeholder = placeholder % sql return placeholder @@ -375,7 +376,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): agg_name = agg_name.lower() if agg_name == 'union': agg_name += 'agg' - sql_template = '%(function)s(%(field)s)' + sql_template = '%(function)s(%(expressions)s)' sql_function = getattr(self, agg_name) return sql_template, sql_function diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index d7a4912fbe..4ec98c402d 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -178,7 +178,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): dist_param = value return [dist_param] - def get_geom_placeholder(self, f, value): + def get_geom_placeholder(self, f, value, qn): """ Provides a proper substitution value for Geometries that are not in the SRID of the field. Specifically, this routine will substitute in the @@ -186,14 +186,15 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): """ def transform_value(value, srid): return not (value is None or value.srid == srid) - if hasattr(value, 'expression'): + if hasattr(value, 'as_sql'): if transform_value(value, f.srid): placeholder = '%s(%%s, %s)' % (self.transform, f.srid) else: placeholder = '%s' # No geometry value used for F expression, substitute in # the column name instead. - return placeholder % self.get_expression_column(value) + sql, _ = qn.compile(value) + return placeholder % sql else: if transform_value(value, f.srid): # Adding Transform() to the SQL placeholder. @@ -255,7 +256,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): agg_name = agg_name.lower() if agg_name == 'union': agg_name += 'agg' - sql_template = self.select % '%(function)s(%(field)s)' + sql_template = self.select % '%(function)s(%(expressions)s)' sql_function = getattr(self, agg_name) return sql_template, sql_function diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 43e9d1a0ae..0cf0a8b266 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -1,23 +1,66 @@ -from django.db.models import Aggregate +from django.db.models.aggregates import Aggregate +from django.contrib.gis.db.models.fields import GeometryField, ExtentField __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] -class Collect(Aggregate): +class GeoAggregate(Aggregate): + template = None + function = None + is_extent = False + + def as_sql(self, compiler, connection): + if connection.ops.oracle: + if not hasattr(self, 'tolerance'): + self.tolerance = 0.05 + self.extra['tolerance'] = self.tolerance + + template, function = connection.ops.spatial_aggregate_sql(self) + if template is None: + template = '%(function)s(%(expressions)s)' + self.extra['template'] = self.extra.get('template', template) + self.extra['function'] = self.extra.get('function', function) + return super(GeoAggregate, self).as_sql(compiler, connection) + + def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False): + c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize) + if not isinstance(self.expressions[0].output_field, GeometryField): + raise ValueError('Geospatial aggregates only allowed on geometry fields.') + return c + + def convert_value(self, value, connection): + return connection.ops.convert_geom(value, self.output_field) + + +class Collect(GeoAggregate): name = 'Collect' -class Extent(Aggregate): +class Extent(GeoAggregate): name = 'Extent' + is_extent = '2D' + + def __init__(self, expression, **extra): + super(Extent, self).__init__(expression, output_field=ExtentField(), **extra) + + def convert_value(self, value, connection): + return connection.ops.convert_extent(value) -class Extent3D(Aggregate): +class Extent3D(GeoAggregate): name = 'Extent3D' + is_extent = '3D' + + def __init__(self, expression, **extra): + super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra) + + def convert_value(self, value, connection): + return connection.ops.convert_extent3d(value) -class MakeLine(Aggregate): +class MakeLine(GeoAggregate): name = 'MakeLine' -class Union(Aggregate): +class Union(GeoAggregate): name = 'Union' diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index c538c18d63..1d64a06be2 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -1,5 +1,5 @@ from django.db.models.fields import Field -from django.db.models.sql.expressions import SQLEvaluator +from django.db.models.expressions import ExpressionNode from django.utils.translation import ugettext_lazy as _ from django.contrib.gis import forms from django.contrib.gis.db.models.lookups import gis_lookups @@ -165,7 +165,7 @@ class GeometryField(Field): returning to the caller. """ value = super(GeometryField, self).get_prep_value(value) - if isinstance(value, SQLEvaluator): + if isinstance(value, ExpressionNode): return value elif isinstance(value, (tuple, list)): geom = value[0] @@ -197,7 +197,7 @@ class GeometryField(Field): return geom def from_db_value(self, value, connection): - if value: + if value and not isinstance(value, Geometry): value = Geometry(value) return value @@ -259,7 +259,7 @@ class GeometryField(Field): pass else: params += value[1:] - elif isinstance(value, SQLEvaluator): + elif isinstance(value, ExpressionNode): params = [] else: params = [connection.ops.Adapter(value)] @@ -282,12 +282,12 @@ class GeometryField(Field): else: return connection.ops.Adapter(self.get_prep_value(value)) - def get_placeholder(self, value, connection): + def get_placeholder(self, value, qn, connection): """ Returns the placeholder for the geometry column for the given value. """ - return connection.ops.get_geom_placeholder(self, value) + return connection.ops.get_geom_placeholder(self, value, qn) for klass in gis_lookups.values(): @@ -335,3 +335,12 @@ class GeometryCollectionField(GeometryField): geom_type = 'GEOMETRYCOLLECTION' form_class = forms.GeometryCollectionField description = _("Geometry collection") + + +class ExtentField(Field): + "Used as a return value from an extent aggregate" + + description = _("Extent Aggregate Field") + + def get_internal_type(self): + return "ExtentField" diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index d3b53b32f0..889237751a 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -4,7 +4,7 @@ import re from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import FieldDoesNotExist from django.db.models.lookups import Lookup -from django.db.models.sql.expressions import SQLEvaluator +from django.db.models.expressions import ExpressionNode, Col from django.utils import six gis_lookups = {} @@ -68,18 +68,19 @@ class GISLookup(Lookup): rhs, rhs_params = super(GISLookup, self).process_rhs(qn, connection) geom = self.rhs - if isinstance(self.rhs, SQLEvaluator): + if isinstance(self.rhs, Col): # Make sure the F Expression destination field exists, and # set an `srid` attribute with the same as that of the # destination. - geo_fld = self._check_geo_field(self.rhs.opts, self.rhs.expression.name) - if not geo_fld: + geo_fld = self.rhs.output_field + if not hasattr(geo_fld, 'srid'): raise ValueError('No geographic field found in expression.') self.rhs.srid = geo_fld.srid + elif isinstance(self.rhs, ExpressionNode): + raise ValueError('Complex expressions not supported for GeometryField') elif isinstance(self.rhs, (list, tuple)): geom = self.rhs[0] - - rhs = connection.ops.get_geom_placeholder(self.lhs.source, geom) + rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, qn) return rhs, rhs_params def as_sql(self, qn, connection): diff --git a/django/contrib/gis/db/models/query.py b/django/contrib/gis/db/models/query.py index cbc1536b8a..f2e7657850 100644 --- a/django/contrib/gis/db/models/query.py +++ b/django/contrib/gis/db/models/query.py @@ -530,7 +530,7 @@ class GeoQuerySet(QuerySet): # transformation SQL. geom = geo_field.get_prep_value(settings['procedure_args'][name]) params = geo_field.get_db_prep_lookup('contains', geom, connection=connection) - geom_placeholder = geo_field.get_placeholder(geom, connection) + geom_placeholder = geo_field.get_placeholder(geom, None, connection) # Replacing the procedure format with that of any needed # transformation SQL. diff --git a/django/contrib/gis/db/models/sql/aggregates.py b/django/contrib/gis/db/models/sql/aggregates.py index c0a7d894eb..c3943eb9f6 100644 --- a/django/contrib/gis/db/models/sql/aggregates.py +++ b/django/contrib/gis/db/models/sql/aggregates.py @@ -6,12 +6,15 @@ from django.contrib.gis.db.models.fields import GeometryField __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__ +warnings.warn( + "django.contrib.gis.db.models.sql.aggregates is deprecated. Use " + "django.contrib.gis.db.models.aggregates instead.", + RemovedInDjango20Warning, stacklevel=2) + + class GeoAggregate(Aggregate): # Default SQL template for spatial aggregates. - sql_template = '%(function)s(%(field)s)' - - # Conversion class, if necessary. - conversion_class = None + sql_template = '%(function)s(%(expressions)s)' # Flags for indicating the type of the aggregate. is_extent = False @@ -45,7 +48,7 @@ class GeoAggregate(Aggregate): substitutions = { 'function': sql_function, - 'field': field_name + 'expressions': field_name } substitutions.update(self.extra) diff --git a/django/contrib/gis/db/models/sql/compiler.py b/django/contrib/gis/db/models/sql/compiler.py index 05491ef930..915f33fc2b 100644 --- a/django/contrib/gis/db/models/sql/compiler.py +++ b/django/contrib/gis/db/models/sql/compiler.py @@ -70,8 +70,8 @@ class GeoSQLCompiler(compiler.SQLCompiler): aliases.update(new_aliases) max_name_length = self.connection.ops.max_name_length() - for alias, aggregate in self.query.aggregate_select.items(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + for alias, annotation in self.query.annotation_select.items(): + agg_sql, agg_params = self.compile(annotation) if alias is None: result.append(agg_sql) else: diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py index 7e7715c682..9c071b85fc 100644 --- a/django/contrib/gis/db/models/sql/query.py +++ b/django/contrib/gis/db/models/sql/query.py @@ -4,7 +4,7 @@ from django.db.models.sql.constants import QUERY_TERMS from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models.lookups import GISLookup -from django.contrib.gis.db.models.sql import aggregates as gis_aggregates +from django.contrib.gis.db.models import aggregates as gis_aggregates from django.contrib.gis.db.models.sql.conversion import GeomField @@ -14,7 +14,6 @@ class GeoQuery(sql.Query): """ # Overriding the valid query terms. query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys()) - aggregates_module = gis_aggregates compiler = 'GeoSQLCompiler' @@ -40,28 +39,12 @@ class GeoQuery(sql.Query): # Remove any aggregates marked for reduction from the subquery # and move them to the outer AggregateQuery. connection = connections[using] - for alias, aggregate in self.aggregate_select.items(): - if isinstance(aggregate, gis_aggregates.GeoAggregate): - if not getattr(aggregate, 'is_extent', False) or connection.ops.oracle: + for alias, annotation in self.annotation_select.items(): + if isinstance(annotation, gis_aggregates.GeoAggregate): + if not getattr(annotation, 'is_extent', False) or connection.ops.oracle: self.extra_select_fields[alias] = GeomField() return super(GeoQuery, self).get_aggregation(using, force_subq) - def resolve_aggregate(self, value, aggregate, connection): - """ - Overridden from GeoQuery's normalize to handle the conversion of - GeoAggregate objects. - """ - if isinstance(aggregate, self.aggregates_module.GeoAggregate): - if aggregate.is_extent: - if aggregate.is_extent == '3D': - return connection.ops.convert_extent3d(value) - else: - return connection.ops.convert_extent(value) - else: - return connection.ops.convert_geom(value, aggregate.source) - else: - return super(GeoQuery, self).resolve_aggregate(value, aggregate, connection) - # Private API utilities, subject to change. def _geo_field(self, field_name=None): """ diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 349fae0253..4799576ba9 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -20,8 +20,7 @@ from django.db.backends.sqlite3.client import DatabaseClient from django.db.backends.sqlite3.creation import DatabaseCreation from django.db.backends.sqlite3.introspection import DatabaseIntrospection from django.db.backends.sqlite3.schema import DatabaseSchemaEditor -from django.db.models import fields -from django.db.models.sql import aggregates +from django.db.models import fields, aggregates from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import force_text from django.utils.functional import cached_property @@ -163,8 +162,7 @@ class DatabaseOperations(BaseDatabaseOperations): bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField) bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev) - if (isinstance(aggregate.source, bad_fields) and - isinstance(aggregate, bad_aggregates)): + if aggregate.refs_field(bad_aggregates, bad_fields): raise NotImplementedError( 'You cannot use Sum, Avg, StdDev and Variance aggregations ' 'on date/time fields in sqlite3 ' diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index c054542e3e..6fe8fe8aae 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -4,7 +4,7 @@ import warnings from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA from django.db.models.query import Q, QuerySet, Prefetch # NOQA -from django.db.models.expressions import F # NOQA +from django.db.models.expressions import ExpressionNode, F, Value, Func # NOQA from django.db.models.manager import Manager # NOQA from django.db.models.base import Model # NOQA from django.db.models.aggregates import * # NOQA diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index e31d228aa5..c68378c7da 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -1,94 +1,152 @@ """ Classes to represent the definitions of aggregate functions. """ -from django.db.models.constants import LOOKUP_SEP +from django.core.exceptions import FieldError +from django.db.models.expressions import Func, Value +from django.db.models.fields import IntegerField, FloatField __all__ = [ 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', ] -def refs_aggregate(lookup_parts, aggregates): - """ - A little helper method to check if the lookup_parts contains references - to the given aggregates set. Because the LOOKUP_SEP is contained in the - default annotation names we must check each prefix of the lookup_parts - for match. - """ - for n in range(len(lookup_parts) + 1): - level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) - if level_n_lookup in aggregates: - return aggregates[level_n_lookup], lookup_parts[n:] - return False, () +class Aggregate(Func): + contains_aggregate = True + name = None + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + assert len(self.source_expressions) == 1 + c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize) + if c.source_expressions[0].contains_aggregate and not summarize: + name = self.source_expressions[0].name + raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( + c.name, name, name)) + c._patch_aggregate(query) # backward-compatibility support + return c -class Aggregate(object): - """ - Default Aggregate definition. - """ - def __init__(self, lookup, **extra): - """Instantiate a new aggregate. + def refs_field(self, aggregate_types, field_types): + try: + return (isinstance(self, aggregate_types) and + isinstance(self.input_field._output_field_or_none, field_types)) + except FieldError: + # Sometimes we don't know the input_field's output type (for example, + # doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField()) + # is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't + # have any output field. + return False - * lookup is the field on which the aggregate operates. - * extra is a dictionary of additional data to provide for the - aggregate definition + @property + def input_field(self): + return self.source_expressions[0] - Also utilizes the class variables: - * name, the identifier for this aggregate function. + @property + def default_alias(self): + if hasattr(self.source_expressions[0], 'name'): + return '%s__%s' % (self.source_expressions[0].name, self.name.lower()) + raise TypeError("Complex expressions require an alias") + + def get_group_by_cols(self): + return [] + + def _patch_aggregate(self, query): """ - self.lookup = lookup - self.extra = extra + Helper method for patching 3rd party aggregates that do not yet support + the new way of subclassing. This method should be removed in 2.0 - def _default_alias(self): - return '%s__%s' % (self.lookup, self.name.lower()) - default_alias = property(_default_alias) + add_to_query(query, alias, col, source, is_summary) will be defined on + legacy aggregates which, in turn, instantiates the SQL implementation of + the aggregate. In all the cases found, the general implementation of + add_to_query looks like: - def add_to_query(self, query, alias, col, source, is_summary): - """Add the aggregate to the nominated query. + def add_to_query(self, query, alias, col, source, is_summary): + klass = SQLImplementationAggregate + aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) + query.aggregates[alias] = aggregate - This method is used to convert the generic Aggregate definition into a - backend-specific definition. - - * query is the backend-specific query instance to which the aggregate - is to be added. - * col is a column reference describing the subject field - of the aggregate. It can be an alias, or a tuple describing - a table and column name. - * source is the underlying field or aggregate definition for - the column reference. If the aggregate is not an ordinal or - computed type, this reference is used to determine the coerced - output type of the aggregate. - * is_summary is a boolean that is set True if the aggregate is a - summary value rather than an annotation. + By supplying a known alias, we can get the SQLAggregate out of the + aggregates dict, and use the sql_function and sql_template attributes + to patch *this* aggregate. """ - klass = getattr(query.aggregates_module, self.name) - aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) - query.aggregates[alias] = aggregate + if not hasattr(self, 'add_to_query') or self.function is not None: + return + + placeholder_alias = "_XXXXXXXX_" + self.add_to_query(query, placeholder_alias, None, None, None) + sql_aggregate = query.aggregates.pop(placeholder_alias) + if 'sql_function' not in self.extra and hasattr(sql_aggregate, 'sql_function'): + self.extra['function'] = sql_aggregate.sql_function + + if hasattr(sql_aggregate, 'sql_template'): + self.extra['template'] = sql_aggregate.sql_template class Avg(Aggregate): + function = 'AVG' name = 'Avg' + def __init__(self, expression, **extra): + super(Avg, self).__init__(expression, output_field=FloatField(), **extra) + + def convert_value(self, value, connection): + if value is None: + return value + return float(value) + class Count(Aggregate): + function = 'COUNT' name = 'Count' + template = '%(function)s(%(distinct)s%(expressions)s)' + + def __init__(self, expression, distinct=False, **extra): + if expression == '*': + expression = Value(expression) + expression._output_field = IntegerField() + super(Count, self).__init__( + expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) + + def convert_value(self, value, connection): + if value is None: + return 0 + return int(value) class Max(Aggregate): + function = 'MAX' name = 'Max' class Min(Aggregate): + function = 'MIN' name = 'Min' class StdDev(Aggregate): name = 'StdDev' + def __init__(self, expression, sample=False, **extra): + self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' + super(StdDev, self).__init__(expression, output_field=FloatField(), **extra) + + def convert_value(self, value, connection): + if value is None: + return value + return float(value) + class Sum(Aggregate): + function = 'SUM' name = 'Sum' class Variance(Aggregate): name = 'Variance' + + def __init__(self, expression, sample=False, **extra): + self.function = 'VAR_SAMP' if sample else 'VAR_POP' + super(Variance, self).__init__(expression, output_field=FloatField(), **extra) + + def convert_value(self, value, connection): + if value is None: + return value + return float(value) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 5484ca7f47..22a5e3ab1e 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1,14 +1,20 @@ +import copy import datetime -from django.db.models.aggregates import refs_aggregate +from django.core.exceptions import FieldError +from django.db.backends import utils as backend_utils +from django.db.models import fields from django.db.models.constants import LOOKUP_SEP -from django.utils import tree +from django.db.models.query_utils import refs_aggregate +from django.utils.functional import cached_property -class ExpressionNode(tree.Node): +class CombinableMixin(object): """ - Base class for all query expressions. + Provides the ability to combine one or two objects with + some connector. For example F('foo') + F('bar'). """ + # Arithmetic connectors ADD = '+' SUB = '-' @@ -25,44 +31,17 @@ class ExpressionNode(tree.Node): BITAND = '&' BITOR = '|' - def __init__(self, children=None, connector=None, negated=False): - if children is not None and len(children) > 1 and connector is None: - raise TypeError('You have to specify a connector.') - super(ExpressionNode, self).__init__(children, connector, negated) - def _combine(self, other, connector, reversed, node=None): if isinstance(other, datetime.timedelta): - return DateModifierNode([self, other], connector) + return DateModifierNode(self, connector, other) + + if not hasattr(other, 'resolve_expression'): + # everything must be resolvable to an expression + other = Value(other) if reversed: - obj = ExpressionNode([other], connector) - obj.add(node or self, connector) - else: - obj = node or ExpressionNode([self], connector) - obj.add(other, connector) - return obj - - def contains_aggregate(self, existing_aggregates): - if self.children: - return any(child.contains_aggregate(existing_aggregates) - for child in self.children - if hasattr(child, 'contains_aggregate')) - else: - return refs_aggregate(self.name.split(LOOKUP_SEP), - existing_aggregates) - - def prepare_database_save(self, unused): - return self - - ################### - # VISITOR METHODS # - ################### - - def prepare(self, evaluator, query, allow_joins): - return evaluator.prepare_node(self, query, allow_joins) - - def evaluate(self, evaluator, qn, connection): - return evaluator.evaluate_node(self, qn, connection) + return Expression(other, connector, self) + return Expression(self, connector, other) ############# # OPERATORS # @@ -137,27 +116,240 @@ class ExpressionNode(tree.Node): ) -class F(ExpressionNode): +class ExpressionNode(CombinableMixin): """ - An expression representing the value of the given field. + Base class for all query expressions. """ - def __init__(self, name): - super(F, self).__init__(None, None, False) - self.name = name - def __deepcopy__(self, memodict): - obj = super(F, self).__deepcopy__(memodict) - obj.name = self.name - return obj + # aggregate specific fields + is_summary = False - def prepare(self, evaluator, query, allow_joins): - return evaluator.prepare_leaf(self, query, allow_joins) + def __init__(self, output_field=None): + self._output_field = output_field - def evaluate(self, evaluator, qn, connection): - return evaluator.evaluate_leaf(self, qn, connection) + def get_source_expressions(self): + return [] + + def set_source_expressions(self, exprs): + assert len(exprs) == 0 + + def as_sql(self, compiler, connection): + """ + Responsible for returning a (sql, [params]) tuple to be included + in the current query. + + Different backends can provide their own implementation, by + providing an `as_{vendor}` method and patching the Expression: + + ``` + def override_as_sql(self, compiler, connection): + # custom logic + return super(ExpressionNode, self).as_sql(compiler, connection) + setattr(ExpressionNode, 'as_' + connection.vendor, override_as_sql) + ``` + + Arguments: + * compiler: the query compiler responsible for generating the query. + Must have a compile method, returning a (sql, [params]) tuple. + Calling compiler(value) will return a quoted `value`. + + * connection: the database connection used for the current query. + + Returns: (sql, params) + Where `sql` is a string containing ordered sql parameters to be + replaced with the elements of the list `params`. + """ + raise NotImplementedError("Subclasses must implement as_sql()") + + @cached_property + def contains_aggregate(self): + for expr in self.get_source_expressions(): + if expr and expr.contains_aggregate: + return True + return False + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + """ + Provides the chance to do any preprocessing or validation before being + added to the query. + + Arguments: + * query: the backend query implementation + * allow_joins: boolean allowing or denying use of joins + in this query + * reuse: a set of reusable joins for multijoins + * summarize: a terminal aggregate clause + + Returns: an ExpressionNode to be added to the query. + """ + c = self.copy() + c.is_summary = summarize + return c + + def _prepare(self): + """ + Hook used by Field.get_prep_lookup() to do custom preparation. + """ + return self + + @property + def field(self): + return self.output_field + + @cached_property + def output_field(self): + """ + Returns the output type of this expressions. + """ + if self._output_field_or_none is None: + raise FieldError("Cannot resolve expression type, unknown output_field") + return self._output_field_or_none + + @cached_property + def _output_field_or_none(self): + """ + Returns the output field of this expression, or None if no output type + can be resolved. Note that the 'output_field' property will raise + FieldError if no type can be resolved, but this attribute allows for + None values. + """ + if self._output_field is None: + self._resolve_output_field() + return self._output_field + + def _resolve_output_field(self): + """ + Attempts to infer the output type of the expression. If the output + fields of all source fields match then we can simply infer the same + type here. + """ + if self._output_field is None: + sources = self.get_source_fields() + num_sources = len(sources) + if num_sources == 0: + self._output_field = None + else: + self._output_field = sources[0] + for source in sources: + if source is not None and not isinstance(self._output_field, source.__class__): + raise FieldError( + "Expression contains mixed types. You must set output_field") + + def convert_value(self, value, connection): + """ + Expressions provide their own converters because users have the option + of manually specifying the output_field which may be a different type + from the one the database returns. + """ + field = self.output_field + internal_type = field.get_internal_type() + if value is None: + return value + elif internal_type == 'FloatField': + return float(value) + elif internal_type.endswith('IntegerField'): + return int(value) + elif internal_type == 'DecimalField': + return backend_utils.typecast_decimal(field.format_number(value)) + return value + + def get_lookup(self, lookup): + return self.output_field.get_lookup(lookup) + + def get_transform(self, name): + return self.output_field.get_transform(name) + + def relabeled_clone(self, change_map): + clone = self.copy() + clone.set_source_expressions( + [e.relabeled_clone(change_map) for e in self.get_source_expressions()]) + return clone + + def copy(self): + c = copy.copy(self) + c.copied = True + return c + + def refs_aggregate(self, existing_aggregates): + """ + Does this expression contain a reference to some of the + existing aggregates? If so, returns the aggregate and also + the lookup parts that *weren't* found. So, if + exsiting_aggregates = {'max_id': Max('id')} + self.name = 'max_id' + queryset.filter(max_id__range=[10,100]) + then this method will return Max('id') and those parts of the + name that weren't found. In this case `max_id` is found and the range + portion is returned as ('range',). + """ + for node in self.get_source_expressions(): + agg, lookup = node.refs_aggregate(existing_aggregates) + if agg: + return agg, lookup + return False, () + + def refs_field(self, aggregate_types, field_types): + """ + Helper method for check_aggregate_support on backends + """ + return any( + node.refs_field(aggregate_types, field_types) + for node in self.get_source_expressions()) + + def prepare_database_save(self, field): + return self + + def get_group_by_cols(self): + cols = [] + for source in self.get_source_expressions(): + cols.extend(source.get_group_by_cols()) + return cols + + def get_source_fields(self): + """ + Returns the underlying field types used by this + aggregate. + """ + return [e._output_field_or_none for e in self.get_source_expressions()] -class DateModifierNode(ExpressionNode): +class Expression(ExpressionNode): + + def __init__(self, lhs, connector, rhs, output_field=None): + super(Expression, self).__init__(output_field=output_field) + self.connector = connector + self.lhs = lhs + self.rhs = rhs + + def get_source_expressions(self): + return [self.lhs, self.rhs] + + def set_source_expressions(self, exprs): + self.lhs, self.rhs = exprs + + def as_sql(self, compiler, connection): + expressions = [] + expression_params = [] + sql, params = compiler.compile(self.lhs) + expressions.append(sql) + expression_params.extend(params) + sql, params = compiler.compile(self.rhs) + expressions.append(sql) + expression_params.extend(params) + # order of precedence + expression_wrapper = '(%s)' + sql = connection.ops.combine_expression(self.connector, expressions) + return expression_wrapper % sql, expression_params + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + c = self.copy() + c.is_summary = summarize + c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize) + c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize) + return c + + +class DateModifierNode(Expression): """ Node that implements the following syntax: filter(end_date__gt=F('start_date') + datetime.timedelta(days=3, seconds=200)) @@ -183,14 +375,195 @@ class DateModifierNode(ExpressionNode): Only adding and subtracting timedeltas is supported, attempts to use other operations raise a TypeError. """ - def __init__(self, children, connector, negated=False): - if len(children) != 2: - raise TypeError('Must specify a node and a timedelta.') - if not isinstance(children[1], datetime.timedelta): - raise TypeError('Second child must be a timedelta.') + def __init__(self, lhs, connector, rhs): + if not isinstance(rhs, datetime.timedelta): + raise TypeError('rhs must be a timedelta.') if connector not in (self.ADD, self.SUB): raise TypeError('Connector must be + or -, not %s' % connector) - super(DateModifierNode, self).__init__(children, connector, negated) + super(DateModifierNode, self).__init__(lhs, connector, Value(rhs)) - def evaluate(self, evaluator, qn, connection): - return evaluator.evaluate_date_modifier_node(self, qn, connection) + def as_sql(self, compiler, connection): + timedelta = self.rhs.value + sql, params = compiler.compile(self.lhs) + if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0): + return sql, params + return connection.ops.date_interval_sql(sql, self.connector, timedelta), params + + +class F(CombinableMixin): + """ + An object capable of resolving references to existing query objects. + """ + def __init__(self, name): + """ + Arguments: + * name: the name of the field this expression references + """ + self.name = name + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + return query.resolve_ref(self.name, allow_joins, reuse, summarize) + + def refs_aggregate(self, existing_aggregates): + return refs_aggregate(self.name.split(LOOKUP_SEP), existing_aggregates) + + +class Func(ExpressionNode): + """ + A SQL function call. + """ + function = None + template = '%(function)s(%(expressions)s)' + arg_joiner = ', ' + + def __init__(self, *expressions, **extra): + output_field = extra.pop('output_field', None) + super(Func, self).__init__(output_field=output_field) + self.source_expressions = self._parse_expressions(*expressions) + self.extra = extra + + def get_source_expressions(self): + return self.source_expressions + + def set_source_expressions(self, exprs): + self.source_expressions = exprs + + def _parse_expressions(self, *expressions): + return [ + arg if hasattr(arg, 'resolve_expression') else F(arg) + for arg in expressions + ] + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + c = self.copy() + c.is_summary = summarize + for pos, arg in enumerate(c.source_expressions): + c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize) + return c + + def as_sql(self, compiler, connection, function=None, template=None): + sql_parts = [] + params = [] + for arg in self.source_expressions: + arg_sql, arg_params = compiler.compile(arg) + sql_parts.append(arg_sql) + params.extend(arg_params) + if function is None: + self.extra['function'] = self.extra.get('function', self.function) + else: + self.extra['function'] = function + self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts) + template = template or self.extra.get('template', self.template) + return template % self.extra, params + + def copy(self): + copy = super(Func, self).copy() + copy.source_expressions = self.source_expressions[:] + copy.extra = self.extra.copy() + return copy + + +class Value(ExpressionNode): + """ + Represents a wrapped value as a node within an expression + """ + def __init__(self, value, output_field=None): + """ + Arguments: + * value: the value this expression represents. The value will be + added into the sql parameter list and properly quoted. + + * output_field: an instance of the model field type that this + expression will return, such as IntegerField() or CharField(). + """ + super(Value, self).__init__(output_field=output_field) + self.value = value + + def as_sql(self, compiler, connection): + return '%s', [self.value] + + +class Col(ExpressionNode): + def __init__(self, alias, target, source=None): + if source is None: + source = target + super(Col, self).__init__(output_field=source) + self.alias, self.target = alias, target + + def as_sql(self, qn, connection): + return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] + + def relabeled_clone(self, relabels): + return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field) + + def get_group_by_cols(self): + return [(self.alias, self.target.column)] + + +class Ref(ExpressionNode): + """ + Reference to column alias of the query. For example, Ref('sum_cost') in + qs.annotate(sum_cost=Sum('cost')) query. + """ + def __init__(self, refs, source): + super(Ref, self).__init__() + self.source = source + self.refs = refs + + def get_source_expressions(self): + return [self.source] + + def set_source_expressions(self, exprs): + self.source, = exprs + + def relabeled_clone(self, relabels): + return self + + def as_sql(self, compiler, connection): + return "%s" % compiler(self.refs), [] + + def get_group_by_cols(self): + return [(None, self.refs)] + + +class Date(ExpressionNode): + """ + Add a date selection column. + """ + def __init__(self, col, lookup_type): + super(Date, self).__init__(output_field=fields.DateField()) + self.col = col + self.lookup_type = lookup_type + + def get_source_expressions(self): + return [self.col] + + def set_source_expressions(self, exprs): + self.col, = self.exprs + + def as_sql(self, qn, connection): + sql, params = self.col.as_sql(qn, connection) + assert not(params) + return connection.ops.date_trunc_sql(self.lookup_type, sql), [] + + +class DateTime(ExpressionNode): + """ + Add a datetime selection column. + """ + def __init__(self, col, lookup_type, tzname): + super(DateTime, self).__init__(output_field=fields.DateTimeField()) + self.col = col + self.lookup_type = lookup_type + self.tzname = tzname + + def get_source_expressions(self): + return [self.col] + + def set_source_expressions(self, exprs): + self.col, = exprs + + def as_sql(self, qn, connection): + sql, params = self.col.as_sql(qn, connection) + assert not(params) + return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 6eed25d96a..552f868de2 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -637,8 +637,6 @@ class Field(RegisterLookupMixin): """ Perform preliminary non-db specific lookup checks and conversions """ - if hasattr(value, 'prepare'): - return value.prepare() if hasattr(value, '_prepare'): return value._prepare() diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index b52d59394d..f7699f5152 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -13,7 +13,7 @@ from django.db.models.fields import (AutoField, Field, IntegerField, from django.db.models.lookups import IsNull from django.db.models.related import RelatedObject, PathInfo from django.db.models.query import QuerySet -from django.db.models.sql.datastructures import Col +from django.db.models.expressions import Col from django.utils.encoding import force_text, smart_text from django.utils import six from django.utils.translation import ugettext_lazy as _ diff --git a/django/db/models/query.py b/django/db/models/query.py index d35f54b0b5..a7474bbc45 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -154,8 +154,7 @@ class QuerySet(object): 2. sql/compiler.results_iter() - Returns one row at time. At this point the rows are still just tuples. In some cases the return values are converted to - Python values at this location (see resolve_columns(), - resolve_aggregate()). + Python values at this location. 3. self.iterator() - Responsible for turning the rows into model objects. """ @@ -241,7 +240,7 @@ class QuerySet(object): max_depth = self.query.max_depth extra_select = list(self.query.extra_select) - aggregate_select = list(self.query.aggregate_select) + annotation_select = list(self.query.annotation_select) only_load = self.query.get_loaded_field_names() fields = self.model._meta.concrete_fields @@ -282,7 +281,7 @@ class QuerySet(object): db = self.db compiler = self.query.get_compiler(using=db) index_start = len(extra_select) - aggregate_start = index_start + len(init_list) + annotation_start = index_start + len(init_list) if fill_cache: klass_info = get_klass_info(model_cls, max_depth=max_depth, @@ -290,18 +289,18 @@ class QuerySet(object): for row in compiler.results_iter(): if fill_cache: obj, _ = get_cached_row(row, index_start, db, klass_info, - offset=len(aggregate_select)) + offset=len(annotation_select)) else: - obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start]) + obj = model_cls.from_db(db, init_list, row[index_start:annotation_start]) if extra_select: for i, k in enumerate(extra_select): setattr(obj, k, row[i]) - # Add the aggregates to the model - if aggregate_select: - for i, aggregate in enumerate(aggregate_select): - setattr(obj, aggregate, row[i + aggregate_start]) + # Add the annotations to the model + if annotation_select: + for i, annotation in enumerate(annotation_select): + setattr(obj, annotation, row[i + annotation_start]) # Add the known related objects to the model, if there are any if self._known_related_objects: @@ -330,13 +329,16 @@ class QuerySet(object): if self.query.distinct_fields: raise NotImplementedError("aggregate() + distinct(fields) not implemented.") for arg in args: + if not hasattr(arg, 'default_alias'): + raise TypeError("Complex aggregates require an alias") kwargs[arg.default_alias] = arg query = self.query.clone() force_subq = query.low_mark != 0 or query.high_mark is not None for (alias, aggregate_expr) in kwargs.items(): - query.add_aggregate(aggregate_expr, self.model, alias, - is_summary=True) + query.add_annotation(aggregate_expr, self.model, alias, is_summary=True) + if not query.annotations[alias].contains_aggregate: + raise TypeError("%s is not an aggregate expression" % alias) return query.get_aggregation(using=self.db, force_subq=force_subq) def count(self): @@ -787,33 +789,40 @@ class QuerySet(object): def annotate(self, *args, **kwargs): """ Return a query set in which the returned objects have been annotated - with data aggregated from related fields. + with extra data or aggregations. """ - aggrs = OrderedDict() # To preserve ordering of args + annotations = OrderedDict() # To preserve ordering of args for arg in args: - if arg.default_alias in kwargs: - raise ValueError("The named annotation '%s' conflicts with the " - "default name for another annotation." - % arg.default_alias) - aggrs[arg.default_alias] = arg - aggrs.update(kwargs) + try: + # we can't do an hasattr here because py2 returns False + # if default_alias exists but throws a TypeError + if arg.default_alias in kwargs: + raise ValueError("The named annotation '%s' conflicts with the " + "default name for another annotation." + % arg.default_alias) + except AttributeError: # default_alias + raise TypeError("Complex annotations require an alias") + annotations[arg.default_alias] = arg + annotations.update(kwargs) + obj = self._clone() names = getattr(self, '_fields', None) if names is None: names = set(self.model._meta.get_all_field_names()) - for aggregate in aggrs: - if aggregate in names: + + # Add the annotations to the query + for alias, annotation in annotations.items(): + if alias in names: raise ValueError("The annotation '%s' conflicts with a field on " - "the model." % aggregate) - - obj = self._clone() - - obj._setup_aggregate_query(list(aggrs)) - - # Add the aggregates to the query - for (alias, aggregate_expr) in aggrs.items(): - obj.query.add_aggregate(aggregate_expr, self.model, alias, - is_summary=False) + "the model." % alias) + obj.query.add_annotation(annotation, self.model, alias, is_summary=False) + # expressions need to be added to the query before we know if they contain aggregates + added_aggregates = [] + for alias, annotation in obj.query.annotations.items(): + if alias in annotations and annotation.contains_aggregate: + added_aggregates.append(alias) + if added_aggregates: + obj._setup_aggregate_query(list(added_aggregates)) return obj @@ -1096,9 +1105,9 @@ class ValuesQuerySet(QuerySet): # Purge any extra columns that haven't been explicitly asked for extra_names = list(self.query.extra_select) field_names = self.field_names - aggregate_names = list(self.query.aggregate_select) + annotation_names = list(self.query.annotation_select) - names = extra_names + field_names + aggregate_names + names = extra_names + field_names + annotation_names for row in self.query.get_compiler(self.db).results_iter(): yield dict(zip(names, row)) @@ -1122,9 +1131,9 @@ class ValuesQuerySet(QuerySet): if self._fields: self.extra_names = [] - self.aggregate_names = [] - if not self.query._extra and not self.query._aggregates: - # Short cut - if there are no extra or aggregates, then + self.annotation_names = [] + if not self.query._extra and not self.query._annotations: + # Short cut - if there are no extra or annotations, then # the values() clause must be just field names. self.field_names = list(self._fields) else: @@ -1136,22 +1145,22 @@ class ValuesQuerySet(QuerySet): # had selected previously. if self.query._extra and f in self.query._extra: self.extra_names.append(f) - elif f in self.query.aggregate_select: - self.aggregate_names.append(f) + elif f in self.query.annotation_select: + self.annotation_names.append(f) else: self.field_names.append(f) else: # Default to all fields. self.extra_names = None self.field_names = [f.attname for f in self.model._meta.concrete_fields] - self.aggregate_names = None + self.annotation_names = None self.query.select = [] if self.extra_names is not None: self.query.set_extra_mask(self.extra_names) self.query.add_fields(self.field_names, True) - if self.aggregate_names is not None: - self.query.set_aggregate_mask(self.aggregate_names) + if self.annotation_names is not None: + self.query.set_annotation_mask(self.annotation_names) def _clone(self, klass=None, setup=False, **kwargs): """ @@ -1164,7 +1173,7 @@ class ValuesQuerySet(QuerySet): c._fields = self._fields[:] c.field_names = self.field_names c.extra_names = self.extra_names - c.aggregate_names = self.aggregate_names + c.annotation_names = self.annotation_names if setup and hasattr(c, '_setup_query'): c._setup_query() return c @@ -1173,7 +1182,7 @@ class ValuesQuerySet(QuerySet): super(ValuesQuerySet, self)._merge_sanity_check(other) if (set(self.extra_names) != set(other.extra_names) or set(self.field_names) != set(other.field_names) or - self.aggregate_names != other.aggregate_names): + self.annotation_names != other.annotation_names): raise TypeError("Merging '%s' classes must involve the same values in each case." % self.__class__.__name__) @@ -1183,9 +1192,9 @@ class ValuesQuerySet(QuerySet): """ self.query.set_group_by() - if self.aggregate_names is not None: - self.aggregate_names.extend(aggregates) - self.query.set_aggregate_mask(self.aggregate_names) + if self.annotation_names is not None: + self.annotation_names.extend(aggregates) + self.query.set_annotation_mask(self.annotation_names) super(ValuesQuerySet, self)._setup_aggregate_query(aggregates) @@ -1231,7 +1240,7 @@ class ValuesListQuerySet(ValuesQuerySet): if self.flat and len(self._fields) == 1: for row in self.query.get_compiler(self.db).results_iter(): yield row[0] - elif not self.query.extra_select and not self.query.aggregate_select: + elif not self.query.extra_select and not self.query.annotation_select: for row in self.query.get_compiler(self.db).results_iter(): yield tuple(row) else: @@ -1240,14 +1249,14 @@ class ValuesListQuerySet(ValuesQuerySet): # the fields to match the order in self._fields. extra_names = list(self.query.extra_select) field_names = self.field_names - aggregate_names = list(self.query.aggregate_select) + annotation_names = list(self.query.annotation_select) - names = extra_names + field_names + aggregate_names + names = extra_names + field_names + annotation_names # If a field list has been specified, use it. Otherwise, use the - # full list of fields, including extras and aggregates. + # full list of fields, including extras and annotations. if self._fields: - fields = list(self._fields) + [f for f in aggregate_names if f not in self._fields] + fields = list(self._fields) + [f for f in annotation_names if f not in self._fields] else: fields = names diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index a8699f6334..59cd453722 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -9,6 +9,7 @@ from __future__ import unicode_literals from django.apps import apps from django.db.backends import utils +from django.db.models.constants import LOOKUP_SEP from django.utils import six from django.utils import tree @@ -220,3 +221,17 @@ def deferred_class_factory(model, attrs): # The above function is also used to unpickle model instances with deferred # fields. deferred_class_factory.__safe_for_unpickling__ = True + + +def refs_aggregate(lookup_parts, aggregates): + """ + A little helper method to check if the lookup_parts contains references + to the given aggregates set. Because the LOOKUP_SEP is contained in the + default annotation names we must check each prefix of the lookup_parts + for a match. + """ + for n in range(len(lookup_parts) + 1): + level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) + if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate: + return aggregates[level_n_lookup], lookup_parts[n:] + return False, () diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 8274d43621..6ebf5fb966 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -2,15 +2,23 @@ Classes to represent the default SQL aggregate functions """ import copy +import warnings from django.db.models.fields import IntegerField, FloatField from django.db.models.lookups import RegisterLookupMixin +from django.utils.deprecation import RemovedInDjango20Warning from django.utils.functional import cached_property __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] +warnings.warn( + "django.db.models.sql.aggregates is deprecated. Use " + "django.db.models.aggregates instead.", + RemovedInDjango20Warning, stacklevel=2) + + class Aggregate(RegisterLookupMixin): """ Default SQL Aggregate. diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 33fe343b5b..5f425a7543 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -4,12 +4,10 @@ from django.conf import settings from django.core.exceptions import FieldError from django.db.backends.utils import truncate_name from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import ExpressionNode from django.db.models.query_utils import select_related_descend, QueryWrapper from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS, ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet -from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import get_order_dir, Query from django.db.transaction import TransactionManagementError from django.db.utils import DatabaseError @@ -248,8 +246,8 @@ class SQLCompiler(object): aliases.update(new_aliases) max_name_length = self.connection.ops.max_name_length() - for alias, aggregate in self.query.aggregate_select.items(): - agg_sql, agg_params = self.compile(aggregate) + for alias, annotation in self.query.annotation_select.items(): + agg_sql, agg_params = self.compile(annotation) if alias is None: result.append(agg_sql) else: @@ -409,7 +407,7 @@ class SQLCompiler(object): group_by.append((str(field), [])) continue col, order = get_order_dir(field, asc) - if col in self.query.aggregate_select: + if col in self.query.annotation_select: result.append('%s %s' % (qn(col), order)) continue if '.' in field: @@ -718,25 +716,17 @@ class SQLCompiler(object): """ fields = None converters = None - has_aggregate_select = bool(self.query.aggregate_select) + has_annotation_select = bool(self.query.annotation_select) for rows in self.execute_sql(MULTI): for row in rows: - if has_aggregate_select: - loaded_fields = ( - self.query.get_loaded_field_names().get(self.query.model, set()) or - self.query.select - ) - aggregate_start = len(self.query.extra_select) + len(loaded_fields) - aggregate_end = aggregate_start + len(self.query.aggregate_select) if fields is None: # We only set this up here because # related_select_cols isn't populated until # execute_sql() has been called. - # We also include types of fields of related models that - # will be included via select_related() for the benefit - # of MySQL/MySQLdb when boolean fields are involved - # (#15040). + # If the field was deferred, exclude it from being passed + # into `get_converters` because it wasn't selected. + only_load = self.deferred_to_columns() # This code duplicates the logic for the order of fields # found in get_columns(). It would be nice to clean this up. @@ -746,30 +736,45 @@ class SQLCompiler(object): fields = self.query.get_meta().concrete_fields else: fields = [] - fields = fields + [f.field for f in self.query.related_select_cols] - # If the field was deferred, exclude it from being passed - # into `get_converters` because it wasn't selected. - only_load = self.deferred_to_columns() if only_load: - fields = [f for f in fields if f.model._meta.db_table not in only_load or - f.column in only_load[f.model._meta.db_table]] - if has_aggregate_select: - # pad None in to fields for aggregates - fields = fields[:aggregate_start] + [ - None for x in range(0, aggregate_end - aggregate_start) - ] + fields[aggregate_start:] + # strip deferred fields + fields = [ + f for f in fields if + f.model._meta.db_table not in only_load or + f.column in only_load[f.model._meta.db_table] + ] + + # annotations come before the related cols + if has_annotation_select: + # extra is always at the start of the field list + prepended_cols = len(self.query.extra_select) + annotation_start = len(fields) + prepended_cols + fields = fields + [ + anno.output_field for alias, anno in self.query.annotation_select.items()] + annotation_end = len(fields) + prepended_cols + + # add related fields + fields = fields + [ + # strip deferred + f.field for f in self.query.related_select_cols if + f.field.model._meta.db_table not in only_load or + f.field.column in only_load[f.field.model._meta.db_table] + ] + converters = self.get_converters(fields) + if has_annotation_select: + for (alias, annotation), position in zip( + self.query.annotation_select.items(), + range(annotation_start, annotation_end + 1)): + if position in converters: + # annotation conversions always run first + converters[position][1].insert(0, annotation.convert_value) + else: + converters[position] = ([], [annotation.convert_value], annotation.output_field) + if converters: row = self.apply_converters(row, converters) - - if has_aggregate_select: - row = tuple(row[:aggregate_start]) + tuple( - self.query.resolve_aggregate(value, aggregate, self.connection) - for (alias, aggregate), value - in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end]) - ) + tuple(row[aggregate_end:]) - yield row def has_results(self): @@ -878,7 +883,7 @@ class SQLInsertCompiler(SQLCompiler): elif hasattr(field, 'get_placeholder'): # Some fields (e.g. geo fields) need special munging before # they can be inserted. - return field.get_placeholder(val, self.connection) + return field.get_placeholder(val, self, self.connection) else: # Return the common case for the placeholder return '%s' @@ -985,8 +990,10 @@ class SQLUpdateCompiler(SQLCompiler): result.append('SET') values, update_params = [], [] for field, model, val in self.query.values: - if hasattr(val, 'prepare_database_save'): - if field.rel or isinstance(val, ExpressionNode): + if hasattr(val, 'resolve_expression'): + val = val.resolve_expression(self.query, allow_joins=False) + elif hasattr(val, 'prepare_database_save'): + if field.rel: val = val.prepare_database_save(field) else: raise TypeError("Database is trying to update a relational field " @@ -998,12 +1005,9 @@ class SQLUpdateCompiler(SQLCompiler): # Getting the placeholder for the field. if hasattr(field, 'get_placeholder'): - placeholder = field.get_placeholder(val, self.connection) + placeholder = field.get_placeholder(val, self, self.connection) else: placeholder = '%s' - - if hasattr(val, 'evaluate'): - val = SQLEvaluator(val, self.query, allow_joins=False) name = field.column if hasattr(val, 'as_sql'): sql, params = self.compile(val) @@ -1103,8 +1107,8 @@ class SQLAggregateCompiler(SQLCompiler): qn = self sql, params = [], [] - for aggregate in self.query.aggregate_select.values(): - agg_sql, agg_params = self.compile(aggregate) + for annotation in self.query.annotation_select.values(): + agg_sql, agg_params = self.compile(annotation) sql.append(agg_sql) params.extend(agg_params) sql = ', '.join(sql) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index f9c9c259de..321451ac42 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -4,33 +4,6 @@ the SQL domain. """ -class Col(object): - def __init__(self, alias, target, source): - self.alias, self.target, self.source = alias, target, source - - def as_sql(self, qn, connection): - return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] - - @property - def output_field(self): - return self.source - - def relabeled_clone(self, relabels): - return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source) - - def get_group_by_cols(self): - return [(self.alias, self.target.column)] - - def get_lookup(self, name): - return self.output_field.get_lookup(name) - - def get_transform(self, name): - return self.output_field.get_transform(name) - - def prepare(self): - return self - - class EmptyResultSet(Exception): pass @@ -49,42 +22,3 @@ class MultiJoin(Exception): class Empty(object): pass - - -class Date(object): - """ - Add a date selection column. - """ - def __init__(self, col, lookup_type): - self.col = col - self.lookup_type = lookup_type - - def relabeled_clone(self, change_map): - return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1])) - - def as_sql(self, qn, connection): - if isinstance(self.col, (list, tuple)): - col = '%s.%s' % tuple(qn(c) for c in self.col) - else: - col = self.col - return connection.ops.date_trunc_sql(self.lookup_type, col), [] - - -class DateTime(object): - """ - Add a datetime selection column. - """ - def __init__(self, col, lookup_type, tzname): - self.col = col - self.lookup_type = lookup_type - self.tzname = tzname - - def relabeled_clone(self, change_map): - return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1])) - - def as_sql(self, qn, connection): - if isinstance(self.col, (list, tuple)): - col = '%s.%s' % tuple(qn(c) for c in self.col) - else: - col = self.col - return connection.ops.datetime_trunc_sql(self.lookup_type, col, self.tzname) diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py deleted file mode 100644 index e15cc2642c..0000000000 --- a/django/db/models/sql/expressions.py +++ /dev/null @@ -1,119 +0,0 @@ -import copy - -from django.core.exceptions import FieldError -from django.db.models.constants import LOOKUP_SEP -from django.db.models.fields import FieldDoesNotExist - - -class SQLEvaluator(object): - def __init__(self, expression, query, allow_joins=True, reuse=None): - self.expression = expression - self.opts = query.get_meta() - self.reuse = reuse - self.cols = [] - self.expression.prepare(self, query, allow_joins) - - def relabeled_clone(self, change_map): - clone = copy.copy(self) - clone.cols = [] - for node, col in self.cols: - if hasattr(col, 'relabeled_clone'): - clone.cols.append((node, col.relabeled_clone(change_map))) - else: - clone.cols.append((node, - (change_map.get(col[0], col[0]), col[1]))) - return clone - - def get_group_by_cols(self): - cols = [] - for node, col in self.cols: - if hasattr(node, 'get_group_by_cols'): - cols.extend(node.get_group_by_cols()) - elif isinstance(col, tuple): - cols.append(col) - return cols - - def prepare(self): - return self - - def as_sql(self, qn, connection): - return self.expression.evaluate(self, qn, connection) - - ##################################################### - # Visitor methods for initial expression preparation # - ##################################################### - - def prepare_node(self, node, query, allow_joins): - for child in node.children: - if hasattr(child, 'prepare'): - child.prepare(self, query, allow_joins) - - def prepare_leaf(self, node, query, allow_joins): - if not allow_joins and LOOKUP_SEP in node.name: - raise FieldError("Joined field references are not permitted in this query") - - field_list = node.name.split(LOOKUP_SEP) - if node.name in query.aggregates: - self.cols.append((node, query.aggregate_select[node.name])) - else: - try: - _, sources, _, join_list, path = query.setup_joins( - field_list, query.get_meta(), query.get_initial_alias(), - can_reuse=self.reuse) - self._used_joins = join_list - targets, _, join_list = query.trim_joins(sources, join_list, path) - if self.reuse is not None: - self.reuse.update(join_list) - for t in targets: - self.cols.append((node, (join_list[-1], t.column))) - except FieldDoesNotExist: - raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (self.name, - [f.name for f in self.opts.fields])) - - ################################################## - # Visitor methods for final expression evaluation # - ################################################## - - def evaluate_node(self, node, qn, connection): - expressions = [] - expression_params = [] - for child in node.children: - if hasattr(child, 'evaluate'): - sql, params = child.evaluate(self, qn, connection) - else: - sql, params = '%s', (child,) - - if len(getattr(child, 'children', [])) > 1: - format = '(%s)' - else: - format = '%s' - - if sql: - expressions.append(format % sql) - expression_params.extend(params) - - return connection.ops.combine_expression(node.connector, expressions), expression_params - - def evaluate_leaf(self, node, qn, connection): - col = None - for n, c in self.cols: - if n is node: - col = c - break - if col is None: - raise ValueError("Given node not found") - if hasattr(col, 'as_sql'): - return col.as_sql(qn, connection) - else: - return '%s.%s' % (qn(col[0]), qn(col[1])), [] - - def evaluate_date_modifier_node(self, node, qn, connection): - timedelta = node.children.pop() - sql, params = self.evaluate_node(node, qn, connection) - node.children.append(timedelta) - - if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0): - return sql, params - - return connection.ops.date_interval_sql(sql, node.connector, timedelta), params diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 856bc51f4f..a17cd62f29 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -14,20 +14,18 @@ import warnings from django.core.exceptions import FieldError from django.db import connections, DEFAULT_DB_ALIAS from django.db.models.constants import LOOKUP_SEP -from django.db.models.aggregates import refs_aggregate -from django.db.models.expressions import ExpressionNode +from django.db.models.expressions import Col, Ref from django.db.models.fields import FieldDoesNotExist -from django.db.models.query_utils import Q +from django.db.models.query_utils import Q, refs_aggregate from django.db.models.related import PathInfo -from django.db.models.sql import aggregates as base_aggregates_module +from django.db.models.aggregates import Count from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, ORDER_PATTERN, JoinInfo, SelectInfo) -from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin, Col -from django.db.models.sql.expressions import SQLEvaluator +from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, ExtraWhere, AND, OR, EmptyWhere) from django.utils import six -from django.utils.deprecation import RemovedInDjango19Warning +from django.utils.deprecation import RemovedInDjango19Warning, RemovedInDjango20Warning from django.utils.encoding import force_text from django.utils.tree import Node @@ -49,7 +47,7 @@ class RawQuery(object): # the compiler can be used to process results. self.low_mark, self.high_mark = 0, None # Used for offset/limit self.extra_select = {} - self.aggregate_select = {} + self.annotation_select = {} def clone(self, using): return RawQuery(self.sql, using, params=self.params) @@ -97,7 +95,6 @@ class Query(object): alias_prefix = 'T' subq_aliases = frozenset([alias_prefix]) query_terms = QUERY_TERMS - aggregates_module = base_aggregates_module compiler = 'SQLCompiler' @@ -140,13 +137,13 @@ class Query(object): self.select_for_update_nowait = False self.select_related = False - # SQL aggregate-related attributes - # The _aggregates will be an OrderedDict when used. Due to the cost + # SQL annotation-related attributes + # The _annotations will be an OrderedDict when used. Due to the cost # of creating OrderedDict this attribute is created lazily (in - # self.aggregates property). - self._aggregates = None # Maps alias -> SQL aggregate function - self.aggregate_select_mask = None - self._aggregate_select_cache = None + # self.annotations property). + self._annotations = None # Maps alias -> Annotation Expression + self.annotation_select_mask = None + self._annotation_select_cache = None # Arbitrary maximum limit for select_related. Prevents infinite # recursion. Can be changed by the depth parameter to select_related(). @@ -155,7 +152,7 @@ class Query(object): # These are for extensions. The contents are more or less appended # verbatim to the appropriate clause. # The _extra attribute is an OrderedDict, lazily created similarly to - # .aggregates + # .annotations self._extra = None # Maps col_alias -> (col_sql, params). self.extra_select_mask = None self._extra_select_cache = None @@ -174,11 +171,18 @@ class Query(object): self._extra = OrderedDict() return self._extra + @property + def annotations(self): + if self._annotations is None: + self._annotations = OrderedDict() + return self._annotations + @property def aggregates(self): - if self._aggregates is None: - self._aggregates = OrderedDict() - return self._aggregates + warnings.warn( + "The aggregates property is deprecated. Use annotations instead.", + RemovedInDjango20Warning, stacklevel=2) + return self.annotations def __str__(self): """ @@ -203,7 +207,7 @@ class Query(object): memo[id(self)] = result return result - def prepare(self): + def _prepare(self): return self def get_compiler(self, using=None, connection=None): @@ -213,8 +217,8 @@ class Query(object): connection = connections[using] # Check that the compiler will be able to execute the query - for alias, aggregate in self.aggregate_select.items(): - connection.ops.check_aggregate_support(aggregate) + for alias, annotation in self.annotation_select.items(): + connection.ops.check_aggregate_support(annotation) return connection.ops.compiler(self.compiler)(self, connection, using) @@ -260,17 +264,17 @@ class Query(object): obj.select_for_update_nowait = self.select_for_update_nowait obj.select_related = self.select_related obj.related_select_cols = [] - obj._aggregates = self._aggregates.copy() if self._aggregates is not None else None - if self.aggregate_select_mask is None: - obj.aggregate_select_mask = None + obj._annotations = self._annotations.copy() if self._annotations is not None else None + if self.annotation_select_mask is None: + obj.annotation_select_mask = None else: - obj.aggregate_select_mask = self.aggregate_select_mask.copy() - # _aggregate_select_cache cannot be copied, as doing so breaks the - # (necessary) state in which both aggregates and - # _aggregate_select_cache point to the same underlying objects. + obj.annotation_select_mask = self.annotation_select_mask.copy() + # _annotation_select_cache cannot be copied, as doing so breaks the + # (necessary) state in which both annotations and + # _annotation_select_cache point to the same underlying objects. # It will get re-populated in the cloned queryset the next time it's # used. - obj._aggregate_select_cache = None + obj._annotation_select_cache = None obj.max_depth = self.max_depth obj._extra = self._extra.copy() if self._extra is not None else None if self.extra_select_mask is None: @@ -299,94 +303,84 @@ class Query(object): obj._setup_query() return obj - def resolve_aggregate(self, value, aggregate, connection): - """Resolve the value of aggregates returned by the database to - consistent (and reasonable) types. - - This is required because of the predisposition of certain backends - to return Decimal and long types when they are not needed. - """ - if value is None: - if aggregate.is_ordinal: - return 0 - # Return None as-is - return value - elif aggregate.is_ordinal: - # Any ordinal aggregate (e.g., count) returns an int - return int(value) - elif aggregate.is_computed: - # Any computed aggregate (e.g., avg) returns a float - return float(value) - else: - # Return value depends on the type of the field being processed. - backend_converters = connection.ops.get_db_converters(aggregate.field.get_internal_type()) - field_converters = aggregate.field.get_db_converters(connection) - for converter in backend_converters: - value = converter(value, aggregate.field) - for converter in field_converters: - value = converter(value, connection) - return value - def get_aggregation(self, using, force_subq=False): """ Returns the dictionary with the values of the existing aggregations. """ - if not self.aggregate_select: + if not self.annotation_select: return {} + # annotations must be forced into subquery + has_annotation = any( + annotation for alias, annotation + in self.annotation_select.items() + if not annotation.contains_aggregate) + # If there is a group by clause, aggregating does not add useful # information but retrieves only the first row. Aggregate # over the subquery instead. - if self.group_by is not None or force_subq: + if self.group_by is not None or force_subq or has_annotation: from django.db.models.sql.subqueries import AggregateQuery - query = AggregateQuery(self.model) - obj = self.clone() + outer_query = AggregateQuery(self.model) + inner_query = self.clone() if not force_subq: # In forced subq case the ordering and limits will likely # affect the results. - obj.clear_ordering(True) - obj.clear_limits() - obj.select_for_update = False - obj.select_related = False - obj.related_select_cols = [] + inner_query.clear_ordering(True) + inner_query.clear_limits() + inner_query.select_for_update = False + inner_query.select_related = False + inner_query.related_select_cols = [] - relabels = dict((t, 'subquery') for t in self.tables) + relabels = dict((t, 'subquery') for t in inner_query.tables) + relabels[None] = 'subquery' # Remove any aggregates marked for reduction from the subquery # and move them to the outer AggregateQuery. - for alias, aggregate in self.aggregate_select.items(): - if aggregate.is_summary: - query.aggregates[alias] = aggregate.relabeled_clone(relabels) - del obj.aggregate_select[alias] - + for alias, annotation in inner_query.annotation_select.items(): + if annotation.is_summary: + # The annotation is already referring the subquery alias, so we + # just need to move the annotation to the outer query. + outer_query.annotations[alias] = annotation.relabeled_clone(relabels) + del inner_query.annotation_select[alias] try: - query.add_subquery(obj, using) + outer_query.add_subquery(inner_query, using) except EmptyResultSet: return dict( (alias, None) - for alias in query.aggregate_select + for alias in outer_query.annotation_select ) else: - query = self + outer_query = self self.select = [] self.default_cols = False self._extra = {} self.remove_inherited_models() - query.clear_ordering(True) - query.clear_limits() - query.select_for_update = False - query.select_related = False - query.related_select_cols = [] - - result = query.get_compiler(using).execute_sql(SINGLE) + outer_query.clear_ordering(True) + outer_query.clear_limits() + outer_query.select_for_update = False + outer_query.select_related = False + outer_query.related_select_cols = [] + compiler = outer_query.get_compiler(using) + result = compiler.execute_sql(SINGLE) if result is None: - result = [None for q in query.aggregate_select.items()] + result = [None for q in outer_query.annotation_select.items()] + + fields = [annotation.output_field + for alias, annotation in outer_query.annotation_select.items()] + converters = compiler.get_converters(fields) + for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()): + if position in converters: + converters[position][1].insert(0, annotation.convert_value) + else: + converters[position] = ([], [annotation.convert_value], annotation.output_field) + result = compiler.apply_converters(result, converters) return dict( - (alias, self.resolve_aggregate(val, aggregate, connection=connections[using])) - for (alias, aggregate), val - in zip(query.aggregate_select.items(), result) + (alias, val) + for (alias, annotation), val + in zip(outer_query.annotation_select.items(), result) ) def get_count(self, using): @@ -394,7 +388,7 @@ class Query(object): Performs a COUNT() query using the current filter constraints. """ obj = self.clone() - if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields): + if len(self.select) > 1 or self.annotation_select or (self.distinct and self.distinct_fields): # If a select clause exists, then the query has already started to # specify the columns that are to be returned. # In this case, we need to use a subquery to evaluate the count. @@ -769,9 +763,9 @@ class Query(object): self.group_by = [relabel_column(col) for col in self.group_by] self.select = [SelectInfo(relabel_column(s.col), s.field) for s in self.select] - if self._aggregates: - self._aggregates = OrderedDict( - (key, relabel_column(col)) for key, col in self._aggregates.items()) + if self._annotations: + self._annotations = OrderedDict( + (key, relabel_column(col)) for key, col in self._annotations.items()) # 2. Rename the alias in the internal table/alias datastructures. for ident, aliases in self.join_map.items(): @@ -974,52 +968,18 @@ class Query(object): self.included_inherited_models = {} def add_aggregate(self, aggregate, model, alias, is_summary): + warnings.warn( + "add_aggregate() is deprecated. Use add_annotation() instead.", + RemovedInDjango20Warning, stacklevel=2) + self.add_annotation(aggregate, model, alias, is_summary) + + def add_annotation(self, annotation, model, alias, is_summary): """ - Adds a single aggregate expression to the Query + Adds a single annotation expression to the Query """ - opts = model._meta - field_list = aggregate.lookup.split(LOOKUP_SEP) - if len(field_list) == 1 and self._aggregates and aggregate.lookup in self.aggregates: - # Aggregate is over an annotation - field_name = field_list[0] - col = field_name - source = self.aggregates[field_name] - if not is_summary: - raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( - aggregate.name, field_name, field_name)) - elif ((len(field_list) > 1) or - (field_list[0] not in [i.name for i in opts.fields]) or - self.group_by is None or - not is_summary): - # If: - # - the field descriptor has more than one part (foo__bar), or - # - the field descriptor is referencing an m2m/m2o field, or - # - this is a reference to a model field (possibly inherited), or - # - this is an annotation over a model field - # then we need to explore the joins that are required. - - # Join promotion note - we must not remove any rows here, so use - # outer join if there isn't any existing join. - _, sources, opts, join_list, path = self.setup_joins( - field_list, opts, self.get_initial_alias()) - - # Process the join chain to see if it can be trimmed - targets, _, join_list = self.trim_joins(sources, join_list, path) - - col = targets[0].column - source = sources[0] - col = (join_list[-1], col) - else: - # The simplest cases. No joins required - - # just reference the provided column alias. - field_name = field_list[0] - source = opts.get_field(field_name) - col = field_name - # We want to have the alias in SELECT clause even if mask is set. - self.append_aggregate_mask([alias]) - - # Add the aggregate to the query - aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) + annotation = annotation.resolve_expression(self, summarize=is_summary) + self.append_annotation_mask([alias]) + self.annotations[alias] = annotation def prepare_lookup_value(self, value, lookups, can_reuse): # Default lookup if none given is exact. @@ -1037,9 +997,8 @@ class Query(object): "Passing callable arguments to queryset is deprecated.", RemovedInDjango19Warning, stacklevel=2) value = value() - elif isinstance(value, ExpressionNode): - # If value is a query expression, evaluate it - value = SQLEvaluator(value, self, reuse=can_reuse) + elif hasattr(value, 'resolve_expression'): + value = value.resolve_expression(self, reuse=can_reuse) if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'): value = value._clone() value.query.bump_prefix(self) @@ -1061,8 +1020,8 @@ class Query(object): Solve the lookup type from the lookup (eg: 'foobar__id__icontains') """ lookup_splitted = lookup.split(LOOKUP_SEP) - if self._aggregates: - aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) + if self._annotations: + aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations) if aggregate: return aggregate_lookups, (), aggregate _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) @@ -1232,7 +1191,11 @@ class Query(object): lookup_type = lookups[-1] else: assert(len(targets) == 1) - col = Col(alias, targets[0], field) + if hasattr(targets[0], 'as_sql'): + # handle Expressions as annotations + col = targets[0] + else: + col = Col(alias, targets[0], field) condition = self.build_lookup(lookups, col, value) if not condition: # Backwards compat for custom lookups @@ -1278,12 +1241,12 @@ class Query(object): Returns whether or not all elements of this q_object need to be put together in the HAVING clause. """ - if not self._aggregates: + if not self._annotations: return False if not isinstance(obj, Node): - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0] - or (hasattr(obj[1], 'contains_aggregate') - and obj[1].contains_aggregate(self.aggregates))) + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0] + or (hasattr(obj[1], 'refs_aggregate') + and obj[1].refs_aggregate(self.annotations)[0])) return any(self.need_having(c) for c in obj.children) def split_having_parts(self, q_object, negated=False): @@ -1390,13 +1353,21 @@ class Query(object): if name == 'pk': name = opts.pk.name try: - field, model, direct, m2m = opts.get_field_by_name(name) + field, model, _, _ = opts.get_field_by_name(name) except FieldDoesNotExist: + # is it an annotation? + if self._annotations and name in self._annotations: + field, model = self._annotations[name], None + if not field.contains_aggregate: + # Local non-relational field. + final_field = field + targets = (field,) + break # We didn't find the current field, so move position back # one step. pos -= 1 if pos == -1 or fail_on_missing: - available = opts.get_all_field_names() + list(self.aggregate_select) + available = opts.get_all_field_names() + list(self.annotation_select) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(available))) break @@ -1445,6 +1416,11 @@ class Query(object): break return path, final_field, targets, names[pos + 1:] + def raise_field_error(self, opts, name): + available = opts.get_all_field_names() + list(self.annotation_select) + raise FieldError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(available))) + def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ Compute the necessary table joins for the passage through the fields @@ -1519,6 +1495,29 @@ class Query(object): self.unref_alias(joins.pop()) return targets, joins[-1], joins + def resolve_ref(self, name, allow_joins, reuse, summarize): + if not allow_joins and LOOKUP_SEP in name: + raise FieldError("Joined field references are not permitted in this query") + if name in self.annotations: + if summarize: + return Ref(name, self.annotation_select[name]) + else: + return self.annotation_select[name] + else: + field_list = name.split(LOOKUP_SEP) + field, sources, opts, join_list, path = self.setup_joins( + field_list, self.get_meta(), + self.get_initial_alias(), reuse) + targets, _, join_list = self.trim_joins(sources, join_list, path) + if len(targets) > 1: + raise FieldError("Referencing multicolumn fields with F() objects " + "isn't supported") + if reuse is not None: + reuse.update(join_list) + col = Col(join_list[-1], targets[0], sources[0]) + col._used_joins = join_list + return col + def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): """ When doing an exclude against any kind of N-to-many relation, we need @@ -1633,7 +1632,7 @@ class Query(object): self.default_cols = False self.select_related = False self.set_extra_mask(()) - self.set_aggregate_mask(()) + self.set_annotation_mask(()) def clear_select_fields(self): """ @@ -1676,7 +1675,7 @@ class Query(object): raise else: names = sorted(opts.get_all_field_names() + list(self.extra) - + list(self.aggregate_select)) + + list(self.annotation_select)) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) self.remove_inherited_models() @@ -1725,39 +1724,55 @@ class Query(object): for col, _ in self.select: self.group_by.append(col) + if self._annotations: + for alias, annotation in six.iteritems(self.annotations): + for col in annotation.get_group_by_cols(): + self.group_by.append(col) + def add_count_column(self): """ Converts the query to do count(...) or count(distinct(pk)) in order to get its size. """ + summarize = False if not self.distinct: if not self.select: - count = self.aggregates_module.Count('*', is_summary=True) + count = Count('*') + summarize = True else: assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select': %r" % self.select - count = self.aggregates_module.Count(self.select[0].col) + col = self.select[0].col + if isinstance(col, (tuple, list)): + count = Count(col[1]) + else: + count = Count(col) + else: opts = self.get_meta() if not self.select: - count = self.aggregates_module.Count( - (self.join((None, opts.db_table, None)), opts.pk.column), - is_summary=True, distinct=True) + lookup = self.join((None, opts.db_table, None)), opts.pk.column + count = Count(lookup[1], distinct=True) + summarize = True else: # Because of SQL portability issues, multi-column, distinct # counts need a sub-query -- see get_count() for details. assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select'." - - count = self.aggregates_module.Count(self.select[0].col, distinct=True) + col = self.select[0].col + if isinstance(col, (tuple, list)): + count = Count(col[1], distinct=True) + else: + count = Count(col, distinct=True) # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False # Set only aggregate to be the count column. - # Clear out the select cache to reflect the new unmasked aggregates. - self._aggregates = {None: count} - self.set_aggregate_mask(None) + # Clear out the select cache to reflect the new unmasked annotations. + count = count.resolve_expression(self, summarize=summarize) + self._annotations = {None: count} + self.set_annotation_mask(None) self.group_by = None def add_select_related(self, fields): @@ -1886,16 +1901,28 @@ class Query(object): target[model] = set(f.name for f in fields) def set_aggregate_mask(self, names): - "Set the mask of aggregates that will actually be returned by the SELECT" + warnings.warn( + "set_aggregate_mask() is deprecated. Use set_annotation_mask() instead.", + RemovedInDjango20Warning, stacklevel=2) + self.set_annotation_mask(names) + + def set_annotation_mask(self, names): + "Set the mask of annotations that will actually be returned by the SELECT" if names is None: - self.aggregate_select_mask = None + self.annotation_select_mask = None else: - self.aggregate_select_mask = set(names) - self._aggregate_select_cache = None + self.annotation_select_mask = set(names) + self._annotation_select_cache = None def append_aggregate_mask(self, names): - if self.aggregate_select_mask is not None: - self.set_aggregate_mask(set(names).union(self.aggregate_select_mask)) + warnings.warn( + "append_aggregate_mask() is deprecated. Use append_annotation_mask() instead.", + RemovedInDjango20Warning, stacklevel=2) + self.append_annotation_mask(names) + + def append_annotation_mask(self, names): + if self.annotation_select_mask is not None: + self.set_annotation_mask(set(names).union(self.annotation_select_mask)) def set_extra_mask(self, names): """ @@ -1910,24 +1937,31 @@ class Query(object): self._extra_select_cache = None @property - def aggregate_select(self): + def annotation_select(self): """The OrderedDict of aggregate columns that are not masked, and should be used in the SELECT clause. This result is cached for optimization purposes. """ - if self._aggregate_select_cache is not None: - return self._aggregate_select_cache - elif not self._aggregates: + if self._annotation_select_cache is not None: + return self._annotation_select_cache + elif not self._annotations: return {} - elif self.aggregate_select_mask is not None: - self._aggregate_select_cache = OrderedDict( - (k, v) for k, v in self.aggregates.items() - if k in self.aggregate_select_mask + elif self.annotation_select_mask is not None: + self._annotation_select_cache = OrderedDict( + (k, v) for k, v in self.annotations.items() + if k in self.annotation_select_mask ) - return self._aggregate_select_cache + return self._annotation_select_cache else: - return self.aggregates + return self.annotations + + @property + def aggregate_select(self): + warnings.warn( + "aggregate_select() is deprecated. Use annotation_select() instead.", + RemovedInDjango20Warning, stacklevel=2) + return self.annotation_select @property def extra_select(self): diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 2f0de5b80c..6f3f7358d3 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -7,9 +7,9 @@ from django.core.exceptions import FieldError from django.db import connections from django.db.models.query_utils import Q from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import Date, DateTime, Col from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo -from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query from django.utils import six from django.utils import timezone @@ -229,7 +229,7 @@ class DateQuery(Query): )) self._check_field(field) # overridden in DateTimeQuery alias = joins[-1] - select = self._get_select((alias, field.column), lookup_type) + select = self._get_select(Col(alias, field), lookup_type) self.clear_select_clause() self.select = [SelectInfo(select, None)] self.distinct = True diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index f65e593a3a..13815cb68c 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -10,7 +10,6 @@ import warnings from django.conf import settings from django.db.models.fields import DateTimeField, Field from django.db.models.sql.datastructures import EmptyResultSet, Empty -from django.db.models.sql.aggregates import Aggregate from django.utils.deprecation import RemovedInDjango19Warning from django.utils.six.moves import xrange from django.utils import timezone @@ -78,7 +77,7 @@ class WhereNode(tree.Node): else: value_annotation = bool(value) - if hasattr(obj, "prepare"): + if hasattr(obj, 'prepare'): value = obj.prepare(lookup_type, value) return (obj, lookup_type, value_annotation, value) @@ -187,11 +186,9 @@ class WhereNode(tree.Node): lvalue, params = lvalue.process(lookup_type, params_or_value, connection) except EmptyShortCircuit: raise EmptyResultSet - elif isinstance(lvalue, Aggregate): - params = lvalue.field.get_db_prep_lookup(lookup_type, params_or_value, connection) else: - raise TypeError("'make_atom' expects a Constraint or an Aggregate " - "as the first item of its 'child' argument.") + raise TypeError("'make_atom' expects a Constraint as the first " + "item of its 'child' argument.") if isinstance(lvalue, tuple): # A direct database column lookup. diff --git a/docs/index.txt b/docs/index.txt index 9aaf5e181d..5ce5f69b19 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -86,7 +86,8 @@ manipulating the data of your Web application. Learn more about it below: :doc:`Aggregation ` | :doc:`Custom fields ` | :doc:`Multiple databases ` | - :doc:`Custom lookups ` + :doc:`Custom lookups ` | + :doc:`Query Expressions ` * **Other:** :doc:`Supported databases ` | diff --git a/docs/internals/deprecation.txt b/docs/internals/deprecation.txt index 57334840d4..0bba849713 100644 --- a/docs/internals/deprecation.txt +++ b/docs/internals/deprecation.txt @@ -41,6 +41,17 @@ details on these changes. :class:`~django.core.management.BaseCommand` instead, which takes no arguments by default. +* ``django.db.models.sql.aggregates`` module will be removed. + +* ``django.contrib.gis.db.models.sql.aggregates`` module will be removed. + +* The following methods and properties of ``django.db.sql.query.Query`` will + be removed: + + * Properties: ``aggregates`` and ``aggregate_select`` + * Methods: ``add_aggregate``, ``set_aggregate_mask``, and + ``append_aggregate_mask``. + * ``django.template.resolve_variable`` will be removed. * The ``error_message`` argument of ``django.forms.RegexField`` will be removed. diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt new file mode 100644 index 0000000000..5dadf35873 --- /dev/null +++ b/docs/ref/models/expressions.txt @@ -0,0 +1,522 @@ +================= +Query Expressions +================= + +.. currentmodule:: django.db.models + +Query expressions describe a value or a computation that can be used as part of +a filter, an annotation, or an aggregation. There are a number of built-in +expressions (documented below) that can be used to help you write queries. +Expressions can be combined, or in some cases nested, to form more complex +computations. + +Supported arithmetic +==================== + +Django supports addition, subtraction, multiplication, division, modulo +arithmetic, and the power operator on query expressions, using Python constants, +variables, and even other expressions. + +.. versionadded:: 1.7 + + Support for the power operator ``**`` was added. + +Some examples +============= + +.. versionchanged:: 1.8 + + Some of the examples rely on functionality that is new in Django 1.8. + +.. code-block:: python + + # Find companies that have more employees than chairs. + Company.objects.filter(num_employees__gt=F('num_chairs')) + + # Find companies that have at least twice as many employees + # as chairs. Both the querysets below are equivalent. + Company.objects.filter(num_employees__gt=F('num_chairs') * 2) + Company.objects.filter( + num_employees__gt=F('num_chairs') + F('num_chairs')) + + # How many chairs are needed for each company to seat all employees? + >>> company = Company.objects.filter( + ... num_employees__gt=F('num_chairs')).annotate( + ... chairs_needed=F('num_employees') - F('num_chairs')).first() + >>> company.num_employees + 120 + >>> company.num_chairs + 50 + >>> company.chairs_needed + 70 + + # Annotate models with an aggregated value. Both forms + # below are equivalent. + Company.objects.annotate(num_products=Count('products')) + Company.objects.annotate(num_products=Count(F('products'))) + + # Aggregates can contain complex computations also + Company.objects.annotate(num_offerings=Count(F('products') + F('services'))) + + +Built-in Expressions +==================== + +``F()`` expressions +------------------- + +.. class:: F + +An ``F()`` object represents the value of a model field or annotated column. It +makes it possible to refer to model field values and perform database +operations using them without actually having to pull them out of the database +into Python memory. + +Instead, Django uses the ``F()`` object to generate a SQL expression that +describes the required operation at the database level. + +This is easiest to understand through an example. Normally, one might do +something like this:: + + # Tintin filed a news story! + reporter = Reporters.objects.get(name='Tintin') + reporter.stories_filed += 1 + reporter.save() + +Here, we have pulled the value of ``reporter.stories_filed`` from the database +into memory and manipulated it using familiar Python operators, and then saved +the object back to the database. But instead we could also have done:: + + from django.db.models import F + reporter = Reporters.objects.get(name='Tintin') + reporter.stories_filed = F('stories_filed') + 1 + reporter.save() + +Although ``reporter.stories_filed = F('stories_filed') + 1`` looks like a +normal Python assignment of value to an instance attribute, in fact it's an SQL +construct describing an operation on the database. + +When Django encounters an instance of ``F()``, it overrides the standard Python +operators to create an encapsulated SQL expression; in this case, one which +instructs the database to increment the database field represented by +``reporter.stories_filed``. + +Whatever value is or was on ``reporter.stories_filed``, Python never gets to +know about it - it is dealt with entirely by the database. All Python does, +through Django's ``F()`` class, is create the SQL syntax to refer to the field +and describe the operation. + +.. note:: + + In order to access the new value that has been saved in this way, the object + will need to be reloaded:: + + reporter = Reporters.objects.get(pk=reporter.pk) + +As well as being used in operations on single instances as above, ``F()`` can +be used on ``QuerySets`` of object instances, with ``update()``. This reduces +the two queries we were using above - the ``get()`` and the +:meth:`~Model.save()` - to just one:: + + reporter = Reporters.objects.filter(name='Tintin') + reporter.update(stories_filed=F('stories_filed') + 1) + +We can also use :meth:`~django.db.models.query.QuerySet.update()` to increment +the field value on multiple objects - which could be very much faster than +pulling them all into Python from the database, looping over them, incrementing +the field value of each one, and saving each one back to the database:: + + Reporter.objects.all().update(stories_filed=F('stories_filed) + 1) + +``F()`` therefore can offer performance advantages by: + +* getting the database, rather than Python, to do work +* reducing the number of queries some operations require + +.. _avoiding-race-conditions-using-f: + +Avoiding race conditions using ``F()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Another useful benefit of ``F()`` is that having the database - rather than +Python - update a field's value avoids a *race condition*. + +If two Python threads execute the code in the first example above, one thread +could retrieve, increment, and save a field's value after the other has +retrieved it from the database. The value that the second thread saves will be +based on the original value; the work of the first thread will simply be lost. + +If the database is responsible for updating the field, the process is more +robust: it will only ever update the field based on the value of the field in +the database when the :meth:`~Model.save()` or ``update()`` is executed, rather +than based on its value when the instance was retrieved. + +Using ``F()`` in filters +~~~~~~~~~~~~~~~~~~~~~~~~ + +``F()`` is also very useful in ``QuerySet`` filters, where they make it +possible to filter a set of objects against criteria based on their field +values, rather than on Python values. + +This is documented in :ref:`using F() expressions in queries +`. + + +.. _func-expressions: + +``Func()`` expressions +---------------------- + +.. versionadded:: 1.8 + +``Func()`` expressions are the base type of all expressions that involve +database functions like ``COALESCE`` and ``LOWER``, or aggregates like ``SUM``. +They can be used directly:: + + queryset.annotate(field_lower=Func(F('field'), function='LOWER')) + +or they can be used to build a library of database functions:: + + class Lower(Func): + function = 'LOWER' + + queryset.annotate(field_lower=Lower(F('field'))) + +But both cases will result in a queryset where each model is annotated with an +extra attribute ``field_lower`` produced, roughly, from the following SQL:: + + SELECT + ... + LOWER("app_label"."field") as "field_lower" + +The ``Func`` API is as follows: + +.. class:: Func(*expressions, **extra) + + .. attribute:: function + + A class attribute describing the function that will be generated. + Specifically, the ``function`` will be interpolated as the ``function`` + placeholder within :attr:`template`. Defaults to ``None``. + + .. attribute:: template + + A class attribute, as a format string, that describes the SQL that is + generated for this function. Defaults to + ``'%(function)s(%(expressions)s)'``. + + .. attribute:: arg_joiner + + A class attribute that denotes the character used to join the list of + ``expressions`` together. Defaults to ``', '``. + +The ``*expressions`` argument is a list of positional expressions that the +function will be applied to. The expressions will be converted to strings, +joined together with ``arg_joiner``, and then interpolated into the ``template`` +as the ``expressions`` placeholder. + +The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated +into the ``template`` attribute. Note that the keywords ``function`` and +``template`` can be used to replace the ``function`` and ``template`` +attributes respectively, without having to define your own class. +``output_field`` can be used to define the expected return type. + +``Aggregate()`` expressions +--------------------------- + +An aggregate expression is a special case of a :ref:`Func() expression +` that informs the query that a ``GROUP BY`` clause +is required. All of the :ref:`aggregate functions `, +like ``Sum()`` and ``Count()``, inherit from ``Aggregate()``. + +Since ``Aggregate``\s are expressions and wrap expressions, you can represent +some complex computations:: + + Company.objects.annotate( + managers_required=(Count('num_employees') / 4) + Count('num_managers')) + +The ``Aggregate`` API is as follows: + +.. class:: Aggregate(expression, output_field=None, **extra) + + .. attribute:: template + + A class attribute, as a format string, that describes the SQL that is + generated for this aggregate. Defaults to + ``'%(function)s( %(expressions)s )'``. + + .. attribute:: function + + A class attribute describing the aggregate function that will be + generated. Specifically, the ``function`` will be interpolated as the + ``function`` placeholder within :attr:`template`. Defaults to ``None``. + +The ``expression`` argument can be the name of a field on the model, or another +expression. It will be converted to a string and used as the ``expressions`` +placeholder within the ``template``. + +The ``output_field`` argument requires a model field instance, like +``IntegerField()`` or ``BooleanField()``, into which Django will load the value +after it's retrieved from the database. + +Note that ``output_field`` is only required when Django is unable to determine +what field type the result should be. Complex expressions that mix field types +should define the desired ``output_field``. For example, adding an +``IntegerField()`` and a ``FloatField()`` together should probably have +``output_field=FloatField()`` defined. + +.. versionchanged:: 1.8 + + ``output_field`` is a new parameter. + +The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated +into the ``template`` attribute. + +.. versionadded:: 1.8 + + Aggregate functions can now use arithmetic and reference multiple + model fields in a single function. + +Creating your own Aggregate Functions +------------------------------------- + +Creating your own aggregate is extremely easy. At a minimum, you need +to define ``function``, but you can also completely customize the +SQL that is generated. Here's a brief example:: + + class Count(Aggregate): + # supports COUNT(distinct field) + function = 'COUNT' + template = '%(function)s(%(distinct)s%(expressions)s)' + + def __init__(self, expression, distinct=False, **extra): + super(Count, self).__init__( + expression, + distinct='DISTINCT ' if distinct else '', + output_field=IntegerField(), + **extra) + + +``Value()`` expressions +----------------------- + +.. class:: Value(value, output_field=None) + + +A ``Value()`` object represents the smallest possible component of an +expression: a simple value. When you need to represent the value of an integer, +boolean, or string within an expression, you can wrap that value within a +``Value()``. + +You will rarely need to use ``Value()`` directly. When you write the expression +``F('field') + 1``, Django implicitly wraps the ``1`` in a ``Value()``, +allowing simple values to be used in more complex expressions. + +The ``value`` argument describes the value to be included in the expression, +such as ``1``, ``True``, or ``None``. Django knows how to convert these Python +values into their corresponding database type. + +The ``output_field`` argument should be a model field instance, like +``IntegerField()`` or ``BooleanField()``, into which Django will load the value +after it's retrieved from the database. + + +Technical Information +===================== + +Below you'll find technical implementation details that may be useful to +library authors. The technical API and examples below will help with +creating generic query expressions that can extend the built-in functionality +that Django provides. + +Expression API +-------------- + +Query expressions implement the :ref:`query expression API `, +but also expose a number of extra methods and attributes listed below. All +query expressions must inherit from ``ExpressionNode()`` or a relevant +subclass. + +When a query expression wraps another expression, it is responsible for +calling the appropriate methods on the wrapped expression. + +.. class:: ExpressionNode + + .. attribute:: contains_aggregate + + Tells Django that this expression contains an aggregate and that a + ``GROUP BY`` clause needs to be added to the query. + + .. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False) + + Provides the chance to do any pre-processing or validation of + the expression before it's added to the query. ``resolve_expression()`` + must also be called on any nested expressions. A ``copy()`` of ``self`` + should be returned with any necessary transformations. + + ``query`` is the backend query implementation. + + ``allow_joins`` is a boolean that allows or denies the use of + joins in the query. + + ``reuse`` is a set of reusable joins for multi-join scenarios. + + ``summarize`` is a boolean that, when ``True``, signals that the + query being computed is a terminal aggregate query. + + .. method:: get_source_expressions() + + Returns an ordered list of inner expressions. For example:: + + >>> Sum(F('foo')).get_source_expressions() + [F('foo')] + + .. method:: set_source_expressions(expressions) + + Takes a list of expressions and stores them such that + ``get_source_expressions()`` can return them. + + .. method:: relabeled_clone(change_map) + + Returns a clone (copy) of ``self``, with any column aliases relabeled. + Column aliases are renamed when subqueries are created. + ``relabeled_clone()`` should also be called on any nested expressions + and assigned to the clone. + + ``change_map`` is a dictionary mapping old aliases to new aliases. + + Example:: + + def relabeled_clone(self, change_map): + clone = copy.copy(self) + clone.expression = self.expression.relabeled_clone(change_map) + return clone + + .. method:: convert_value(self, value, connection) + + A hook allowing the expression to coerce ``value`` into a more + appropriate type. + + .. method:: refs_aggregate(existing_aggregates) + + Returns a tuple containing the ``(aggregate, lookup_path)`` of the + first aggregate that this expression (or any nested expression) + references, or ``(False, ())`` if no aggregate is referenced. + For example:: + + queryset.filter(num_chairs__gt=F('sum__employees')) + + The ``F()`` expression here references a previous ``Sum()`` + computation which means that this filter expression should be + added to the ``HAVING`` clause rather than the ``WHERE`` clause. + + In the majority of cases, returning the result of ``refs_aggregate`` + on any nested expression should be appropriate, as the necessary + built-in expressions will return the correct values. + + .. method:: get_group_by_cols() + + Responsible for returning the list of columns references by + this expression. ``get_group_by_cols()`` should be called on any + nested expressions. ``F()`` objects, in particular, hold a reference + to a column. + +Writing your own Query Expressions +---------------------------------- + +You can write your own query expression classes that use, and can integrate +with, other query expressions. Let's step through an example by writing an +implementation of the ``COALESCE`` SQL function, without using the built-in +:ref:`Func() expressions `. + +The ``COALESCE`` SQL function is defined as taking a list of columns or +values. It will return the first column or value that isn't ``NULL``. + +We'll start by defining the template to be used for SQL generation and +an ``__init__()`` method to set some attributes:: + + import copy + from django.db.models import ExpressionNode + + class Coalesce(ExpressionNode): + template = 'COALESCE( %(expressions)s )' + + def __init__(self, expressions, output_field, **extra): + super(Coalesce, self).__init__(output_field=output_field) + if len(expressions) < 2: + raise ValueError('expressions must have at least 2 elements') + for expression in expressions: + if not hasattr(expression, 'resolve_expression'): + raise TypeError('%r is not an Expression' % expression) + self.expressions = expressions + self.extra = extra + +We do some basic validation on the parameters, including requiring at least +2 columns or values, and ensuring they are expressions. We are requiring +``output_field`` here so that Django knows what kind of model field to assign +the eventual result to. + +Now we implement the pre-processing and validation. Since we do not have +any of our own validation at this point, we just delegate to the nested +expressions:: + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + c = self.copy() + c.is_summary = summarize + for pos, expression in enumerate(self.expressions): + c.expressions[pos] = expression.resolve_expression(query, allow_joins, reuse, summarize) + return c + +Next, we write the method responsible for generating the SQL:: + + def as_sql(self, compiler, connection): + sql_expressions, sql_params = [], [] + for expression in self.expressions: + sql, params = compiler.compile(expression) + sql_expressions.append(sql) + sql_params.extend(params) + self.extra['expressions'] = ','.join(sql_expressions) + return self.template % self.extra, sql_params + + def as_oracle(self, compiler, connection): + """ + Example of vendor specific handling (Oracle in this case). + Let's make the function name lowercase. + """ + self.template = 'coalesce( %(expressions)s )' + return self.as_sql(compiler, connection) + +We generate the SQL for each of the ``expressions`` by using the +``compiler.compile()`` method, and join the result together with commas. +Then the template is filled out with our data and the SQL and parameters +are returned. + +We've also defined a custom implementation that is specific to the Oracle +backend. The ``as_oracle()`` function will be called instead of ``as_sql()`` +if the Oracle backend is in use. + +Finally, we implement the rest of the methods that allow our query expression +to play nice with other query expressions:: + + def get_source_expressions(self): + return self.expressions + + def set_source_expressions(expressions): + self.expressions = expressions + +Let's see how it works:: + + >>> qs = Company.objects.annotate( + ... tagline=Coalesce([ + ... F('motto'), + ... F('ticker_name'), + ... F('description'), + ... Value('No Tagline') + ... ], output_field=CharField())) + >>> for c in qs: + ... print("%s: %s" % (c.name, c.tagline)) + ... + Google: Do No Evil + Apple: AAPL + Yahoo: Internet Company + Django Software Foundation: No Tagline diff --git a/docs/ref/models/index.txt b/docs/ref/models/index.txt index 57385b200d..b7f5ab6635 100644 --- a/docs/ref/models/index.txt +++ b/docs/ref/models/index.txt @@ -15,3 +15,4 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`. querysets queries lookups + expressions diff --git a/docs/ref/models/queries.txt b/docs/ref/models/queries.txt index 1eab21c528..283983aef3 100644 --- a/docs/ref/models/queries.txt +++ b/docs/ref/models/queries.txt @@ -7,115 +7,6 @@ Query-related classes This document provides reference material for query-related tools not documented elsewhere. -``F()`` expressions -=================== - -.. class:: F - -An ``F()`` object represents the value of a model field. It makes it possible -to refer to model field values and perform database operations using them -without actually having to pull them out of the database into Python memory. - -Instead, Django uses the ``F()`` object to generate a SQL expression that -describes the required operation at the database level. - -This is easiest to understand through an example. Normally, one might do -something like this:: - - # Tintin filed a news story! - reporter = Reporters.objects.get(name='Tintin') - reporter.stories_filed += 1 - reporter.save() - -Here, we have pulled the value of ``reporter.stories_filed`` from the database -into memory and manipulated it using familiar Python operators, and then saved -the object back to the database. But instead we could also have done:: - - from django.db.models import F - reporter = Reporters.objects.get(name='Tintin') - reporter.stories_filed = F('stories_filed') + 1 - reporter.save() - -Although ``reporter.stories_filed = F('stories_filed') + 1`` looks like a -normal Python assignment of value to an instance attribute, in fact it's an SQL -construct describing an operation on the database. - -When Django encounters an instance of ``F()``, it overrides the standard Python -operators to create an encapsulated SQL expression; in this case, one which -instructs the database to increment the database field represented by -``reporter.stories_filed``. - -Whatever value is or was on ``reporter.stories_filed``, Python never gets to -know about it - it is dealt with entirely by the database. All Python does, -through Django's ``F()`` class, is create the SQL syntax to refer to the field -and describe the operation. - -.. note:: - - In order to access the new value that has been saved in this way, the object - will need to be reloaded:: - - reporter = Reporters.objects.get(pk=reporter.pk) - -As well as being used in operations on single instances as above, ``F()`` can -be used on ``QuerySets`` of object instances, with ``update()``. This reduces -the two queries we were using above - the ``get()`` and the -:meth:`~Model.save()` - to just one:: - - reporter = Reporters.objects.filter(name='Tintin') - reporter.update(stories_filed=F('stories_filed') + 1) - -We can also use :meth:`~django.db.models.query.QuerySet.update()` to increment -the field value on multiple objects - which could be very much faster than -pulling them all into Python from the database, looping over them, incrementing -the field value of each one, and saving each one back to the database:: - - Reporter.objects.all().update(stories_filed=F('stories_filed') + 1) - -``F()`` therefore can offer performance advantages by: - -* getting the database, rather than Python, to do work -* reducing the number of queries some operations require - -.. _avoiding-race-conditions-using-f: - -Avoiding race conditions using ``F()`` --------------------------------------- - -Another useful benefit of ``F()`` is that having the database - rather than -Python - update a field's value avoids a *race condition*. - -If two Python threads execute the code in the first example above, one thread -could retrieve, increment, and save a field's value after the other has -retrieved it from the database. The value that the second thread saves will be -based on the original value; the work of the first thread will simply be lost. - -If the database is responsible for updating the field, the process is more -robust: it will only ever update the field based on the value of the field in -the database when the :meth:`~Model.save()` or ``update()`` is executed, rather -than based on its value when the instance was retrieved. - -Using ``F()`` in filters ------------------------- - -``F()`` is also very useful in ``QuerySet`` filters, where they make it -possible to filter a set of objects against criteria based on their field -values, rather than on Python values. - -This is documented in :ref:`using F() expressions in queries -` - -Supported operations with ``F()`` ---------------------------------- - -As well as addition, Django supports subtraction, multiplication, division, -and modulo arithmetic with ``F()`` objects, using Python constants, -variables, and even other ``F()`` objects. - -.. versionadded:: 1.7 - - The power operator ``**`` is also supported. - ``Q()`` objects =============== diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 218a4ff35f..1cf215bc46 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -220,9 +220,18 @@ annotate .. method:: annotate(*args, **kwargs) -Annotates each object in the ``QuerySet`` with the provided list of -aggregate values (averages, sums, etc) that have been computed over -the objects that are related to the objects in the ``QuerySet``. +Annotates each object in the ``QuerySet`` with the provided list of :doc:`query +expressions `. An expression may be a simple value, a +reference to a field on the model (or any related models), or an aggregate +expression (averages, sums, etc) that has been computed over the objects that +are related to the objects in the ``QuerySet``. + +.. versionadded:: 1.8 + + Previous versions of Django only allowed aggregate functions to be used as + annotations. It is now possible to annotate a model with all kinds of + expressions. + Each argument to ``annotate()`` is an annotation that will be added to each object in the ``QuerySet`` that is returned. @@ -232,7 +241,9 @@ in `Aggregation Functions`_ below. Annotations specified using keyword arguments will use the keyword as the alias for the annotation. Anonymous arguments will have an alias generated for them based upon the name of the aggregate function and -the model field that is being aggregated. +the model field that is being aggregated. Only aggregate expressions +that reference a single field can be anonymous arguments. Everything +else must be a keyword argument. For example, if you were manipulating a list of blogs, you may want to determine how many entries have been made in each blog:: @@ -1886,12 +1897,15 @@ the ``QuerySet``. Each argument to ``aggregate()`` specifies a value that will be included in the dictionary that is returned. The aggregation functions that are provided by Django are described in -`Aggregation Functions`_ below. +`Aggregation Functions`_ below. Since aggregates are also :doc:`query +expressions `, you may combine aggregates with other +aggregates or values to create complex aggregates. Aggregates specified using keyword arguments will use the keyword as the name for the annotation. Anonymous arguments will have a name generated for them based upon the name of the aggregate function and the model field that is being -aggregated. +aggregated. Complex aggregates cannot use anonymous arguments and must specify +a keyword argument as an alias. For example, when you are working with blog entries, you may want to know the number of authors that have contributed blog entries:: @@ -2667,8 +2681,9 @@ Aggregation functions Django provides the following aggregation functions in the ``django.db.models`` module. For details on how to use these -aggregate functions, see -:doc:`the topic guide on aggregation `. +aggregate functions, see :doc:`the topic guide on aggregation +`. See the :class:`~django.db.models.Aggregate` +documentation to learn how to create your aggregates. .. warning:: @@ -2685,12 +2700,47 @@ aggregate functions, see instead of ``0`` if the ``QuerySet`` contains no entries. An exception is ``Count``, which does return ``0`` if the ``QuerySet`` is empty. +All aggregates have the following parameters in common: + +``expression`` +~~~~~~~~~~~~~~ + +A string that references a field on the model, or a :doc:`query expression +`. + +.. versionadded:: 1.8 + + Aggregate functions are now able to reference multiple fields in complex + computations. + +``output_field`` +~~~~~~~~~~~~~~~~ + +An optional argument that represents the :doc:`model field ` +of the return value + +.. versionadded:: 1.8 + + The ``output_field`` argument was added. + +.. note:: + + When combining multiple field types, Django can only determine the + ``output_field`` if all fields are of the same type. Otherwise, you + must provide the ``output_field`` yourself. + +``**extra`` +~~~~~~~~~~~ + +Keyword arguments that can provide extra context for the SQL generated +by the aggregate. + Avg ~~~ -.. class:: Avg(field) +.. class:: Avg(expression, output_field=None, **extra) - Returns the mean value of the given field, which must be numeric. + Returns the mean value of the given expression, which must be numeric. * Default alias: ``__avg`` * Return type: ``float`` @@ -2698,9 +2748,10 @@ Avg Count ~~~~~ -.. class:: Count(field, distinct=False) +.. class:: Count(expression, distinct=False, **extra) - Returns the number of objects that are related through the provided field. + Returns the number of objects that are related through the provided + expression. * Default alias: ``__count`` * Return type: ``int`` @@ -2716,29 +2767,29 @@ Count Max ~~~ -.. class:: Max(field) +.. class:: Max(expression, output_field=None, **extra) - Returns the maximum value of the given field. + Returns the maximum value of the given expression. * Default alias: ``__max`` - * Return type: same as input field + * Return type: same as input field, or ``output_field`` if supplied Min ~~~ -.. class:: Min(field) +.. class:: Min(expression, output_field=None, **extra) - Returns the minimum value of the given field. + Returns the minimum value of the given expression. * Default alias: ``__min`` - * Return type: same as input field + * Return type: same as input field, or ``output_field`` if supplied StdDev ~~~~~~ -.. class:: StdDev(field, sample=False) +.. class:: StdDev(expression, sample=False, **extra) - Returns the standard deviation of the data in the provided field. + Returns the standard deviation of the data in the provided expression. * Default alias: ``__stddev`` * Return type: ``float`` @@ -2760,19 +2811,19 @@ StdDev Sum ~~~ -.. class:: Sum(field) +.. class:: Sum(expression, output_field=None, **extra) - Computes the sum of all values of the given field. + Computes the sum of all values of the given expression. * Default alias: ``__sum`` - * Return type: same as input field + * Return type: same as input field, or ``output_field`` if supplied Variance ~~~~~~~~ -.. class:: Variance(field, sample=False) +.. class:: Variance(expression, sample=False, **extra) - Returns the variance of the data in the provided field. + Returns the variance of the data in the provided expression. * Default alias: ``__variance`` * Return type: ``float`` diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 729c19f3c2..6bfe9435a6 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -52,6 +52,15 @@ New data types `. It is stored as the native ``uuid`` data type on PostgreSQL and as a fixed length character field on other backends. +Query Expressions +~~~~~~~~~~~~~~~~~ + +:doc:`Query Expressions ` allow users to create, +customize, and compose complex SQL expressions. This has enabled annotate +to accept expressions other than aggregates. Aggregates are now able to +reference multiple fields, as well as perform arithmetic, similar to ``F()`` +objects. + Minor features ~~~~~~~~~~~~~~ @@ -857,6 +866,29 @@ or ``name='django.contrib.gis.sitemaps.views.kmz'``. .. _security issue: https://www.djangoproject.com/weblog/2014/apr/21/security/#s-issue-unexpected-code-execution-using-reverse +Aggregate methods and modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``django.db.models.sql.aggregates`` and +``django.contrib.gis.db.models.sql.aggregates`` modules (both private API), have +been deprecated as ``django.db.models.aggregates`` and +``django.contrib.gis.db.models.aggregates`` are now also responsible +for SQL generation. The old modules will be removed in Django 2.0. + +If you were using the old modules, see :doc:`Query Expressions +` for instructions on rewriting custom aggregates +using the new stable API. + +The following methods and properties of ``django.db.models.sql.query.Query`` +have also been deprecated and the backwards compatibility shims will be removed +in Django 2.0: + +* ``Query.aggregates``, replaced by ``annotations``. +* ``Query.aggregate_select``, replaced by ``annotation_select``. +* ``Query.add_aggregate()``, replaced by ``add_annotation()``. +* ``Query.set_aggregate_mask()``, replaced by ``set_annotation_mask()``. +* ``Query.append_aggregate_mask()``, replaced by ``append_annotation_mask()``. + Extending management command arguments through ``Command.option_list`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/topics/db/aggregation.txt b/docs/topics/db/aggregation.txt index 436256483d..175cc21782 100644 --- a/docs/topics/db/aggregation.txt +++ b/docs/topics/db/aggregation.txt @@ -67,6 +67,11 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above >>> Book.objects.all().aggregate(Max('price')) {'price__max': Decimal('81.20')} + # Cost per page + >>> Book.objects.all().aggregate( + ... price_per_page=Sum(F('price')/F('pages'), output_field=FloatField())) + {'price_per_page': 0.4470664529184653} + # All the following queries involve traversing the Book<->Publisher # many-to-many relationship backward diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index b1b6199ffa..851ce69db0 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -3,12 +3,21 @@ from __future__ import unicode_literals import datetime from decimal import Decimal import re +import warnings +from django.core.exceptions import FieldError from django.db import connection -from django.db.models import Avg, Sum, Count, Max, Min +from django.db.models import ( + Avg, Sum, Count, Max, Min, + Aggregate, F, Value, Func, + IntegerField, FloatField, DecimalField) +with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + from django.db.models.sql import aggregates as sql_aggregates from django.test import TestCase from django.test.utils import Approximate from django.test.utils import CaptureQueriesContext +from django.utils.deprecation import RemovedInDjango20Warning from .models import Author, Publisher, Book, Store @@ -678,3 +687,271 @@ class BaseAggregateTestCase(TestCase): else: self.assertNotIn('order by', qstr) self.assertEqual(qstr.count(' join '), 0) + + +class ComplexAggregateTestCase(TestCase): + fixtures = ["aggregation.json"] + + def test_nonaggregate_aggregation_throws(self): + with self.assertRaisesRegexp(TypeError, 'fail is not an aggregate expression'): + Book.objects.aggregate(fail=F('price')) + + def test_nonfield_annotation(self): + book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField())))[0] + self.assertEqual(book.val, 2) + book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField()))[0] + self.assertEqual(book.val, 2) + + def test_missing_output_field_raises_error(self): + with self.assertRaisesRegexp(FieldError, 'Cannot resolve expression type, unknown output_field'): + Book.objects.annotate(val=Max(Value(2)))[0] + + def test_annotation_expressions(self): + authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') + authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name') + for qs in (authors, authors2): + self.assertEqual(len(qs), 9) + self.assertQuerysetEqual( + qs, [ + ('Adrian Holovaty', 132), + ('Brad Dayley', None), + ('Jacob Kaplan-Moss', 129), + ('James Bennett', 63), + ('Jeffrey Forcier', 128), + ('Paul Bissex', 120), + ('Peter Norvig', 103), + ('Stuart Russell', 103), + ('Wesley J. Chun', 176) + ], + lambda a: (a.name, a.combined_ages) + ) + + def test_aggregation_expressions(self): + a1 = Author.objects.aggregate(av_age=Sum('age') / Count('*')) + a2 = Author.objects.aggregate(av_age=Sum('age') / Count('age')) + a3 = Author.objects.aggregate(av_age=Avg('age')) + self.assertEqual(a1, {'av_age': 37}) + self.assertEqual(a2, {'av_age': 37}) + self.assertEqual(a3, {'av_age': Approximate(37.4, places=1)}) + + def test_order_of_precedence(self): + p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3) + self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)}) + + p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3) + self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)}) + + def test_combine_different_types(self): + with self.assertRaisesRegexp(FieldError, 'Expression contains mixed types. You must set output_field'): + Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')).get(pk=4) + + b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=IntegerField())).get(pk=4) + self.assertEqual(b1.sums, 383) + + b2 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=FloatField())).get(pk=4) + self.assertEqual(b2.sums, 383.69) + + b3 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=DecimalField(max_digits=6, decimal_places=2))).get(pk=4) + self.assertEqual(b3.sums, Decimal("383.69")) + + def test_complex_aggregations_require_kwarg(self): + with self.assertRaisesRegexp(TypeError, 'Complex expressions require an alias'): + Author.objects.annotate(Sum(F('age') + F('friends__age'))) + with self.assertRaisesRegexp(TypeError, 'Complex aggregates require an alias'): + Author.objects.aggregate(Sum('age') / Count('age')) + + def test_aggregate_over_complex_annotation(self): + qs = Author.objects.annotate( + combined_ages=Sum(F('age') + F('friends__age'))) + + age = qs.aggregate(max_combined_age=Max('combined_ages')) + self.assertEqual(age['max_combined_age'], 176) + + age = qs.aggregate(max_combined_age_doubled=Max('combined_ages') * 2) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + + age = qs.aggregate( + max_combined_age_doubled=Max('combined_ages') + Max('combined_ages')) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + + age = qs.aggregate( + max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'), + sum_combined_age=Sum('combined_ages')) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + self.assertEqual(age['sum_combined_age'], 954) + + age = qs.aggregate( + max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'), + sum_combined_age_doubled=Sum('combined_ages') + Sum('combined_ages')) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + self.assertEqual(age['sum_combined_age_doubled'], 954 * 2) + + def test_values_annotation_with_expression(self): + # ensure the F() is promoted to the group by clause + qs = Author.objects.values('name').annotate(another_age=Sum('age') + F('age')) + a = qs.get(pk=1) + self.assertEqual(a['another_age'], 68) + + qs = qs.annotate(friend_count=Count('friends')) + a = qs.get(pk=1) + self.assertEqual(a['friend_count'], 2) + + qs = qs.annotate(combined_age=Sum('age') + F('friends__age')).filter(pk=1).order_by('-combined_age') + self.assertEqual( + list(qs), [ + { + "name": 'Adrian Holovaty', + "another_age": 68, + "friend_count": 1, + "combined_age": 69 + }, + { + "name": 'Adrian Holovaty', + "another_age": 68, + "friend_count": 1, + "combined_age": 63 + } + ] + ) + + vals = qs.values('name', 'combined_age') + self.assertEqual( + list(vals), [ + { + "name": 'Adrian Holovaty', + "combined_age": 69 + }, + { + "name": 'Adrian Holovaty', + "combined_age": 63 + } + ] + ) + + def test_annotate_values_aggregate(self): + alias_age = Author.objects.annotate( + age_alias=F('age') + ).values( + 'age_alias', + ).aggregate(sum_age=Sum('age_alias')) + + age = Author.objects.values('age').aggregate(sum_age=Sum('age')) + + self.assertEqual(alias_age['sum_age'], age['sum_age']) + + def test_annotate_over_annotate(self): + author = Author.objects.annotate( + age_alias=F('age') + ).annotate( + sum_age=Sum('age_alias') + ).get(pk=1) + + other_author = Author.objects.annotate( + sum_age=Sum('age') + ).get(pk=1) + + self.assertEqual(author.sum_age, other_author.sum_age) + + def test_annotated_aggregate_over_annotated_aggregate(self): + with self.assertRaisesRegexp(FieldError, "Cannot compute Sum\('id__max'\): 'id__max' is an aggregate"): + Book.objects.annotate(Max('id')).annotate(Sum('id__max')) + + def test_add_implementation(self): + try: + # test completely changing how the output is rendered + def lower_case_function_override(self, qn, connection): + sql, params = qn.compile(self.source_expressions[0]) + substitutions = dict(function=self.function.lower(), expressions=sql) + substitutions.update(self.extra) + return self.template % substitutions, params + setattr(Sum, 'as_' + connection.vendor, lower_case_function_override) + + qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=IntegerField())) + self.assertEqual(str(qs.query).count('sum('), 1) + b1 = qs.get(pk=4) + self.assertEqual(b1.sums, 383) + + # test changing the dict and delegating + def lower_case_function_super(self, qn, connection): + self.extra['function'] = self.function.lower() + return super(Sum, self).as_sql(qn, connection) + setattr(Sum, 'as_' + connection.vendor, lower_case_function_super) + + qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=IntegerField())) + self.assertEqual(str(qs.query).count('sum('), 1) + b1 = qs.get(pk=4) + self.assertEqual(b1.sums, 383) + + # test overriding all parts of the template + def be_evil(self, qn, connection): + substitutions = dict(function='MAX', expressions='2') + substitutions.update(self.extra) + return self.template % substitutions, () + setattr(Sum, 'as_' + connection.vendor, be_evil) + + qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=IntegerField())) + self.assertEqual(str(qs.query).count('MAX('), 1) + b1 = qs.get(pk=4) + self.assertEqual(b1.sums, 2) + finally: + delattr(Sum, 'as_' + connection.vendor) + + def test_complex_values_aggregation(self): + max_rating = Book.objects.values('rating').aggregate( + double_max_rating=Max('rating') + Max('rating')) + self.assertEqual(max_rating['double_max_rating'], 5 * 2) + + max_books_per_rating = Book.objects.values('rating').annotate( + books_per_rating=Count('id') + 5 + ).aggregate(Max('books_per_rating')) + self.assertEqual( + max_books_per_rating, + {'books_per_rating__max': 3 + 5}) + + def test_expression_on_aggregation(self): + + # Create a plain expression + class Greatest(Func): + function = 'GREATEST' + + def as_sqlite(self, qn, connection): + return super(Greatest, self).as_sql(qn, connection, function='MAX') + + qs = Publisher.objects.annotate( + price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) + ).filter(price_or_median__gte=F('num_awards')).order_by('pk') + self.assertQuerysetEqual( + qs, [1, 2, 3, 4], lambda v: v.pk) + + qs2 = Publisher.objects.annotate( + rating_or_num_awards=Greatest(Avg('book__rating'), F('num_awards'), + output_field=FloatField()) + ).filter(rating_or_num_awards__gt=F('num_awards')).order_by('pk') + self.assertQuerysetEqual( + qs2, [1, 2], lambda v: v.pk) + + def test_backwards_compatibility(self): + + class SqlNewSum(sql_aggregates.Aggregate): + sql_function = 'SUM' + + class NewSum(Aggregate): + name = 'Sum' + + def add_to_query(self, query, alias, col, source, is_summary): + klass = SqlNewSum + aggregate = klass( + col, source=source, is_summary=is_summary, **self.extra) + query.annotations[alias] = aggregate + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInDjango20Warning) + qs = Author.objects.values('name').annotate(another_age=NewSum('age') + F('age')) + a = qs.get(pk=1) + self.assertEqual(a['another_age'], 68) diff --git a/tests/annotations/__init__.py b/tests/annotations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/annotations/fixtures/annotations.json b/tests/annotations/fixtures/annotations.json new file mode 100644 index 0000000000..09c9b8346b --- /dev/null +++ b/tests/annotations/fixtures/annotations.json @@ -0,0 +1,243 @@ +[ + { + "pk": 1, + "model": "annotations.publisher", + "fields": { + "name": "Apress", + "num_awards": 3 + } + }, + { + "pk": 2, + "model": "annotations.publisher", + "fields": { + "name": "Sams", + "num_awards": 1 + } + }, + { + "pk": 3, + "model": "annotations.publisher", + "fields": { + "name": "Prentice Hall", + "num_awards": 7 + } + }, + { + "pk": 4, + "model": "annotations.publisher", + "fields": { + "name": "Morgan Kaufmann", + "num_awards": 9 + } + }, + { + "pk": 5, + "model": "annotations.publisher", + "fields": { + "name": "Jonno's House of Books", + "num_awards": 0 + } + }, + { + "pk": 1, + "model": "annotations.book", + "fields": { + "publisher": 1, + "isbn": "159059725", + "name": "The Definitive Guide to Django: Web Development Done Right", + "price": "30.00", + "rating": 4.5, + "authors": [1, 2], + "contact": 1, + "pages": 447, + "pubdate": "2007-12-6" + } + }, + { + "pk": 2, + "model": "annotations.book", + "fields": { + "publisher": 2, + "isbn": "067232959", + "name": "Sams Teach Yourself Django in 24 Hours", + "price": "23.09", + "rating": 3.0, + "authors": [3], + "contact": 3, + "pages": 528, + "pubdate": "2008-3-3" + } + }, + { + "pk": 3, + "model": "annotations.book", + "fields": { + "publisher": 1, + "isbn": "159059996", + "name": "Practical Django Projects", + "price": "29.69", + "rating": 4.0, + "authors": [4], + "contact": 4, + "pages": 300, + "pubdate": "2008-6-23" + } + }, + { + "pk": 4, + "model": "annotations.book", + "fields": { + "publisher": 3, + "isbn": "013235613", + "name": "Python Web Development with Django", + "price": "29.69", + "rating": 4.0, + "authors": [5, 6, 7], + "contact": 5, + "pages": 350, + "pubdate": "2008-11-3" + } + }, + { + "pk": 5, + "model": "annotations.book", + "fields": { + "publisher": 3, + "isbn": "013790395", + "name": "Artificial Intelligence: A Modern Approach", + "price": "82.80", + "rating": 4.0, + "authors": [8, 9], + "contact": 8, + "pages": 1132, + "pubdate": "1995-1-15" + } + }, + { + "pk": 6, + "model": "annotations.book", + "fields": { + "publisher": 4, + "isbn": "155860191", + "name": "Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp", + "price": "75.00", + "rating": 5.0, + "authors": [8], + "contact": 8, + "pages": 946, + "pubdate": "1991-10-15" + } + }, + { + "pk": 1, + "model": "annotations.store", + "fields": { + "books": [1, 2, 3, 4, 5, 6], + "name": "Amazon.com", + "original_opening": "1994-4-23 9:17:42", + "friday_night_closing": "23:59:59" + } + }, + { + "pk": 2, + "model": "annotations.store", + "fields": { + "books": [1, 3, 5, 6], + "name": "Books.com", + "original_opening": "2001-3-15 11:23:37", + "friday_night_closing": "23:59:59" + } + }, + { + "pk": 3, + "model": "annotations.store", + "fields": { + "books": [3, 4, 6], + "name": "Mamma and Pappa's Books", + "original_opening": "1945-4-25 16:24:14", + "friday_night_closing": "21:30:00" + } + }, + { + "pk": 1, + "model": "annotations.author", + "fields": { + "age": 34, + "friends": [2, 4], + "name": "Adrian Holovaty" + } + }, + { + "pk": 2, + "model": "annotations.author", + "fields": { + "age": 35, + "friends": [1, 7], + "name": "Jacob Kaplan-Moss" + } + }, + { + "pk": 3, + "model": "annotations.author", + "fields": { + "age": 45, + "friends": [], + "name": "Brad Dayley" + } + }, + { + "pk": 4, + "model": "annotations.author", + "fields": { + "age": 29, + "friends": [1], + "name": "James Bennett" + } + }, + { + "pk": 5, + "model": "annotations.author", + "fields": { + "age": 37, + "friends": [6, 7], + "name": "Jeffrey Forcier" + } + }, + { + "pk": 6, + "model": "annotations.author", + "fields": { + "age": 29, + "friends": [5, 7], + "name": "Paul Bissex" + } + }, + { + "pk": 7, + "model": "annotations.author", + "fields": { + "age": 25, + "friends": [2, 5, 6], + "name": "Wesley J. Chun" + } + }, + { + "pk": 8, + "model": "annotations.author", + "fields": { + "age": 57, + "friends": [9], + "name": "Peter Norvig" + } + }, + { + "pk": 9, + "model": "annotations.author", + "fields": { + "age": 46, + "friends": [8], + "name": "Stuart Russell" + } + } +] diff --git a/tests/annotations/models.py b/tests/annotations/models.py new file mode 100644 index 0000000000..0c438b99f8 --- /dev/null +++ b/tests/annotations/models.py @@ -0,0 +1,86 @@ +# coding: utf-8 +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Author(models.Model): + name = models.CharField(max_length=100) + age = models.IntegerField() + friends = models.ManyToManyField('self', blank=True) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Publisher(models.Model): + name = models.CharField(max_length=255) + num_awards = models.IntegerField() + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Book(models.Model): + isbn = models.CharField(max_length=9) + name = models.CharField(max_length=255) + pages = models.IntegerField() + rating = models.FloatField() + price = models.DecimalField(decimal_places=2, max_digits=6) + authors = models.ManyToManyField(Author) + contact = models.ForeignKey(Author, related_name='book_contact_set') + publisher = models.ForeignKey(Publisher) + pubdate = models.DateField() + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Store(models.Model): + name = models.CharField(max_length=255) + books = models.ManyToManyField(Book) + original_opening = models.DateTimeField() + friday_night_closing = models.TimeField() + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class DepartmentStore(Store): + chain = models.CharField(max_length=255) + + def __str__(self): + return '%s - %s ' % (self.chain, self.name) + + +@python_2_unicode_compatible +class Employee(models.Model): + # The order of these fields matter, do not change. Certain backends + # rely on field ordering to perform database conversions, and this + # model helps to test that. + first_name = models.CharField(max_length=20) + manager = models.BooleanField(default=False) + last_name = models.CharField(max_length=20) + store = models.ForeignKey(Store) + age = models.IntegerField() + salary = models.DecimalField(max_digits=8, decimal_places=2) + + def __str__(self): + return '%s %s' % (self.first_name, self.last_name) + + +@python_2_unicode_compatible +class Company(models.Model): + name = models.CharField(max_length=200) + motto = models.CharField(max_length=200, null=True, blank=True) + ticker_name = models.CharField(max_length=10, null=True, blank=True) + description = models.CharField(max_length=200, null=True, blank=True) + + def __str__(self): + return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)' + % (self.name, self.motto, self.ticker_name, self.description) + ) diff --git a/tests/annotations/tests.py b/tests/annotations/tests.py new file mode 100644 index 0000000000..afc23b85df --- /dev/null +++ b/tests/annotations/tests.py @@ -0,0 +1,288 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal + +from django.core.exceptions import FieldError +from django.db.models import ( + Sum, Count, + F, Value, Func, + IntegerField, BooleanField, CharField) +from django.db.models.fields import FieldDoesNotExist +from django.test import TestCase + +from .models import Author, Book, Store, DepartmentStore, Company, Employee + + +class NonAggregateAnnotationTestCase(TestCase): + fixtures = ["annotations.json"] + + def test_basic_annotation(self): + books = Book.objects.annotate( + is_book=Value(1, output_field=IntegerField())) + for book in books: + self.assertEqual(book.is_book, 1) + + def test_basic_f_annotation(self): + books = Book.objects.annotate(another_rating=F('rating')) + for book in books: + self.assertEqual(book.another_rating, book.rating) + + def test_joined_annotation(self): + books = Book.objects.select_related('publisher').annotate( + num_awards=F('publisher__num_awards')) + for book in books: + self.assertEqual(book.num_awards, book.publisher.num_awards) + + def test_annotate_with_aggregation(self): + books = Book.objects.annotate( + is_book=Value(1, output_field=IntegerField()), + rating_count=Count('rating')) + for book in books: + self.assertEqual(book.is_book, 1) + self.assertEqual(book.rating_count, 1) + + def test_aggregate_over_annotation(self): + agg = Author.objects.annotate(other_age=F('age')).aggregate(otherage_sum=Sum('other_age')) + other_agg = Author.objects.aggregate(age_sum=Sum('age')) + self.assertEqual(agg['otherage_sum'], other_agg['age_sum']) + + def test_filter_annotation(self): + books = Book.objects.annotate( + is_book=Value(1, output_field=IntegerField()) + ).filter(is_book=1) + for book in books: + self.assertEqual(book.is_book, 1) + + def test_filter_annotation_with_f(self): + books = Book.objects.annotate( + other_rating=F('rating') + ).filter(other_rating=3.5) + for book in books: + self.assertEqual(book.other_rating, 3.5) + + def test_filter_annotation_with_double_f(self): + books = Book.objects.annotate( + other_rating=F('rating') + ).filter(other_rating=F('rating')) + for book in books: + self.assertEqual(book.other_rating, book.rating) + + def test_filter_agg_with_double_f(self): + books = Book.objects.annotate( + sum_rating=Sum('rating') + ).filter(sum_rating=F('sum_rating')) + for book in books: + self.assertEqual(book.sum_rating, book.rating) + + def test_filter_wrong_annotation(self): + with self.assertRaisesRegexp(FieldError, "Cannot resolve keyword .*"): + list(Book.objects.annotate( + sum_rating=Sum('rating') + ).filter(sum_rating=F('nope'))) + + def test_update_with_annotation(self): + book_preupdate = Book.objects.get(pk=2) + Book.objects.annotate(other_rating=F('rating') - 1).update(rating=F('other_rating')) + book_postupdate = Book.objects.get(pk=2) + self.assertEqual(book_preupdate.rating - 1, book_postupdate.rating) + + def test_annotation_with_m2m(self): + books = Book.objects.annotate(author_age=F('authors__age')).filter(pk=1).order_by('author_age') + self.assertEqual(books[0].author_age, 34) + self.assertEqual(books[1].author_age, 35) + + def test_annotation_reverse_m2m(self): + books = Book.objects.annotate( + store_name=F('store__name')).filter( + name='Practical Django Projects').order_by( + 'store_name') + + self.assertQuerysetEqual( + books, [ + 'Amazon.com', + 'Books.com', + 'Mamma and Pappa\'s Books' + ], + lambda b: b.store_name + ) + + def test_values_annotation(self): + """ + Annotations can reference fields in a values clause, + and contribute to an existing values clause. + """ + # annotate references a field in values() + qs = Book.objects.values('rating').annotate(other_rating=F('rating') - 1) + book = qs.get(pk=1) + self.assertEqual(book['rating'] - 1, book['other_rating']) + + # filter refs the annotated value + book = qs.get(other_rating=4) + self.assertEqual(book['other_rating'], 4) + + # can annotate an existing values with a new field + book = qs.annotate(other_isbn=F('isbn')).get(other_rating=4) + self.assertEqual(book['other_rating'], 4) + self.assertEqual(book['other_isbn'], '155860191') + + def test_defer_annotation(self): + """ + Deferred attributes can be referenced by an annotation, + but they are not themselves deferred, and cannot be deferred. + """ + qs = Book.objects.defer('rating').annotate(other_rating=F('rating') - 1) + + with self.assertNumQueries(2): + book = qs.get(other_rating=4) + self.assertEqual(book.rating, 5) + self.assertEqual(book.other_rating, 4) + + with self.assertRaisesRegexp(FieldDoesNotExist, "\w has no field named u?'other_rating'"): + book = qs.defer('other_rating').get(other_rating=4) + + def test_mti_annotations(self): + """ + Fields on an inherited model can be referenced by an + annotated field. + """ + d = DepartmentStore.objects.create( + name='Angus & Robinson', + original_opening=datetime.date(2014, 3, 8), + friday_night_closing=datetime.time(21, 00, 00), + chain='Westfield' + ) + + books = Book.objects.filter(rating__gt=4) + for b in books: + d.books.add(b) + + qs = DepartmentStore.objects.annotate( + other_name=F('name'), + other_chain=F('chain'), + is_open=Value(True, BooleanField()), + book_isbn=F('books__isbn') + ).select_related('store').order_by('book_isbn').filter(chain='Westfield') + + self.assertQuerysetEqual( + qs, [ + ('Angus & Robinson', 'Westfield', True, '155860191'), + ('Angus & Robinson', 'Westfield', True, '159059725') + ], + lambda d: (d.other_name, d.other_chain, d.is_open, d.book_isbn) + ) + + def test_column_field_ordering(self): + """ + Test that columns are aligned in the correct order for + resolve_columns. This test will fail on mysql if column + ordering is out. Column fields should be aligned as: + 1. extra_select + 2. model_fields + 3. annotation_fields + 4. model_related_fields + """ + store = Store.objects.first() + Employee.objects.create(id=1, first_name='Max', manager=True, last_name='Paine', + store=store, age=23, salary=Decimal(50000.00)) + Employee.objects.create(id=2, first_name='Buffy', manager=False, last_name='Summers', + store=store, age=18, salary=Decimal(40000.00)) + + qs = Employee.objects.extra( + select={'random_value': '42'} + ).select_related('store').annotate( + annotated_value=Value(17, output_field=IntegerField()) + ) + + rows = [ + (1, 'Max', True, 42, 'Paine', 23, Decimal(50000.00), store.name, 17), + (2, 'Buffy', False, 42, 'Summers', 18, Decimal(40000.00), store.name, 17) + ] + + self.assertQuerysetEqual( + qs.order_by('id'), rows, + lambda e: ( + e.id, e.first_name, e.manager, e.random_value, e.last_name, e.age, + e.salary, e.store.name, e.annotated_value)) + + def test_column_field_ordering_with_deferred(self): + store = Store.objects.first() + Employee.objects.create(id=1, first_name='Max', manager=True, last_name='Paine', + store=store, age=23, salary=Decimal(50000.00)) + Employee.objects.create(id=2, first_name='Buffy', manager=False, last_name='Summers', + store=store, age=18, salary=Decimal(40000.00)) + + qs = Employee.objects.extra( + select={'random_value': '42'} + ).select_related('store').annotate( + annotated_value=Value(17, output_field=IntegerField()) + ) + + rows = [ + (1, 'Max', True, 42, 'Paine', 23, Decimal(50000.00), store.name, 17), + (2, 'Buffy', False, 42, 'Summers', 18, Decimal(40000.00), store.name, 17) + ] + + # and we respect deferred columns! + self.assertQuerysetEqual( + qs.defer('age').order_by('id'), rows, + lambda e: ( + e.id, e.first_name, e.manager, e.random_value, e.last_name, e.age, + e.salary, e.store.name, e.annotated_value)) + + def test_custom_functions(self): + Company(name='Apple', motto=None, ticker_name='APPL', description='Beautiful Devices').save() + Company(name='Django Software Foundation', motto=None, ticker_name=None, description=None).save() + Company(name='Google', motto='Do No Evil', ticker_name='GOOG', description='Internet Company').save() + Company(name='Yahoo', motto=None, ticker_name=None, description='Internet Company').save() + + qs = Company.objects.annotate( + tagline=Func( + F('motto'), + F('ticker_name'), + F('description'), + Value('No Tag'), + function='COALESCE') + ).order_by('name') + + self.assertQuerysetEqual( + qs, [ + ('Apple', 'APPL'), + ('Django Software Foundation', 'No Tag'), + ('Google', 'Do No Evil'), + ('Yahoo', 'Internet Company') + ], + lambda c: (c.name, c.tagline) + ) + + def test_custom_functions_can_ref_other_functions(self): + Company(name='Apple', motto=None, ticker_name='APPL', description='Beautiful Devices').save() + Company(name='Django Software Foundation', motto=None, ticker_name=None, description=None).save() + Company(name='Google', motto='Do No Evil', ticker_name='GOOG', description='Internet Company').save() + Company(name='Yahoo', motto=None, ticker_name=None, description='Internet Company').save() + + class Lower(Func): + function = 'LOWER' + + qs = Company.objects.annotate( + tagline=Func( + F('motto'), + F('ticker_name'), + F('description'), + Value('No Tag'), + function='COALESCE') + ).annotate( + tagline_lower=Lower(F('tagline'), output_field=CharField()) + ).order_by('name') + + # LOWER function supported by: + # oracle, postgres, mysql, sqlite, sqlserver + + self.assertQuerysetEqual( + qs, [ + ('Apple', 'APPL'.lower()), + ('Django Software Foundation', 'No Tag'.lower()), + ('Google', 'Do No Evil'.lower()), + ('Yahoo', 'Internet Company'.lower()) + ], + lambda c: (c.name, c.tagline_lower) + ) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index bd9eed9603..2756c8e10d 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -296,6 +296,21 @@ class ExpressionsTests(TestCase): g = deepcopy(f) self.assertEqual(f.name, g.name) + def test_f_reuse(self): + f = F('id') + n = Number.objects.create(integer=-1) + c = Company.objects.create( + name="Example Inc.", num_employees=2300, num_chairs=5, + ceo=Employee.objects.create(firstname="Joe", lastname="Smith") + ) + c_qs = Company.objects.filter(id=f) + self.assertEqual(c_qs.get(), c) + # Reuse the same F-object for another queryset + n_qs = Number.objects.filter(id=f) + self.assertEqual(n_qs.get(), n) + # The original query still works correctly + self.assertEqual(c_qs.get(), c) + class ExpressionsNumericTests(TestCase): @@ -362,12 +377,16 @@ class ExpressionsNumericTests(TestCase): Complex expressions of different connection types are possible. """ n = Number.objects.create(integer=10, float=123.45) - self.assertEqual(Number.objects.filter(pk=n.pk) - .update(float=F('integer') + F('float') * 2), 1) + self.assertEqual(Number.objects.filter(pk=n.pk).update( + float=F('integer') + F('float') * 2), 1) self.assertEqual(Number.objects.get(pk=n.pk).integer, 10) self.assertEqual(Number.objects.get(pk=n.pk).float, Approximate(256.900, places=3)) + def test_incorrect_field_expression(self): + with self.assertRaisesRegexp(FieldError, "Cannot resolve keyword u?'nope' into field.*"): + list(Employee.objects.filter(firstname=F('nope'))) + class ExpressionOperatorTests(TestCase): def setUp(self):