mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #14030 -- Allowed annotations to accept all expressions
This commit is contained in:
committed by
Marc Tamlyn
parent
39e3ef88c2
commit
f59fd15c49
1
AUTHORS
1
AUTHORS
@@ -347,6 +347,7 @@ answer newbie questions, and generally made Django that much better:
|
||||
Jorge Bastida <me@jorgebastida.com>
|
||||
Jorge Gajon <gajon@gajon.org>
|
||||
Joseph Kocherhans <joseph@jkocherhans.com>
|
||||
Josh Smeaton <josh.smeaton@gmail.com>
|
||||
Joshua Ginsberg <jag@flowtheory.net>
|
||||
Jozko Skrablin <jozko.skrablin@gmail.com>
|
||||
J. Pablo Fernandez <pupeno@pupeno.com>
|
||||
|
@@ -10,7 +10,7 @@ from django.db.models import signals, FieldDoesNotExist, DO_NOTHING
|
||||
from django.db.models.base import ModelBase
|
||||
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
|
||||
from django.db.models.related import PathInfo
|
||||
from django.db.models.sql.datastructures import Col
|
||||
from django.db.models.expressions import Col
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.utils.encoding import smart_text, python_2_unicode_compatible
|
||||
|
||||
|
@@ -186,7 +186,7 @@ class BaseSpatialOperations(object):
|
||||
"""
|
||||
raise NotImplementedError('Distance operations not available on this spatial backend.')
|
||||
|
||||
def get_geom_placeholder(self, f, value):
|
||||
def get_geom_placeholder(self, f, value, qn):
|
||||
"""
|
||||
Returns the placeholder for the given geometry field with the given
|
||||
value. Depending on the spatial backend, the placeholder may contain a
|
||||
@@ -195,16 +195,6 @@ class BaseSpatialOperations(object):
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
|
||||
|
||||
def get_expression_column(self, evaluator):
|
||||
"""
|
||||
Helper method to return the quoted column string from the evaluator
|
||||
for its expression.
|
||||
"""
|
||||
for expr, col_tup in evaluator.cols:
|
||||
if expr is evaluator.expression:
|
||||
return '%s.%s' % tuple(map(self.quote_name, col_tup))
|
||||
raise Exception("Could not find the column for the expression.")
|
||||
|
||||
# Spatial SQL Construction
|
||||
def spatial_aggregate_sql(self, agg):
|
||||
raise NotImplementedError('Aggregate support not implemented for this spatial backend.')
|
||||
|
@@ -35,14 +35,14 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
def geo_db_type(self, f):
|
||||
return f.geom_type
|
||||
|
||||
def get_geom_placeholder(self, f, value):
|
||||
def get_geom_placeholder(self, f, value, qn):
|
||||
"""
|
||||
The placeholder here has to include MySQL's WKT constructor. Because
|
||||
MySQL does not support spatial transformations, there is no need to
|
||||
modify the placeholder based on the contents of the given value.
|
||||
"""
|
||||
if hasattr(value, 'expression'):
|
||||
placeholder = self.get_expression_column(value)
|
||||
if hasattr(value, 'as_sql'):
|
||||
placeholder, _ = qn.compile(value)
|
||||
else:
|
||||
placeholder = '%s(%%s)' % self.from_text
|
||||
return placeholder
|
||||
|
@@ -9,7 +9,7 @@
|
||||
"""
|
||||
import re
|
||||
|
||||
from django.db.backends.oracle.base import DatabaseOperations
|
||||
from django.db.backends.oracle.base import DatabaseOperations, Database
|
||||
from django.contrib.gis.db.backends.base import BaseSpatialOperations
|
||||
from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter
|
||||
from django.contrib.gis.db.backends.utils import SpatialOperator
|
||||
@@ -145,9 +145,11 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
else:
|
||||
return None
|
||||
|
||||
def convert_geom(self, clob, geo_field):
|
||||
if clob:
|
||||
return Geometry(clob.read(), geo_field.srid)
|
||||
def convert_geom(self, value, geo_field):
|
||||
if value:
|
||||
if isinstance(value, Database.LOB):
|
||||
value = value.read()
|
||||
return Geometry(value, geo_field.srid)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -184,7 +186,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
|
||||
return [dist_param]
|
||||
|
||||
def get_geom_placeholder(self, f, value):
|
||||
def get_geom_placeholder(self, f, value, qn):
|
||||
"""
|
||||
Provides a proper substitution value for Geometries that are not in the
|
||||
SRID of the field. Specifically, this routine will substitute in the
|
||||
@@ -196,14 +198,15 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
def transform_value(val, srid):
|
||||
return val.srid != srid
|
||||
|
||||
if hasattr(value, 'expression'):
|
||||
if hasattr(value, 'as_sql'):
|
||||
if transform_value(value, f.srid):
|
||||
placeholder = '%s(%%s, %s)' % (self.transform, f.srid)
|
||||
else:
|
||||
placeholder = '%s'
|
||||
# No geometry value used for F expression, substitute in
|
||||
# the column name instead.
|
||||
return placeholder % self.get_expression_column(value)
|
||||
sql, _ = qn.compile(value)
|
||||
return placeholder % sql
|
||||
else:
|
||||
if transform_value(value, f.srid):
|
||||
return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (self.transform, value.srid, f.srid)
|
||||
@@ -219,9 +222,9 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
if agg_name == 'union':
|
||||
agg_name += 'agg'
|
||||
if agg.is_extent:
|
||||
sql_template = '%(function)s(%(field)s)'
|
||||
sql_template = '%(function)s(%(expressions)s)'
|
||||
else:
|
||||
sql_template = '%(function)s(SDOAGGRTYPE(%(field)s,%(tolerance)s))'
|
||||
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
|
||||
sql_function = getattr(self, agg_name)
|
||||
return self.select % sql_template, sql_function
|
||||
|
||||
|
@@ -22,7 +22,7 @@ class PostGISOperator(SpatialOperator):
|
||||
super(PostGISOperator, self).__init__(**kwargs)
|
||||
|
||||
def as_sql(self, connection, lookup, *args):
|
||||
if lookup.lhs.source.geography and not self.geography:
|
||||
if lookup.lhs.output_field.geography and not self.geography:
|
||||
raise ValueError('PostGIS geography does not support the "%s" '
|
||||
'function/operator.' % (self.func or self.op,))
|
||||
return super(PostGISOperator, self).as_sql(connection, lookup, *args)
|
||||
@@ -32,7 +32,7 @@ class PostGISDistanceOperator(PostGISOperator):
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s'
|
||||
|
||||
def as_sql(self, connection, lookup, template_params, sql_params):
|
||||
if not lookup.lhs.source.geography and lookup.lhs.source.geodetic(connection):
|
||||
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
|
||||
sql_template = self.sql_template
|
||||
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
|
||||
template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
|
||||
@@ -215,7 +215,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
Converts the geometry returned from PostGIS aggretates.
|
||||
"""
|
||||
if hex:
|
||||
return Geometry(hex)
|
||||
return Geometry(hex, srid=geo_field.srid)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -284,7 +284,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
else:
|
||||
return [dist_param]
|
||||
|
||||
def get_geom_placeholder(self, f, value):
|
||||
def get_geom_placeholder(self, f, value, qn):
|
||||
"""
|
||||
Provides a proper substitution value for Geometries that are not in the
|
||||
SRID of the field. Specifically, this routine will substitute in the
|
||||
@@ -296,11 +296,12 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
# Adding Transform() to the SQL placeholder.
|
||||
placeholder = '%s(%%s, %s)' % (self.transform, f.srid)
|
||||
|
||||
if hasattr(value, 'expression'):
|
||||
if hasattr(value, 'as_sql'):
|
||||
# If this is an F expression, then we don't really want
|
||||
# a placeholder and instead substitute in the column
|
||||
# of the expression.
|
||||
placeholder = placeholder % self.get_expression_column(value)
|
||||
sql, _ = qn.compile(value)
|
||||
placeholder = placeholder % sql
|
||||
|
||||
return placeholder
|
||||
|
||||
@@ -375,7 +376,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
agg_name = agg_name.lower()
|
||||
if agg_name == 'union':
|
||||
agg_name += 'agg'
|
||||
sql_template = '%(function)s(%(field)s)'
|
||||
sql_template = '%(function)s(%(expressions)s)'
|
||||
sql_function = getattr(self, agg_name)
|
||||
return sql_template, sql_function
|
||||
|
||||
|
@@ -178,7 +178,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
dist_param = value
|
||||
return [dist_param]
|
||||
|
||||
def get_geom_placeholder(self, f, value):
|
||||
def get_geom_placeholder(self, f, value, qn):
|
||||
"""
|
||||
Provides a proper substitution value for Geometries that are not in the
|
||||
SRID of the field. Specifically, this routine will substitute in the
|
||||
@@ -186,14 +186,15 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
"""
|
||||
def transform_value(value, srid):
|
||||
return not (value is None or value.srid == srid)
|
||||
if hasattr(value, 'expression'):
|
||||
if hasattr(value, 'as_sql'):
|
||||
if transform_value(value, f.srid):
|
||||
placeholder = '%s(%%s, %s)' % (self.transform, f.srid)
|
||||
else:
|
||||
placeholder = '%s'
|
||||
# No geometry value used for F expression, substitute in
|
||||
# the column name instead.
|
||||
return placeholder % self.get_expression_column(value)
|
||||
sql, _ = qn.compile(value)
|
||||
return placeholder % sql
|
||||
else:
|
||||
if transform_value(value, f.srid):
|
||||
# Adding Transform() to the SQL placeholder.
|
||||
@@ -255,7 +256,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
agg_name = agg_name.lower()
|
||||
if agg_name == 'union':
|
||||
agg_name += 'agg'
|
||||
sql_template = self.select % '%(function)s(%(field)s)'
|
||||
sql_template = self.select % '%(function)s(%(expressions)s)'
|
||||
sql_function = getattr(self, agg_name)
|
||||
return sql_template, sql_function
|
||||
|
||||
|
@@ -1,23 +1,66 @@
|
||||
from django.db.models import Aggregate
|
||||
from django.db.models.aggregates import Aggregate
|
||||
from django.contrib.gis.db.models.fields import GeometryField, ExtentField
|
||||
|
||||
__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union']
|
||||
|
||||
|
||||
class Collect(Aggregate):
|
||||
class GeoAggregate(Aggregate):
|
||||
template = None
|
||||
function = None
|
||||
is_extent = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if connection.ops.oracle:
|
||||
if not hasattr(self, 'tolerance'):
|
||||
self.tolerance = 0.05
|
||||
self.extra['tolerance'] = self.tolerance
|
||||
|
||||
template, function = connection.ops.spatial_aggregate_sql(self)
|
||||
if template is None:
|
||||
template = '%(function)s(%(expressions)s)'
|
||||
self.extra['template'] = self.extra.get('template', template)
|
||||
self.extra['function'] = self.extra.get('function', function)
|
||||
return super(GeoAggregate, self).as_sql(compiler, connection)
|
||||
|
||||
def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize)
|
||||
if not isinstance(self.expressions[0].output_field, GeometryField):
|
||||
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
|
||||
return c
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
return connection.ops.convert_geom(value, self.output_field)
|
||||
|
||||
|
||||
class Collect(GeoAggregate):
|
||||
name = 'Collect'
|
||||
|
||||
|
||||
class Extent(Aggregate):
|
||||
class Extent(GeoAggregate):
|
||||
name = 'Extent'
|
||||
is_extent = '2D'
|
||||
|
||||
def __init__(self, expression, **extra):
|
||||
super(Extent, self).__init__(expression, output_field=ExtentField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
return connection.ops.convert_extent(value)
|
||||
|
||||
|
||||
class Extent3D(Aggregate):
|
||||
class Extent3D(GeoAggregate):
|
||||
name = 'Extent3D'
|
||||
is_extent = '3D'
|
||||
|
||||
def __init__(self, expression, **extra):
|
||||
super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
return connection.ops.convert_extent3d(value)
|
||||
|
||||
|
||||
class MakeLine(Aggregate):
|
||||
class MakeLine(GeoAggregate):
|
||||
name = 'MakeLine'
|
||||
|
||||
|
||||
class Union(Aggregate):
|
||||
class Union(GeoAggregate):
|
||||
name = 'Union'
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.sql.expressions import SQLEvaluator
|
||||
from django.db.models.expressions import ExpressionNode
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.contrib.gis import forms
|
||||
from django.contrib.gis.db.models.lookups import gis_lookups
|
||||
@@ -165,7 +165,7 @@ class GeometryField(Field):
|
||||
returning to the caller.
|
||||
"""
|
||||
value = super(GeometryField, self).get_prep_value(value)
|
||||
if isinstance(value, SQLEvaluator):
|
||||
if isinstance(value, ExpressionNode):
|
||||
return value
|
||||
elif isinstance(value, (tuple, list)):
|
||||
geom = value[0]
|
||||
@@ -197,7 +197,7 @@ class GeometryField(Field):
|
||||
return geom
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
if value:
|
||||
if value and not isinstance(value, Geometry):
|
||||
value = Geometry(value)
|
||||
return value
|
||||
|
||||
@@ -259,7 +259,7 @@ class GeometryField(Field):
|
||||
pass
|
||||
else:
|
||||
params += value[1:]
|
||||
elif isinstance(value, SQLEvaluator):
|
||||
elif isinstance(value, ExpressionNode):
|
||||
params = []
|
||||
else:
|
||||
params = [connection.ops.Adapter(value)]
|
||||
@@ -282,12 +282,12 @@ class GeometryField(Field):
|
||||
else:
|
||||
return connection.ops.Adapter(self.get_prep_value(value))
|
||||
|
||||
def get_placeholder(self, value, connection):
|
||||
def get_placeholder(self, value, qn, connection):
|
||||
"""
|
||||
Returns the placeholder for the geometry column for the
|
||||
given value.
|
||||
"""
|
||||
return connection.ops.get_geom_placeholder(self, value)
|
||||
return connection.ops.get_geom_placeholder(self, value, qn)
|
||||
|
||||
|
||||
for klass in gis_lookups.values():
|
||||
@@ -335,3 +335,12 @@ class GeometryCollectionField(GeometryField):
|
||||
geom_type = 'GEOMETRYCOLLECTION'
|
||||
form_class = forms.GeometryCollectionField
|
||||
description = _("Geometry collection")
|
||||
|
||||
|
||||
class ExtentField(Field):
|
||||
"Used as a return value from an extent aggregate"
|
||||
|
||||
description = _("Extent Aggregate Field")
|
||||
|
||||
def get_internal_type(self):
|
||||
return "ExtentField"
|
||||
|
@@ -4,7 +4,7 @@ import re
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import FieldDoesNotExist
|
||||
from django.db.models.lookups import Lookup
|
||||
from django.db.models.sql.expressions import SQLEvaluator
|
||||
from django.db.models.expressions import ExpressionNode, Col
|
||||
from django.utils import six
|
||||
|
||||
gis_lookups = {}
|
||||
@@ -68,18 +68,19 @@ class GISLookup(Lookup):
|
||||
rhs, rhs_params = super(GISLookup, self).process_rhs(qn, connection)
|
||||
|
||||
geom = self.rhs
|
||||
if isinstance(self.rhs, SQLEvaluator):
|
||||
if isinstance(self.rhs, Col):
|
||||
# Make sure the F Expression destination field exists, and
|
||||
# set an `srid` attribute with the same as that of the
|
||||
# destination.
|
||||
geo_fld = self._check_geo_field(self.rhs.opts, self.rhs.expression.name)
|
||||
if not geo_fld:
|
||||
geo_fld = self.rhs.output_field
|
||||
if not hasattr(geo_fld, 'srid'):
|
||||
raise ValueError('No geographic field found in expression.')
|
||||
self.rhs.srid = geo_fld.srid
|
||||
elif isinstance(self.rhs, ExpressionNode):
|
||||
raise ValueError('Complex expressions not supported for GeometryField')
|
||||
elif isinstance(self.rhs, (list, tuple)):
|
||||
geom = self.rhs[0]
|
||||
|
||||
rhs = connection.ops.get_geom_placeholder(self.lhs.source, geom)
|
||||
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, qn)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
|
@@ -530,7 +530,7 @@ class GeoQuerySet(QuerySet):
|
||||
# transformation SQL.
|
||||
geom = geo_field.get_prep_value(settings['procedure_args'][name])
|
||||
params = geo_field.get_db_prep_lookup('contains', geom, connection=connection)
|
||||
geom_placeholder = geo_field.get_placeholder(geom, connection)
|
||||
geom_placeholder = geo_field.get_placeholder(geom, None, connection)
|
||||
|
||||
# Replacing the procedure format with that of any needed
|
||||
# transformation SQL.
|
||||
|
@@ -6,12 +6,15 @@ from django.contrib.gis.db.models.fields import GeometryField
|
||||
__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"django.contrib.gis.db.models.sql.aggregates is deprecated. Use "
|
||||
"django.contrib.gis.db.models.aggregates instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
|
||||
|
||||
class GeoAggregate(Aggregate):
|
||||
# Default SQL template for spatial aggregates.
|
||||
sql_template = '%(function)s(%(field)s)'
|
||||
|
||||
# Conversion class, if necessary.
|
||||
conversion_class = None
|
||||
sql_template = '%(function)s(%(expressions)s)'
|
||||
|
||||
# Flags for indicating the type of the aggregate.
|
||||
is_extent = False
|
||||
@@ -45,7 +48,7 @@ class GeoAggregate(Aggregate):
|
||||
|
||||
substitutions = {
|
||||
'function': sql_function,
|
||||
'field': field_name
|
||||
'expressions': field_name
|
||||
}
|
||||
substitutions.update(self.extra)
|
||||
|
||||
|
@@ -70,8 +70,8 @@ class GeoSQLCompiler(compiler.SQLCompiler):
|
||||
aliases.update(new_aliases)
|
||||
|
||||
max_name_length = self.connection.ops.max_name_length()
|
||||
for alias, aggregate in self.query.aggregate_select.items():
|
||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
||||
for alias, annotation in self.query.annotation_select.items():
|
||||
agg_sql, agg_params = self.compile(annotation)
|
||||
if alias is None:
|
||||
result.append(agg_sql)
|
||||
else:
|
||||
|
@@ -4,7 +4,7 @@ from django.db.models.sql.constants import QUERY_TERMS
|
||||
|
||||
from django.contrib.gis.db.models.fields import GeometryField
|
||||
from django.contrib.gis.db.models.lookups import GISLookup
|
||||
from django.contrib.gis.db.models.sql import aggregates as gis_aggregates
|
||||
from django.contrib.gis.db.models import aggregates as gis_aggregates
|
||||
from django.contrib.gis.db.models.sql.conversion import GeomField
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ class GeoQuery(sql.Query):
|
||||
"""
|
||||
# Overriding the valid query terms.
|
||||
query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys())
|
||||
aggregates_module = gis_aggregates
|
||||
|
||||
compiler = 'GeoSQLCompiler'
|
||||
|
||||
@@ -40,28 +39,12 @@ class GeoQuery(sql.Query):
|
||||
# Remove any aggregates marked for reduction from the subquery
|
||||
# and move them to the outer AggregateQuery.
|
||||
connection = connections[using]
|
||||
for alias, aggregate in self.aggregate_select.items():
|
||||
if isinstance(aggregate, gis_aggregates.GeoAggregate):
|
||||
if not getattr(aggregate, 'is_extent', False) or connection.ops.oracle:
|
||||
for alias, annotation in self.annotation_select.items():
|
||||
if isinstance(annotation, gis_aggregates.GeoAggregate):
|
||||
if not getattr(annotation, 'is_extent', False) or connection.ops.oracle:
|
||||
self.extra_select_fields[alias] = GeomField()
|
||||
return super(GeoQuery, self).get_aggregation(using, force_subq)
|
||||
|
||||
def resolve_aggregate(self, value, aggregate, connection):
|
||||
"""
|
||||
Overridden from GeoQuery's normalize to handle the conversion of
|
||||
GeoAggregate objects.
|
||||
"""
|
||||
if isinstance(aggregate, self.aggregates_module.GeoAggregate):
|
||||
if aggregate.is_extent:
|
||||
if aggregate.is_extent == '3D':
|
||||
return connection.ops.convert_extent3d(value)
|
||||
else:
|
||||
return connection.ops.convert_extent(value)
|
||||
else:
|
||||
return connection.ops.convert_geom(value, aggregate.source)
|
||||
else:
|
||||
return super(GeoQuery, self).resolve_aggregate(value, aggregate, connection)
|
||||
|
||||
# Private API utilities, subject to change.
|
||||
def _geo_field(self, field_name=None):
|
||||
"""
|
||||
|
@@ -20,8 +20,7 @@ from django.db.backends.sqlite3.client import DatabaseClient
|
||||
from django.db.backends.sqlite3.creation import DatabaseCreation
|
||||
from django.db.backends.sqlite3.introspection import DatabaseIntrospection
|
||||
from django.db.backends.sqlite3.schema import DatabaseSchemaEditor
|
||||
from django.db.models import fields
|
||||
from django.db.models.sql import aggregates
|
||||
from django.db.models import fields, aggregates
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.encoding import force_text
|
||||
from django.utils.functional import cached_property
|
||||
@@ -163,8 +162,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
|
||||
bad_aggregates = (aggregates.Sum, aggregates.Avg,
|
||||
aggregates.Variance, aggregates.StdDev)
|
||||
if (isinstance(aggregate.source, bad_fields) and
|
||||
isinstance(aggregate, bad_aggregates)):
|
||||
if aggregate.refs_field(bad_aggregates, bad_fields):
|
||||
raise NotImplementedError(
|
||||
'You cannot use Sum, Avg, StdDev and Variance aggregations '
|
||||
'on date/time fields in sqlite3 '
|
||||
|
@@ -4,7 +4,7 @@ import warnings
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA
|
||||
from django.db.models.query import Q, QuerySet, Prefetch # NOQA
|
||||
from django.db.models.expressions import F # NOQA
|
||||
from django.db.models.expressions import ExpressionNode, F, Value, Func # NOQA
|
||||
from django.db.models.manager import Manager # NOQA
|
||||
from django.db.models.base import Model # NOQA
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
|
@@ -1,94 +1,152 @@
|
||||
"""
|
||||
Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import IntegerField, FloatField
|
||||
|
||||
__all__ = [
|
||||
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
||||
]
|
||||
|
||||
|
||||
def refs_aggregate(lookup_parts, aggregates):
|
||||
"""
|
||||
A little helper method to check if the lookup_parts contains references
|
||||
to the given aggregates set. Because the LOOKUP_SEP is contained in the
|
||||
default annotation names we must check each prefix of the lookup_parts
|
||||
for match.
|
||||
"""
|
||||
for n in range(len(lookup_parts) + 1):
|
||||
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
|
||||
if level_n_lookup in aggregates:
|
||||
return aggregates[level_n_lookup], lookup_parts[n:]
|
||||
return False, ()
|
||||
class Aggregate(Func):
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
assert len(self.source_expressions) == 1
|
||||
c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if c.source_expressions[0].contains_aggregate and not summarize:
|
||||
name = self.source_expressions[0].name
|
||||
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
|
||||
c.name, name, name))
|
||||
c._patch_aggregate(query) # backward-compatibility support
|
||||
return c
|
||||
|
||||
class Aggregate(object):
|
||||
"""
|
||||
Default Aggregate definition.
|
||||
"""
|
||||
def __init__(self, lookup, **extra):
|
||||
"""Instantiate a new aggregate.
|
||||
def refs_field(self, aggregate_types, field_types):
|
||||
try:
|
||||
return (isinstance(self, aggregate_types) and
|
||||
isinstance(self.input_field._output_field_or_none, field_types))
|
||||
except FieldError:
|
||||
# Sometimes we don't know the input_field's output type (for example,
|
||||
# doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
|
||||
# is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
|
||||
# have any output field.
|
||||
return False
|
||||
|
||||
* lookup is the field on which the aggregate operates.
|
||||
* extra is a dictionary of additional data to provide for the
|
||||
aggregate definition
|
||||
@property
|
||||
def input_field(self):
|
||||
return self.source_expressions[0]
|
||||
|
||||
Also utilizes the class variables:
|
||||
* name, the identifier for this aggregate function.
|
||||
@property
|
||||
def default_alias(self):
|
||||
if hasattr(self.source_expressions[0], 'name'):
|
||||
return '%s__%s' % (self.source_expressions[0].name, self.name.lower())
|
||||
raise TypeError("Complex expressions require an alias")
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
def _patch_aggregate(self, query):
|
||||
"""
|
||||
self.lookup = lookup
|
||||
self.extra = extra
|
||||
Helper method for patching 3rd party aggregates that do not yet support
|
||||
the new way of subclassing. This method should be removed in 2.0
|
||||
|
||||
def _default_alias(self):
|
||||
return '%s__%s' % (self.lookup, self.name.lower())
|
||||
default_alias = property(_default_alias)
|
||||
add_to_query(query, alias, col, source, is_summary) will be defined on
|
||||
legacy aggregates which, in turn, instantiates the SQL implementation of
|
||||
the aggregate. In all the cases found, the general implementation of
|
||||
add_to_query looks like:
|
||||
|
||||
def add_to_query(self, query, alias, col, source, is_summary):
|
||||
"""Add the aggregate to the nominated query.
|
||||
def add_to_query(self, query, alias, col, source, is_summary):
|
||||
klass = SQLImplementationAggregate
|
||||
aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
|
||||
query.aggregates[alias] = aggregate
|
||||
|
||||
This method is used to convert the generic Aggregate definition into a
|
||||
backend-specific definition.
|
||||
|
||||
* query is the backend-specific query instance to which the aggregate
|
||||
is to be added.
|
||||
* col is a column reference describing the subject field
|
||||
of the aggregate. It can be an alias, or a tuple describing
|
||||
a table and column name.
|
||||
* source is the underlying field or aggregate definition for
|
||||
the column reference. If the aggregate is not an ordinal or
|
||||
computed type, this reference is used to determine the coerced
|
||||
output type of the aggregate.
|
||||
* is_summary is a boolean that is set True if the aggregate is a
|
||||
summary value rather than an annotation.
|
||||
By supplying a known alias, we can get the SQLAggregate out of the
|
||||
aggregates dict, and use the sql_function and sql_template attributes
|
||||
to patch *this* aggregate.
|
||||
"""
|
||||
klass = getattr(query.aggregates_module, self.name)
|
||||
aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
|
||||
query.aggregates[alias] = aggregate
|
||||
if not hasattr(self, 'add_to_query') or self.function is not None:
|
||||
return
|
||||
|
||||
placeholder_alias = "_XXXXXXXX_"
|
||||
self.add_to_query(query, placeholder_alias, None, None, None)
|
||||
sql_aggregate = query.aggregates.pop(placeholder_alias)
|
||||
if 'sql_function' not in self.extra and hasattr(sql_aggregate, 'sql_function'):
|
||||
self.extra['function'] = sql_aggregate.sql_function
|
||||
|
||||
if hasattr(sql_aggregate, 'sql_template'):
|
||||
self.extra['template'] = sql_aggregate.sql_template
|
||||
|
||||
|
||||
class Avg(Aggregate):
|
||||
function = 'AVG'
|
||||
name = 'Avg'
|
||||
|
||||
def __init__(self, expression, **extra):
|
||||
super(Avg, self).__init__(expression, output_field=FloatField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
if value is None:
|
||||
return value
|
||||
return float(value)
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
function = 'COUNT'
|
||||
name = 'Count'
|
||||
template = '%(function)s(%(distinct)s%(expressions)s)'
|
||||
|
||||
def __init__(self, expression, distinct=False, **extra):
|
||||
if expression == '*':
|
||||
expression = Value(expression)
|
||||
expression._output_field = IntegerField()
|
||||
super(Count, self).__init__(
|
||||
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
if value is None:
|
||||
return 0
|
||||
return int(value)
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = 'MAX'
|
||||
name = 'Max'
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = 'MIN'
|
||||
name = 'Min'
|
||||
|
||||
|
||||
class StdDev(Aggregate):
|
||||
name = 'StdDev'
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
||||
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
if value is None:
|
||||
return value
|
||||
return float(value)
|
||||
|
||||
|
||||
class Sum(Aggregate):
|
||||
function = 'SUM'
|
||||
name = 'Sum'
|
||||
|
||||
|
||||
class Variance(Aggregate):
|
||||
name = 'Variance'
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
||||
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
if value is None:
|
||||
return value
|
||||
return float(value)
|
||||
|
@@ -1,14 +1,20 @@
|
||||
import copy
|
||||
import datetime
|
||||
|
||||
from django.db.models.aggregates import refs_aggregate
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.models import fields
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.utils import tree
|
||||
from django.db.models.query_utils import refs_aggregate
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class ExpressionNode(tree.Node):
|
||||
class CombinableMixin(object):
|
||||
"""
|
||||
Base class for all query expressions.
|
||||
Provides the ability to combine one or two objects with
|
||||
some connector. For example F('foo') + F('bar').
|
||||
"""
|
||||
|
||||
# Arithmetic connectors
|
||||
ADD = '+'
|
||||
SUB = '-'
|
||||
@@ -25,44 +31,17 @@ class ExpressionNode(tree.Node):
|
||||
BITAND = '&'
|
||||
BITOR = '|'
|
||||
|
||||
def __init__(self, children=None, connector=None, negated=False):
|
||||
if children is not None and len(children) > 1 and connector is None:
|
||||
raise TypeError('You have to specify a connector.')
|
||||
super(ExpressionNode, self).__init__(children, connector, negated)
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
if isinstance(other, datetime.timedelta):
|
||||
return DateModifierNode([self, other], connector)
|
||||
return DateModifierNode(self, connector, other)
|
||||
|
||||
if not hasattr(other, 'resolve_expression'):
|
||||
# everything must be resolvable to an expression
|
||||
other = Value(other)
|
||||
|
||||
if reversed:
|
||||
obj = ExpressionNode([other], connector)
|
||||
obj.add(node or self, connector)
|
||||
else:
|
||||
obj = node or ExpressionNode([self], connector)
|
||||
obj.add(other, connector)
|
||||
return obj
|
||||
|
||||
def contains_aggregate(self, existing_aggregates):
|
||||
if self.children:
|
||||
return any(child.contains_aggregate(existing_aggregates)
|
||||
for child in self.children
|
||||
if hasattr(child, 'contains_aggregate'))
|
||||
else:
|
||||
return refs_aggregate(self.name.split(LOOKUP_SEP),
|
||||
existing_aggregates)
|
||||
|
||||
def prepare_database_save(self, unused):
|
||||
return self
|
||||
|
||||
###################
|
||||
# VISITOR METHODS #
|
||||
###################
|
||||
|
||||
def prepare(self, evaluator, query, allow_joins):
|
||||
return evaluator.prepare_node(self, query, allow_joins)
|
||||
|
||||
def evaluate(self, evaluator, qn, connection):
|
||||
return evaluator.evaluate_node(self, qn, connection)
|
||||
return Expression(other, connector, self)
|
||||
return Expression(self, connector, other)
|
||||
|
||||
#############
|
||||
# OPERATORS #
|
||||
@@ -137,27 +116,240 @@ class ExpressionNode(tree.Node):
|
||||
)
|
||||
|
||||
|
||||
class F(ExpressionNode):
|
||||
class ExpressionNode(CombinableMixin):
|
||||
"""
|
||||
An expression representing the value of the given field.
|
||||
Base class for all query expressions.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
super(F, self).__init__(None, None, False)
|
||||
self.name = name
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
obj = super(F, self).__deepcopy__(memodict)
|
||||
obj.name = self.name
|
||||
return obj
|
||||
# aggregate specific fields
|
||||
is_summary = False
|
||||
|
||||
def prepare(self, evaluator, query, allow_joins):
|
||||
return evaluator.prepare_leaf(self, query, allow_joins)
|
||||
def __init__(self, output_field=None):
|
||||
self._output_field = output_field
|
||||
|
||||
def evaluate(self, evaluator, qn, connection):
|
||||
return evaluator.evaluate_leaf(self, qn, connection)
|
||||
def get_source_expressions(self):
|
||||
return []
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
assert len(exprs) == 0
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
"""
|
||||
Responsible for returning a (sql, [params]) tuple to be included
|
||||
in the current query.
|
||||
|
||||
Different backends can provide their own implementation, by
|
||||
providing an `as_{vendor}` method and patching the Expression:
|
||||
|
||||
```
|
||||
def override_as_sql(self, compiler, connection):
|
||||
# custom logic
|
||||
return super(ExpressionNode, self).as_sql(compiler, connection)
|
||||
setattr(ExpressionNode, 'as_' + connection.vendor, override_as_sql)
|
||||
```
|
||||
|
||||
Arguments:
|
||||
* compiler: the query compiler responsible for generating the query.
|
||||
Must have a compile method, returning a (sql, [params]) tuple.
|
||||
Calling compiler(value) will return a quoted `value`.
|
||||
|
||||
* connection: the database connection used for the current query.
|
||||
|
||||
Returns: (sql, params)
|
||||
Where `sql` is a string containing ordered sql parameters to be
|
||||
replaced with the elements of the list `params`.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement as_sql()")
|
||||
|
||||
@cached_property
|
||||
def contains_aggregate(self):
|
||||
for expr in self.get_source_expressions():
|
||||
if expr and expr.contains_aggregate:
|
||||
return True
|
||||
return False
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
"""
|
||||
Provides the chance to do any preprocessing or validation before being
|
||||
added to the query.
|
||||
|
||||
Arguments:
|
||||
* query: the backend query implementation
|
||||
* allow_joins: boolean allowing or denying use of joins
|
||||
in this query
|
||||
* reuse: a set of reusable joins for multijoins
|
||||
* summarize: a terminal aggregate clause
|
||||
|
||||
Returns: an ExpressionNode to be added to the query.
|
||||
"""
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
return c
|
||||
|
||||
def _prepare(self):
|
||||
"""
|
||||
Hook used by Field.get_prep_lookup() to do custom preparation.
|
||||
"""
|
||||
return self
|
||||
|
||||
@property
|
||||
def field(self):
|
||||
return self.output_field
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
"""
|
||||
Returns the output type of this expressions.
|
||||
"""
|
||||
if self._output_field_or_none is None:
|
||||
raise FieldError("Cannot resolve expression type, unknown output_field")
|
||||
return self._output_field_or_none
|
||||
|
||||
@cached_property
|
||||
def _output_field_or_none(self):
|
||||
"""
|
||||
Returns the output field of this expression, or None if no output type
|
||||
can be resolved. Note that the 'output_field' property will raise
|
||||
FieldError if no type can be resolved, but this attribute allows for
|
||||
None values.
|
||||
"""
|
||||
if self._output_field is None:
|
||||
self._resolve_output_field()
|
||||
return self._output_field
|
||||
|
||||
def _resolve_output_field(self):
|
||||
"""
|
||||
Attempts to infer the output type of the expression. If the output
|
||||
fields of all source fields match then we can simply infer the same
|
||||
type here.
|
||||
"""
|
||||
if self._output_field is None:
|
||||
sources = self.get_source_fields()
|
||||
num_sources = len(sources)
|
||||
if num_sources == 0:
|
||||
self._output_field = None
|
||||
else:
|
||||
self._output_field = sources[0]
|
||||
for source in sources:
|
||||
if source is not None and not isinstance(self._output_field, source.__class__):
|
||||
raise FieldError(
|
||||
"Expression contains mixed types. You must set output_field")
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
"""
|
||||
Expressions provide their own converters because users have the option
|
||||
of manually specifying the output_field which may be a different type
|
||||
from the one the database returns.
|
||||
"""
|
||||
field = self.output_field
|
||||
internal_type = field.get_internal_type()
|
||||
if value is None:
|
||||
return value
|
||||
elif internal_type == 'FloatField':
|
||||
return float(value)
|
||||
elif internal_type.endswith('IntegerField'):
|
||||
return int(value)
|
||||
elif internal_type == 'DecimalField':
|
||||
return backend_utils.typecast_decimal(field.format_number(value))
|
||||
return value
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.output_field.get_transform(name)
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[e.relabeled_clone(change_map) for e in self.get_source_expressions()])
|
||||
return clone
|
||||
|
||||
def copy(self):
|
||||
c = copy.copy(self)
|
||||
c.copied = True
|
||||
return c
|
||||
|
||||
def refs_aggregate(self, existing_aggregates):
|
||||
"""
|
||||
Does this expression contain a reference to some of the
|
||||
existing aggregates? If so, returns the aggregate and also
|
||||
the lookup parts that *weren't* found. So, if
|
||||
exsiting_aggregates = {'max_id': Max('id')}
|
||||
self.name = 'max_id'
|
||||
queryset.filter(max_id__range=[10,100])
|
||||
then this method will return Max('id') and those parts of the
|
||||
name that weren't found. In this case `max_id` is found and the range
|
||||
portion is returned as ('range',).
|
||||
"""
|
||||
for node in self.get_source_expressions():
|
||||
agg, lookup = node.refs_aggregate(existing_aggregates)
|
||||
if agg:
|
||||
return agg, lookup
|
||||
return False, ()
|
||||
|
||||
def refs_field(self, aggregate_types, field_types):
|
||||
"""
|
||||
Helper method for check_aggregate_support on backends
|
||||
"""
|
||||
return any(
|
||||
node.refs_field(aggregate_types, field_types)
|
||||
for node in self.get_source_expressions())
|
||||
|
||||
def prepare_database_save(self, field):
|
||||
return self
|
||||
|
||||
def get_group_by_cols(self):
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def get_source_fields(self):
|
||||
"""
|
||||
Returns the underlying field types used by this
|
||||
aggregate.
|
||||
"""
|
||||
return [e._output_field_or_none for e in self.get_source_expressions()]
|
||||
|
||||
|
||||
class DateModifierNode(ExpressionNode):
|
||||
class Expression(ExpressionNode):
|
||||
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super(Expression, self).__init__(output_field=output_field)
|
||||
self.connector = connector
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.lhs, self.rhs]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.lhs, self.rhs = exprs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
expressions = []
|
||||
expression_params = []
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
expressions.append(sql)
|
||||
expression_params.extend(params)
|
||||
sql, params = compiler.compile(self.rhs)
|
||||
expressions.append(sql)
|
||||
expression_params.extend(params)
|
||||
# order of precedence
|
||||
expression_wrapper = '(%s)'
|
||||
sql = connection.ops.combine_expression(self.connector, expressions)
|
||||
return expression_wrapper % sql, expression_params
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
return c
|
||||
|
||||
|
||||
class DateModifierNode(Expression):
|
||||
"""
|
||||
Node that implements the following syntax:
|
||||
filter(end_date__gt=F('start_date') + datetime.timedelta(days=3, seconds=200))
|
||||
@@ -183,14 +375,195 @@ class DateModifierNode(ExpressionNode):
|
||||
Only adding and subtracting timedeltas is supported, attempts to use other
|
||||
operations raise a TypeError.
|
||||
"""
|
||||
def __init__(self, children, connector, negated=False):
|
||||
if len(children) != 2:
|
||||
raise TypeError('Must specify a node and a timedelta.')
|
||||
if not isinstance(children[1], datetime.timedelta):
|
||||
raise TypeError('Second child must be a timedelta.')
|
||||
def __init__(self, lhs, connector, rhs):
|
||||
if not isinstance(rhs, datetime.timedelta):
|
||||
raise TypeError('rhs must be a timedelta.')
|
||||
if connector not in (self.ADD, self.SUB):
|
||||
raise TypeError('Connector must be + or -, not %s' % connector)
|
||||
super(DateModifierNode, self).__init__(children, connector, negated)
|
||||
super(DateModifierNode, self).__init__(lhs, connector, Value(rhs))
|
||||
|
||||
def evaluate(self, evaluator, qn, connection):
|
||||
return evaluator.evaluate_date_modifier_node(self, qn, connection)
|
||||
def as_sql(self, compiler, connection):
|
||||
timedelta = self.rhs.value
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
|
||||
return sql, params
|
||||
return connection.ops.date_interval_sql(sql, self.connector, timedelta), params
|
||||
|
||||
|
||||
class F(CombinableMixin):
|
||||
"""
|
||||
An object capable of resolving references to existing query objects.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Arguments:
|
||||
* name: the name of the field this expression references
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
|
||||
|
||||
def refs_aggregate(self, existing_aggregates):
|
||||
return refs_aggregate(self.name.split(LOOKUP_SEP), existing_aggregates)
|
||||
|
||||
|
||||
class Func(ExpressionNode):
|
||||
"""
|
||||
A SQL function call.
|
||||
"""
|
||||
function = None
|
||||
template = '%(function)s(%(expressions)s)'
|
||||
arg_joiner = ', '
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
output_field = extra.pop('output_field', None)
|
||||
super(Func, self).__init__(output_field=output_field)
|
||||
self.source_expressions = self._parse_expressions(*expressions)
|
||||
self.extra = extra
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.source_expressions
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.source_expressions = exprs
|
||||
|
||||
def _parse_expressions(self, *expressions):
|
||||
return [
|
||||
arg if hasattr(arg, 'resolve_expression') else F(arg)
|
||||
for arg in expressions
|
||||
]
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
for pos, arg in enumerate(c.source_expressions):
|
||||
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
sql_parts = []
|
||||
params = []
|
||||
for arg in self.source_expressions:
|
||||
arg_sql, arg_params = compiler.compile(arg)
|
||||
sql_parts.append(arg_sql)
|
||||
params.extend(arg_params)
|
||||
if function is None:
|
||||
self.extra['function'] = self.extra.get('function', self.function)
|
||||
else:
|
||||
self.extra['function'] = function
|
||||
self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts)
|
||||
template = template or self.extra.get('template', self.template)
|
||||
return template % self.extra, params
|
||||
|
||||
def copy(self):
|
||||
copy = super(Func, self).copy()
|
||||
copy.source_expressions = self.source_expressions[:]
|
||||
copy.extra = self.extra.copy()
|
||||
return copy
|
||||
|
||||
|
||||
class Value(ExpressionNode):
|
||||
"""
|
||||
Represents a wrapped value as a node within an expression
|
||||
"""
|
||||
def __init__(self, value, output_field=None):
|
||||
"""
|
||||
Arguments:
|
||||
* value: the value this expression represents. The value will be
|
||||
added into the sql parameter list and properly quoted.
|
||||
|
||||
* output_field: an instance of the model field type that this
|
||||
expression will return, such as IntegerField() or CharField().
|
||||
"""
|
||||
super(Value, self).__init__(output_field=output_field)
|
||||
self.value = value
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return '%s', [self.value]
|
||||
|
||||
|
||||
class Col(ExpressionNode):
|
||||
def __init__(self, alias, target, source=None):
|
||||
if source is None:
|
||||
source = target
|
||||
super(Col, self).__init__(output_field=source)
|
||||
self.alias, self.target = alias, target
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [(self.alias, self.target.column)]
|
||||
|
||||
|
||||
class Ref(ExpressionNode):
|
||||
"""
|
||||
Reference to column alias of the query. For example, Ref('sum_cost') in
|
||||
qs.annotate(sum_cost=Sum('cost')) query.
|
||||
"""
|
||||
def __init__(self, refs, source):
|
||||
super(Ref, self).__init__()
|
||||
self.source = source
|
||||
self.refs = refs
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.source]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.source, = exprs
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return "%s" % compiler(self.refs), []
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [(None, self.refs)]
|
||||
|
||||
|
||||
class Date(ExpressionNode):
|
||||
"""
|
||||
Add a date selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type):
|
||||
super(Date, self).__init__(output_field=fields.DateField())
|
||||
self.col = col
|
||||
self.lookup_type = lookup_type
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.col]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.col, = self.exprs
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = self.col.as_sql(qn, connection)
|
||||
assert not(params)
|
||||
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
|
||||
|
||||
|
||||
class DateTime(ExpressionNode):
|
||||
"""
|
||||
Add a datetime selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type, tzname):
|
||||
super(DateTime, self).__init__(output_field=fields.DateTimeField())
|
||||
self.col = col
|
||||
self.lookup_type = lookup_type
|
||||
self.tzname = tzname
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.col]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.col, = exprs
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = self.col.as_sql(qn, connection)
|
||||
assert not(params)
|
||||
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)
|
||||
|
@@ -637,8 +637,6 @@ class Field(RegisterLookupMixin):
|
||||
"""
|
||||
Perform preliminary non-db specific lookup checks and conversions
|
||||
"""
|
||||
if hasattr(value, 'prepare'):
|
||||
return value.prepare()
|
||||
if hasattr(value, '_prepare'):
|
||||
return value._prepare()
|
||||
|
||||
|
@@ -13,7 +13,7 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
|
||||
from django.db.models.lookups import IsNull
|
||||
from django.db.models.related import RelatedObject, PathInfo
|
||||
from django.db.models.query import QuerySet
|
||||
from django.db.models.sql.datastructures import Col
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
from django.utils import six
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
@@ -154,8 +154,7 @@ class QuerySet(object):
|
||||
2. sql/compiler.results_iter()
|
||||
- Returns one row at time. At this point the rows are still just
|
||||
tuples. In some cases the return values are converted to
|
||||
Python values at this location (see resolve_columns(),
|
||||
resolve_aggregate()).
|
||||
Python values at this location.
|
||||
3. self.iterator()
|
||||
- Responsible for turning the rows into model objects.
|
||||
"""
|
||||
@@ -241,7 +240,7 @@ class QuerySet(object):
|
||||
max_depth = self.query.max_depth
|
||||
|
||||
extra_select = list(self.query.extra_select)
|
||||
aggregate_select = list(self.query.aggregate_select)
|
||||
annotation_select = list(self.query.annotation_select)
|
||||
|
||||
only_load = self.query.get_loaded_field_names()
|
||||
fields = self.model._meta.concrete_fields
|
||||
@@ -282,7 +281,7 @@ class QuerySet(object):
|
||||
db = self.db
|
||||
compiler = self.query.get_compiler(using=db)
|
||||
index_start = len(extra_select)
|
||||
aggregate_start = index_start + len(init_list)
|
||||
annotation_start = index_start + len(init_list)
|
||||
|
||||
if fill_cache:
|
||||
klass_info = get_klass_info(model_cls, max_depth=max_depth,
|
||||
@@ -290,18 +289,18 @@ class QuerySet(object):
|
||||
for row in compiler.results_iter():
|
||||
if fill_cache:
|
||||
obj, _ = get_cached_row(row, index_start, db, klass_info,
|
||||
offset=len(aggregate_select))
|
||||
offset=len(annotation_select))
|
||||
else:
|
||||
obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start])
|
||||
obj = model_cls.from_db(db, init_list, row[index_start:annotation_start])
|
||||
|
||||
if extra_select:
|
||||
for i, k in enumerate(extra_select):
|
||||
setattr(obj, k, row[i])
|
||||
|
||||
# Add the aggregates to the model
|
||||
if aggregate_select:
|
||||
for i, aggregate in enumerate(aggregate_select):
|
||||
setattr(obj, aggregate, row[i + aggregate_start])
|
||||
# Add the annotations to the model
|
||||
if annotation_select:
|
||||
for i, annotation in enumerate(annotation_select):
|
||||
setattr(obj, annotation, row[i + annotation_start])
|
||||
|
||||
# Add the known related objects to the model, if there are any
|
||||
if self._known_related_objects:
|
||||
@@ -330,13 +329,16 @@ class QuerySet(object):
|
||||
if self.query.distinct_fields:
|
||||
raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
|
||||
for arg in args:
|
||||
if not hasattr(arg, 'default_alias'):
|
||||
raise TypeError("Complex aggregates require an alias")
|
||||
kwargs[arg.default_alias] = arg
|
||||
|
||||
query = self.query.clone()
|
||||
force_subq = query.low_mark != 0 or query.high_mark is not None
|
||||
for (alias, aggregate_expr) in kwargs.items():
|
||||
query.add_aggregate(aggregate_expr, self.model, alias,
|
||||
is_summary=True)
|
||||
query.add_annotation(aggregate_expr, self.model, alias, is_summary=True)
|
||||
if not query.annotations[alias].contains_aggregate:
|
||||
raise TypeError("%s is not an aggregate expression" % alias)
|
||||
return query.get_aggregation(using=self.db, force_subq=force_subq)
|
||||
|
||||
def count(self):
|
||||
@@ -787,33 +789,40 @@ class QuerySet(object):
|
||||
def annotate(self, *args, **kwargs):
|
||||
"""
|
||||
Return a query set in which the returned objects have been annotated
|
||||
with data aggregated from related fields.
|
||||
with extra data or aggregations.
|
||||
"""
|
||||
aggrs = OrderedDict() # To preserve ordering of args
|
||||
annotations = OrderedDict() # To preserve ordering of args
|
||||
for arg in args:
|
||||
if arg.default_alias in kwargs:
|
||||
raise ValueError("The named annotation '%s' conflicts with the "
|
||||
"default name for another annotation."
|
||||
% arg.default_alias)
|
||||
aggrs[arg.default_alias] = arg
|
||||
aggrs.update(kwargs)
|
||||
try:
|
||||
# we can't do an hasattr here because py2 returns False
|
||||
# if default_alias exists but throws a TypeError
|
||||
if arg.default_alias in kwargs:
|
||||
raise ValueError("The named annotation '%s' conflicts with the "
|
||||
"default name for another annotation."
|
||||
% arg.default_alias)
|
||||
except AttributeError: # default_alias
|
||||
raise TypeError("Complex annotations require an alias")
|
||||
annotations[arg.default_alias] = arg
|
||||
annotations.update(kwargs)
|
||||
|
||||
obj = self._clone()
|
||||
names = getattr(self, '_fields', None)
|
||||
if names is None:
|
||||
names = set(self.model._meta.get_all_field_names())
|
||||
for aggregate in aggrs:
|
||||
if aggregate in names:
|
||||
|
||||
# Add the annotations to the query
|
||||
for alias, annotation in annotations.items():
|
||||
if alias in names:
|
||||
raise ValueError("The annotation '%s' conflicts with a field on "
|
||||
"the model." % aggregate)
|
||||
|
||||
obj = self._clone()
|
||||
|
||||
obj._setup_aggregate_query(list(aggrs))
|
||||
|
||||
# Add the aggregates to the query
|
||||
for (alias, aggregate_expr) in aggrs.items():
|
||||
obj.query.add_aggregate(aggregate_expr, self.model, alias,
|
||||
is_summary=False)
|
||||
"the model." % alias)
|
||||
obj.query.add_annotation(annotation, self.model, alias, is_summary=False)
|
||||
# expressions need to be added to the query before we know if they contain aggregates
|
||||
added_aggregates = []
|
||||
for alias, annotation in obj.query.annotations.items():
|
||||
if alias in annotations and annotation.contains_aggregate:
|
||||
added_aggregates.append(alias)
|
||||
if added_aggregates:
|
||||
obj._setup_aggregate_query(list(added_aggregates))
|
||||
|
||||
return obj
|
||||
|
||||
@@ -1096,9 +1105,9 @@ class ValuesQuerySet(QuerySet):
|
||||
# Purge any extra columns that haven't been explicitly asked for
|
||||
extra_names = list(self.query.extra_select)
|
||||
field_names = self.field_names
|
||||
aggregate_names = list(self.query.aggregate_select)
|
||||
annotation_names = list(self.query.annotation_select)
|
||||
|
||||
names = extra_names + field_names + aggregate_names
|
||||
names = extra_names + field_names + annotation_names
|
||||
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
yield dict(zip(names, row))
|
||||
@@ -1122,9 +1131,9 @@ class ValuesQuerySet(QuerySet):
|
||||
|
||||
if self._fields:
|
||||
self.extra_names = []
|
||||
self.aggregate_names = []
|
||||
if not self.query._extra and not self.query._aggregates:
|
||||
# Short cut - if there are no extra or aggregates, then
|
||||
self.annotation_names = []
|
||||
if not self.query._extra and not self.query._annotations:
|
||||
# Short cut - if there are no extra or annotations, then
|
||||
# the values() clause must be just field names.
|
||||
self.field_names = list(self._fields)
|
||||
else:
|
||||
@@ -1136,22 +1145,22 @@ class ValuesQuerySet(QuerySet):
|
||||
# had selected previously.
|
||||
if self.query._extra and f in self.query._extra:
|
||||
self.extra_names.append(f)
|
||||
elif f in self.query.aggregate_select:
|
||||
self.aggregate_names.append(f)
|
||||
elif f in self.query.annotation_select:
|
||||
self.annotation_names.append(f)
|
||||
else:
|
||||
self.field_names.append(f)
|
||||
else:
|
||||
# Default to all fields.
|
||||
self.extra_names = None
|
||||
self.field_names = [f.attname for f in self.model._meta.concrete_fields]
|
||||
self.aggregate_names = None
|
||||
self.annotation_names = None
|
||||
|
||||
self.query.select = []
|
||||
if self.extra_names is not None:
|
||||
self.query.set_extra_mask(self.extra_names)
|
||||
self.query.add_fields(self.field_names, True)
|
||||
if self.aggregate_names is not None:
|
||||
self.query.set_aggregate_mask(self.aggregate_names)
|
||||
if self.annotation_names is not None:
|
||||
self.query.set_annotation_mask(self.annotation_names)
|
||||
|
||||
def _clone(self, klass=None, setup=False, **kwargs):
|
||||
"""
|
||||
@@ -1164,7 +1173,7 @@ class ValuesQuerySet(QuerySet):
|
||||
c._fields = self._fields[:]
|
||||
c.field_names = self.field_names
|
||||
c.extra_names = self.extra_names
|
||||
c.aggregate_names = self.aggregate_names
|
||||
c.annotation_names = self.annotation_names
|
||||
if setup and hasattr(c, '_setup_query'):
|
||||
c._setup_query()
|
||||
return c
|
||||
@@ -1173,7 +1182,7 @@ class ValuesQuerySet(QuerySet):
|
||||
super(ValuesQuerySet, self)._merge_sanity_check(other)
|
||||
if (set(self.extra_names) != set(other.extra_names) or
|
||||
set(self.field_names) != set(other.field_names) or
|
||||
self.aggregate_names != other.aggregate_names):
|
||||
self.annotation_names != other.annotation_names):
|
||||
raise TypeError("Merging '%s' classes must involve the same values in each case."
|
||||
% self.__class__.__name__)
|
||||
|
||||
@@ -1183,9 +1192,9 @@ class ValuesQuerySet(QuerySet):
|
||||
"""
|
||||
self.query.set_group_by()
|
||||
|
||||
if self.aggregate_names is not None:
|
||||
self.aggregate_names.extend(aggregates)
|
||||
self.query.set_aggregate_mask(self.aggregate_names)
|
||||
if self.annotation_names is not None:
|
||||
self.annotation_names.extend(aggregates)
|
||||
self.query.set_annotation_mask(self.annotation_names)
|
||||
|
||||
super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
|
||||
|
||||
@@ -1231,7 +1240,7 @@ class ValuesListQuerySet(ValuesQuerySet):
|
||||
if self.flat and len(self._fields) == 1:
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
yield row[0]
|
||||
elif not self.query.extra_select and not self.query.aggregate_select:
|
||||
elif not self.query.extra_select and not self.query.annotation_select:
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
yield tuple(row)
|
||||
else:
|
||||
@@ -1240,14 +1249,14 @@ class ValuesListQuerySet(ValuesQuerySet):
|
||||
# the fields to match the order in self._fields.
|
||||
extra_names = list(self.query.extra_select)
|
||||
field_names = self.field_names
|
||||
aggregate_names = list(self.query.aggregate_select)
|
||||
annotation_names = list(self.query.annotation_select)
|
||||
|
||||
names = extra_names + field_names + aggregate_names
|
||||
names = extra_names + field_names + annotation_names
|
||||
|
||||
# If a field list has been specified, use it. Otherwise, use the
|
||||
# full list of fields, including extras and aggregates.
|
||||
# full list of fields, including extras and annotations.
|
||||
if self._fields:
|
||||
fields = list(self._fields) + [f for f in aggregate_names if f not in self._fields]
|
||||
fields = list(self._fields) + [f for f in annotation_names if f not in self._fields]
|
||||
else:
|
||||
fields = names
|
||||
|
||||
|
@@ -9,6 +9,7 @@ from __future__ import unicode_literals
|
||||
|
||||
from django.apps import apps
|
||||
from django.db.backends import utils
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.utils import six
|
||||
from django.utils import tree
|
||||
|
||||
@@ -220,3 +221,17 @@ def deferred_class_factory(model, attrs):
|
||||
# The above function is also used to unpickle model instances with deferred
|
||||
# fields.
|
||||
deferred_class_factory.__safe_for_unpickling__ = True
|
||||
|
||||
|
||||
def refs_aggregate(lookup_parts, aggregates):
|
||||
"""
|
||||
A little helper method to check if the lookup_parts contains references
|
||||
to the given aggregates set. Because the LOOKUP_SEP is contained in the
|
||||
default annotation names we must check each prefix of the lookup_parts
|
||||
for a match.
|
||||
"""
|
||||
for n in range(len(lookup_parts) + 1):
|
||||
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
|
||||
if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate:
|
||||
return aggregates[level_n_lookup], lookup_parts[n:]
|
||||
return False, ()
|
||||
|
@@ -2,15 +2,23 @@
|
||||
Classes to represent the default SQL aggregate functions
|
||||
"""
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
from django.db.models.fields import IntegerField, FloatField
|
||||
from django.db.models.lookups import RegisterLookupMixin
|
||||
from django.utils.deprecation import RemovedInDjango20Warning
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
__all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance']
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"django.db.models.sql.aggregates is deprecated. Use "
|
||||
"django.db.models.aggregates instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
|
||||
|
||||
class Aggregate(RegisterLookupMixin):
|
||||
"""
|
||||
Default SQL Aggregate.
|
||||
|
@@ -4,12 +4,10 @@ from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.backends.utils import truncate_name
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import ExpressionNode
|
||||
from django.db.models.query_utils import select_related_descend, QueryWrapper
|
||||
from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
|
||||
ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
|
||||
from django.db.models.sql.datastructures import EmptyResultSet
|
||||
from django.db.models.sql.expressions import SQLEvaluator
|
||||
from django.db.models.sql.query import get_order_dir, Query
|
||||
from django.db.transaction import TransactionManagementError
|
||||
from django.db.utils import DatabaseError
|
||||
@@ -248,8 +246,8 @@ class SQLCompiler(object):
|
||||
aliases.update(new_aliases)
|
||||
|
||||
max_name_length = self.connection.ops.max_name_length()
|
||||
for alias, aggregate in self.query.aggregate_select.items():
|
||||
agg_sql, agg_params = self.compile(aggregate)
|
||||
for alias, annotation in self.query.annotation_select.items():
|
||||
agg_sql, agg_params = self.compile(annotation)
|
||||
if alias is None:
|
||||
result.append(agg_sql)
|
||||
else:
|
||||
@@ -409,7 +407,7 @@ class SQLCompiler(object):
|
||||
group_by.append((str(field), []))
|
||||
continue
|
||||
col, order = get_order_dir(field, asc)
|
||||
if col in self.query.aggregate_select:
|
||||
if col in self.query.annotation_select:
|
||||
result.append('%s %s' % (qn(col), order))
|
||||
continue
|
||||
if '.' in field:
|
||||
@@ -718,25 +716,17 @@ class SQLCompiler(object):
|
||||
"""
|
||||
fields = None
|
||||
converters = None
|
||||
has_aggregate_select = bool(self.query.aggregate_select)
|
||||
has_annotation_select = bool(self.query.annotation_select)
|
||||
for rows in self.execute_sql(MULTI):
|
||||
for row in rows:
|
||||
if has_aggregate_select:
|
||||
loaded_fields = (
|
||||
self.query.get_loaded_field_names().get(self.query.model, set()) or
|
||||
self.query.select
|
||||
)
|
||||
aggregate_start = len(self.query.extra_select) + len(loaded_fields)
|
||||
aggregate_end = aggregate_start + len(self.query.aggregate_select)
|
||||
if fields is None:
|
||||
# We only set this up here because
|
||||
# related_select_cols isn't populated until
|
||||
# execute_sql() has been called.
|
||||
|
||||
# We also include types of fields of related models that
|
||||
# will be included via select_related() for the benefit
|
||||
# of MySQL/MySQLdb when boolean fields are involved
|
||||
# (#15040).
|
||||
# If the field was deferred, exclude it from being passed
|
||||
# into `get_converters` because it wasn't selected.
|
||||
only_load = self.deferred_to_columns()
|
||||
|
||||
# This code duplicates the logic for the order of fields
|
||||
# found in get_columns(). It would be nice to clean this up.
|
||||
@@ -746,30 +736,45 @@ class SQLCompiler(object):
|
||||
fields = self.query.get_meta().concrete_fields
|
||||
else:
|
||||
fields = []
|
||||
fields = fields + [f.field for f in self.query.related_select_cols]
|
||||
|
||||
# If the field was deferred, exclude it from being passed
|
||||
# into `get_converters` because it wasn't selected.
|
||||
only_load = self.deferred_to_columns()
|
||||
if only_load:
|
||||
fields = [f for f in fields if f.model._meta.db_table not in only_load or
|
||||
f.column in only_load[f.model._meta.db_table]]
|
||||
if has_aggregate_select:
|
||||
# pad None in to fields for aggregates
|
||||
fields = fields[:aggregate_start] + [
|
||||
None for x in range(0, aggregate_end - aggregate_start)
|
||||
] + fields[aggregate_start:]
|
||||
# strip deferred fields
|
||||
fields = [
|
||||
f for f in fields if
|
||||
f.model._meta.db_table not in only_load or
|
||||
f.column in only_load[f.model._meta.db_table]
|
||||
]
|
||||
|
||||
# annotations come before the related cols
|
||||
if has_annotation_select:
|
||||
# extra is always at the start of the field list
|
||||
prepended_cols = len(self.query.extra_select)
|
||||
annotation_start = len(fields) + prepended_cols
|
||||
fields = fields + [
|
||||
anno.output_field for alias, anno in self.query.annotation_select.items()]
|
||||
annotation_end = len(fields) + prepended_cols
|
||||
|
||||
# add related fields
|
||||
fields = fields + [
|
||||
# strip deferred
|
||||
f.field for f in self.query.related_select_cols if
|
||||
f.field.model._meta.db_table not in only_load or
|
||||
f.field.column in only_load[f.field.model._meta.db_table]
|
||||
]
|
||||
|
||||
converters = self.get_converters(fields)
|
||||
if has_annotation_select:
|
||||
for (alias, annotation), position in zip(
|
||||
self.query.annotation_select.items(),
|
||||
range(annotation_start, annotation_end + 1)):
|
||||
if position in converters:
|
||||
# annotation conversions always run first
|
||||
converters[position][1].insert(0, annotation.convert_value)
|
||||
else:
|
||||
converters[position] = ([], [annotation.convert_value], annotation.output_field)
|
||||
|
||||
if converters:
|
||||
row = self.apply_converters(row, converters)
|
||||
|
||||
if has_aggregate_select:
|
||||
row = tuple(row[:aggregate_start]) + tuple(
|
||||
self.query.resolve_aggregate(value, aggregate, self.connection)
|
||||
for (alias, aggregate), value
|
||||
in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end])
|
||||
) + tuple(row[aggregate_end:])
|
||||
|
||||
yield row
|
||||
|
||||
def has_results(self):
|
||||
@@ -878,7 +883,7 @@ class SQLInsertCompiler(SQLCompiler):
|
||||
elif hasattr(field, 'get_placeholder'):
|
||||
# Some fields (e.g. geo fields) need special munging before
|
||||
# they can be inserted.
|
||||
return field.get_placeholder(val, self.connection)
|
||||
return field.get_placeholder(val, self, self.connection)
|
||||
else:
|
||||
# Return the common case for the placeholder
|
||||
return '%s'
|
||||
@@ -985,8 +990,10 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
result.append('SET')
|
||||
values, update_params = [], []
|
||||
for field, model, val in self.query.values:
|
||||
if hasattr(val, 'prepare_database_save'):
|
||||
if field.rel or isinstance(val, ExpressionNode):
|
||||
if hasattr(val, 'resolve_expression'):
|
||||
val = val.resolve_expression(self.query, allow_joins=False)
|
||||
elif hasattr(val, 'prepare_database_save'):
|
||||
if field.rel:
|
||||
val = val.prepare_database_save(field)
|
||||
else:
|
||||
raise TypeError("Database is trying to update a relational field "
|
||||
@@ -998,12 +1005,9 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
|
||||
# Getting the placeholder for the field.
|
||||
if hasattr(field, 'get_placeholder'):
|
||||
placeholder = field.get_placeholder(val, self.connection)
|
||||
placeholder = field.get_placeholder(val, self, self.connection)
|
||||
else:
|
||||
placeholder = '%s'
|
||||
|
||||
if hasattr(val, 'evaluate'):
|
||||
val = SQLEvaluator(val, self.query, allow_joins=False)
|
||||
name = field.column
|
||||
if hasattr(val, 'as_sql'):
|
||||
sql, params = self.compile(val)
|
||||
@@ -1103,8 +1107,8 @@ class SQLAggregateCompiler(SQLCompiler):
|
||||
qn = self
|
||||
|
||||
sql, params = [], []
|
||||
for aggregate in self.query.aggregate_select.values():
|
||||
agg_sql, agg_params = self.compile(aggregate)
|
||||
for annotation in self.query.annotation_select.values():
|
||||
agg_sql, agg_params = self.compile(annotation)
|
||||
sql.append(agg_sql)
|
||||
params.extend(agg_params)
|
||||
sql = ', '.join(sql)
|
||||
|
@@ -4,33 +4,6 @@ the SQL domain.
|
||||
"""
|
||||
|
||||
|
||||
class Col(object):
|
||||
def __init__(self, alias, target, source):
|
||||
self.alias, self.target, self.source = alias, target, source
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return self.source
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [(self.alias, self.target.column)]
|
||||
|
||||
def get_lookup(self, name):
|
||||
return self.output_field.get_lookup(name)
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.output_field.get_transform(name)
|
||||
|
||||
def prepare(self):
|
||||
return self
|
||||
|
||||
|
||||
class EmptyResultSet(Exception):
|
||||
pass
|
||||
|
||||
@@ -49,42 +22,3 @@ class MultiJoin(Exception):
|
||||
|
||||
class Empty(object):
|
||||
pass
|
||||
|
||||
|
||||
class Date(object):
|
||||
"""
|
||||
Add a date selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type):
|
||||
self.col = col
|
||||
self.lookup_type = lookup_type
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
if isinstance(self.col, (list, tuple)):
|
||||
col = '%s.%s' % tuple(qn(c) for c in self.col)
|
||||
else:
|
||||
col = self.col
|
||||
return connection.ops.date_trunc_sql(self.lookup_type, col), []
|
||||
|
||||
|
||||
class DateTime(object):
|
||||
"""
|
||||
Add a datetime selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type, tzname):
|
||||
self.col = col
|
||||
self.lookup_type = lookup_type
|
||||
self.tzname = tzname
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
if isinstance(self.col, (list, tuple)):
|
||||
col = '%s.%s' % tuple(qn(c) for c in self.col)
|
||||
else:
|
||||
col = self.col
|
||||
return connection.ops.datetime_trunc_sql(self.lookup_type, col, self.tzname)
|
||||
|
@@ -1,119 +0,0 @@
|
||||
import copy
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import FieldDoesNotExist
|
||||
|
||||
|
||||
class SQLEvaluator(object):
|
||||
def __init__(self, expression, query, allow_joins=True, reuse=None):
|
||||
self.expression = expression
|
||||
self.opts = query.get_meta()
|
||||
self.reuse = reuse
|
||||
self.cols = []
|
||||
self.expression.prepare(self, query, allow_joins)
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = copy.copy(self)
|
||||
clone.cols = []
|
||||
for node, col in self.cols:
|
||||
if hasattr(col, 'relabeled_clone'):
|
||||
clone.cols.append((node, col.relabeled_clone(change_map)))
|
||||
else:
|
||||
clone.cols.append((node,
|
||||
(change_map.get(col[0], col[0]), col[1])))
|
||||
return clone
|
||||
|
||||
def get_group_by_cols(self):
|
||||
cols = []
|
||||
for node, col in self.cols:
|
||||
if hasattr(node, 'get_group_by_cols'):
|
||||
cols.extend(node.get_group_by_cols())
|
||||
elif isinstance(col, tuple):
|
||||
cols.append(col)
|
||||
return cols
|
||||
|
||||
def prepare(self):
|
||||
return self
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
return self.expression.evaluate(self, qn, connection)
|
||||
|
||||
#####################################################
|
||||
# Visitor methods for initial expression preparation #
|
||||
#####################################################
|
||||
|
||||
def prepare_node(self, node, query, allow_joins):
|
||||
for child in node.children:
|
||||
if hasattr(child, 'prepare'):
|
||||
child.prepare(self, query, allow_joins)
|
||||
|
||||
def prepare_leaf(self, node, query, allow_joins):
|
||||
if not allow_joins and LOOKUP_SEP in node.name:
|
||||
raise FieldError("Joined field references are not permitted in this query")
|
||||
|
||||
field_list = node.name.split(LOOKUP_SEP)
|
||||
if node.name in query.aggregates:
|
||||
self.cols.append((node, query.aggregate_select[node.name]))
|
||||
else:
|
||||
try:
|
||||
_, sources, _, join_list, path = query.setup_joins(
|
||||
field_list, query.get_meta(), query.get_initial_alias(),
|
||||
can_reuse=self.reuse)
|
||||
self._used_joins = join_list
|
||||
targets, _, join_list = query.trim_joins(sources, join_list, path)
|
||||
if self.reuse is not None:
|
||||
self.reuse.update(join_list)
|
||||
for t in targets:
|
||||
self.cols.append((node, (join_list[-1], t.column)))
|
||||
except FieldDoesNotExist:
|
||||
raise FieldError("Cannot resolve keyword %r into field. "
|
||||
"Choices are: %s" % (self.name,
|
||||
[f.name for f in self.opts.fields]))
|
||||
|
||||
##################################################
|
||||
# Visitor methods for final expression evaluation #
|
||||
##################################################
|
||||
|
||||
def evaluate_node(self, node, qn, connection):
|
||||
expressions = []
|
||||
expression_params = []
|
||||
for child in node.children:
|
||||
if hasattr(child, 'evaluate'):
|
||||
sql, params = child.evaluate(self, qn, connection)
|
||||
else:
|
||||
sql, params = '%s', (child,)
|
||||
|
||||
if len(getattr(child, 'children', [])) > 1:
|
||||
format = '(%s)'
|
||||
else:
|
||||
format = '%s'
|
||||
|
||||
if sql:
|
||||
expressions.append(format % sql)
|
||||
expression_params.extend(params)
|
||||
|
||||
return connection.ops.combine_expression(node.connector, expressions), expression_params
|
||||
|
||||
def evaluate_leaf(self, node, qn, connection):
|
||||
col = None
|
||||
for n, c in self.cols:
|
||||
if n is node:
|
||||
col = c
|
||||
break
|
||||
if col is None:
|
||||
raise ValueError("Given node not found")
|
||||
if hasattr(col, 'as_sql'):
|
||||
return col.as_sql(qn, connection)
|
||||
else:
|
||||
return '%s.%s' % (qn(col[0]), qn(col[1])), []
|
||||
|
||||
def evaluate_date_modifier_node(self, node, qn, connection):
|
||||
timedelta = node.children.pop()
|
||||
sql, params = self.evaluate_node(node, qn, connection)
|
||||
node.children.append(timedelta)
|
||||
|
||||
if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
|
||||
return sql, params
|
||||
|
||||
return connection.ops.date_interval_sql(sql, node.connector, timedelta), params
|
@@ -14,20 +14,18 @@ import warnings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import connections, DEFAULT_DB_ALIAS
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.aggregates import refs_aggregate
|
||||
from django.db.models.expressions import ExpressionNode
|
||||
from django.db.models.expressions import Col, Ref
|
||||
from django.db.models.fields import FieldDoesNotExist
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.query_utils import Q, refs_aggregate
|
||||
from django.db.models.related import PathInfo
|
||||
from django.db.models.sql import aggregates as base_aggregates_module
|
||||
from django.db.models.aggregates import Count
|
||||
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
|
||||
ORDER_PATTERN, JoinInfo, SelectInfo)
|
||||
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin, Col
|
||||
from django.db.models.sql.expressions import SQLEvaluator
|
||||
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
|
||||
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
|
||||
ExtraWhere, AND, OR, EmptyWhere)
|
||||
from django.utils import six
|
||||
from django.utils.deprecation import RemovedInDjango19Warning
|
||||
from django.utils.deprecation import RemovedInDjango19Warning, RemovedInDjango20Warning
|
||||
from django.utils.encoding import force_text
|
||||
from django.utils.tree import Node
|
||||
|
||||
@@ -49,7 +47,7 @@ class RawQuery(object):
|
||||
# the compiler can be used to process results.
|
||||
self.low_mark, self.high_mark = 0, None # Used for offset/limit
|
||||
self.extra_select = {}
|
||||
self.aggregate_select = {}
|
||||
self.annotation_select = {}
|
||||
|
||||
def clone(self, using):
|
||||
return RawQuery(self.sql, using, params=self.params)
|
||||
@@ -97,7 +95,6 @@ class Query(object):
|
||||
alias_prefix = 'T'
|
||||
subq_aliases = frozenset([alias_prefix])
|
||||
query_terms = QUERY_TERMS
|
||||
aggregates_module = base_aggregates_module
|
||||
|
||||
compiler = 'SQLCompiler'
|
||||
|
||||
@@ -140,13 +137,13 @@ class Query(object):
|
||||
self.select_for_update_nowait = False
|
||||
self.select_related = False
|
||||
|
||||
# SQL aggregate-related attributes
|
||||
# The _aggregates will be an OrderedDict when used. Due to the cost
|
||||
# SQL annotation-related attributes
|
||||
# The _annotations will be an OrderedDict when used. Due to the cost
|
||||
# of creating OrderedDict this attribute is created lazily (in
|
||||
# self.aggregates property).
|
||||
self._aggregates = None # Maps alias -> SQL aggregate function
|
||||
self.aggregate_select_mask = None
|
||||
self._aggregate_select_cache = None
|
||||
# self.annotations property).
|
||||
self._annotations = None # Maps alias -> Annotation Expression
|
||||
self.annotation_select_mask = None
|
||||
self._annotation_select_cache = None
|
||||
|
||||
# Arbitrary maximum limit for select_related. Prevents infinite
|
||||
# recursion. Can be changed by the depth parameter to select_related().
|
||||
@@ -155,7 +152,7 @@ class Query(object):
|
||||
# These are for extensions. The contents are more or less appended
|
||||
# verbatim to the appropriate clause.
|
||||
# The _extra attribute is an OrderedDict, lazily created similarly to
|
||||
# .aggregates
|
||||
# .annotations
|
||||
self._extra = None # Maps col_alias -> (col_sql, params).
|
||||
self.extra_select_mask = None
|
||||
self._extra_select_cache = None
|
||||
@@ -174,11 +171,18 @@ class Query(object):
|
||||
self._extra = OrderedDict()
|
||||
return self._extra
|
||||
|
||||
@property
|
||||
def annotations(self):
|
||||
if self._annotations is None:
|
||||
self._annotations = OrderedDict()
|
||||
return self._annotations
|
||||
|
||||
@property
|
||||
def aggregates(self):
|
||||
if self._aggregates is None:
|
||||
self._aggregates = OrderedDict()
|
||||
return self._aggregates
|
||||
warnings.warn(
|
||||
"The aggregates property is deprecated. Use annotations instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
return self.annotations
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
@@ -203,7 +207,7 @@ class Query(object):
|
||||
memo[id(self)] = result
|
||||
return result
|
||||
|
||||
def prepare(self):
|
||||
def _prepare(self):
|
||||
return self
|
||||
|
||||
def get_compiler(self, using=None, connection=None):
|
||||
@@ -213,8 +217,8 @@ class Query(object):
|
||||
connection = connections[using]
|
||||
|
||||
# Check that the compiler will be able to execute the query
|
||||
for alias, aggregate in self.aggregate_select.items():
|
||||
connection.ops.check_aggregate_support(aggregate)
|
||||
for alias, annotation in self.annotation_select.items():
|
||||
connection.ops.check_aggregate_support(annotation)
|
||||
|
||||
return connection.ops.compiler(self.compiler)(self, connection, using)
|
||||
|
||||
@@ -260,17 +264,17 @@ class Query(object):
|
||||
obj.select_for_update_nowait = self.select_for_update_nowait
|
||||
obj.select_related = self.select_related
|
||||
obj.related_select_cols = []
|
||||
obj._aggregates = self._aggregates.copy() if self._aggregates is not None else None
|
||||
if self.aggregate_select_mask is None:
|
||||
obj.aggregate_select_mask = None
|
||||
obj._annotations = self._annotations.copy() if self._annotations is not None else None
|
||||
if self.annotation_select_mask is None:
|
||||
obj.annotation_select_mask = None
|
||||
else:
|
||||
obj.aggregate_select_mask = self.aggregate_select_mask.copy()
|
||||
# _aggregate_select_cache cannot be copied, as doing so breaks the
|
||||
# (necessary) state in which both aggregates and
|
||||
# _aggregate_select_cache point to the same underlying objects.
|
||||
obj.annotation_select_mask = self.annotation_select_mask.copy()
|
||||
# _annotation_select_cache cannot be copied, as doing so breaks the
|
||||
# (necessary) state in which both annotations and
|
||||
# _annotation_select_cache point to the same underlying objects.
|
||||
# It will get re-populated in the cloned queryset the next time it's
|
||||
# used.
|
||||
obj._aggregate_select_cache = None
|
||||
obj._annotation_select_cache = None
|
||||
obj.max_depth = self.max_depth
|
||||
obj._extra = self._extra.copy() if self._extra is not None else None
|
||||
if self.extra_select_mask is None:
|
||||
@@ -299,94 +303,84 @@ class Query(object):
|
||||
obj._setup_query()
|
||||
return obj
|
||||
|
||||
def resolve_aggregate(self, value, aggregate, connection):
|
||||
"""Resolve the value of aggregates returned by the database to
|
||||
consistent (and reasonable) types.
|
||||
|
||||
This is required because of the predisposition of certain backends
|
||||
to return Decimal and long types when they are not needed.
|
||||
"""
|
||||
if value is None:
|
||||
if aggregate.is_ordinal:
|
||||
return 0
|
||||
# Return None as-is
|
||||
return value
|
||||
elif aggregate.is_ordinal:
|
||||
# Any ordinal aggregate (e.g., count) returns an int
|
||||
return int(value)
|
||||
elif aggregate.is_computed:
|
||||
# Any computed aggregate (e.g., avg) returns a float
|
||||
return float(value)
|
||||
else:
|
||||
# Return value depends on the type of the field being processed.
|
||||
backend_converters = connection.ops.get_db_converters(aggregate.field.get_internal_type())
|
||||
field_converters = aggregate.field.get_db_converters(connection)
|
||||
for converter in backend_converters:
|
||||
value = converter(value, aggregate.field)
|
||||
for converter in field_converters:
|
||||
value = converter(value, connection)
|
||||
return value
|
||||
|
||||
def get_aggregation(self, using, force_subq=False):
|
||||
"""
|
||||
Returns the dictionary with the values of the existing aggregations.
|
||||
"""
|
||||
if not self.aggregate_select:
|
||||
if not self.annotation_select:
|
||||
return {}
|
||||
|
||||
# annotations must be forced into subquery
|
||||
has_annotation = any(
|
||||
annotation for alias, annotation
|
||||
in self.annotation_select.items()
|
||||
if not annotation.contains_aggregate)
|
||||
|
||||
# If there is a group by clause, aggregating does not add useful
|
||||
# information but retrieves only the first row. Aggregate
|
||||
# over the subquery instead.
|
||||
if self.group_by is not None or force_subq:
|
||||
if self.group_by is not None or force_subq or has_annotation:
|
||||
|
||||
from django.db.models.sql.subqueries import AggregateQuery
|
||||
query = AggregateQuery(self.model)
|
||||
obj = self.clone()
|
||||
outer_query = AggregateQuery(self.model)
|
||||
inner_query = self.clone()
|
||||
if not force_subq:
|
||||
# In forced subq case the ordering and limits will likely
|
||||
# affect the results.
|
||||
obj.clear_ordering(True)
|
||||
obj.clear_limits()
|
||||
obj.select_for_update = False
|
||||
obj.select_related = False
|
||||
obj.related_select_cols = []
|
||||
inner_query.clear_ordering(True)
|
||||
inner_query.clear_limits()
|
||||
inner_query.select_for_update = False
|
||||
inner_query.select_related = False
|
||||
inner_query.related_select_cols = []
|
||||
|
||||
relabels = dict((t, 'subquery') for t in self.tables)
|
||||
relabels = dict((t, 'subquery') for t in inner_query.tables)
|
||||
relabels[None] = 'subquery'
|
||||
# Remove any aggregates marked for reduction from the subquery
|
||||
# and move them to the outer AggregateQuery.
|
||||
for alias, aggregate in self.aggregate_select.items():
|
||||
if aggregate.is_summary:
|
||||
query.aggregates[alias] = aggregate.relabeled_clone(relabels)
|
||||
del obj.aggregate_select[alias]
|
||||
|
||||
for alias, annotation in inner_query.annotation_select.items():
|
||||
if annotation.is_summary:
|
||||
# The annotation is already referring the subquery alias, so we
|
||||
# just need to move the annotation to the outer query.
|
||||
outer_query.annotations[alias] = annotation.relabeled_clone(relabels)
|
||||
del inner_query.annotation_select[alias]
|
||||
try:
|
||||
query.add_subquery(obj, using)
|
||||
outer_query.add_subquery(inner_query, using)
|
||||
except EmptyResultSet:
|
||||
return dict(
|
||||
(alias, None)
|
||||
for alias in query.aggregate_select
|
||||
for alias in outer_query.annotation_select
|
||||
)
|
||||
else:
|
||||
query = self
|
||||
outer_query = self
|
||||
self.select = []
|
||||
self.default_cols = False
|
||||
self._extra = {}
|
||||
self.remove_inherited_models()
|
||||
|
||||
query.clear_ordering(True)
|
||||
query.clear_limits()
|
||||
query.select_for_update = False
|
||||
query.select_related = False
|
||||
query.related_select_cols = []
|
||||
|
||||
result = query.get_compiler(using).execute_sql(SINGLE)
|
||||
outer_query.clear_ordering(True)
|
||||
outer_query.clear_limits()
|
||||
outer_query.select_for_update = False
|
||||
outer_query.select_related = False
|
||||
outer_query.related_select_cols = []
|
||||
compiler = outer_query.get_compiler(using)
|
||||
result = compiler.execute_sql(SINGLE)
|
||||
if result is None:
|
||||
result = [None for q in query.aggregate_select.items()]
|
||||
result = [None for q in outer_query.annotation_select.items()]
|
||||
|
||||
fields = [annotation.output_field
|
||||
for alias, annotation in outer_query.annotation_select.items()]
|
||||
converters = compiler.get_converters(fields)
|
||||
for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()):
|
||||
if position in converters:
|
||||
converters[position][1].insert(0, annotation.convert_value)
|
||||
else:
|
||||
converters[position] = ([], [annotation.convert_value], annotation.output_field)
|
||||
result = compiler.apply_converters(result, converters)
|
||||
|
||||
return dict(
|
||||
(alias, self.resolve_aggregate(val, aggregate, connection=connections[using]))
|
||||
for (alias, aggregate), val
|
||||
in zip(query.aggregate_select.items(), result)
|
||||
(alias, val)
|
||||
for (alias, annotation), val
|
||||
in zip(outer_query.annotation_select.items(), result)
|
||||
)
|
||||
|
||||
def get_count(self, using):
|
||||
@@ -394,7 +388,7 @@ class Query(object):
|
||||
Performs a COUNT() query using the current filter constraints.
|
||||
"""
|
||||
obj = self.clone()
|
||||
if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields):
|
||||
if len(self.select) > 1 or self.annotation_select or (self.distinct and self.distinct_fields):
|
||||
# If a select clause exists, then the query has already started to
|
||||
# specify the columns that are to be returned.
|
||||
# In this case, we need to use a subquery to evaluate the count.
|
||||
@@ -769,9 +763,9 @@ class Query(object):
|
||||
self.group_by = [relabel_column(col) for col in self.group_by]
|
||||
self.select = [SelectInfo(relabel_column(s.col), s.field)
|
||||
for s in self.select]
|
||||
if self._aggregates:
|
||||
self._aggregates = OrderedDict(
|
||||
(key, relabel_column(col)) for key, col in self._aggregates.items())
|
||||
if self._annotations:
|
||||
self._annotations = OrderedDict(
|
||||
(key, relabel_column(col)) for key, col in self._annotations.items())
|
||||
|
||||
# 2. Rename the alias in the internal table/alias datastructures.
|
||||
for ident, aliases in self.join_map.items():
|
||||
@@ -974,52 +968,18 @@ class Query(object):
|
||||
self.included_inherited_models = {}
|
||||
|
||||
def add_aggregate(self, aggregate, model, alias, is_summary):
|
||||
warnings.warn(
|
||||
"add_aggregate() is deprecated. Use add_annotation() instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
self.add_annotation(aggregate, model, alias, is_summary)
|
||||
|
||||
def add_annotation(self, annotation, model, alias, is_summary):
|
||||
"""
|
||||
Adds a single aggregate expression to the Query
|
||||
Adds a single annotation expression to the Query
|
||||
"""
|
||||
opts = model._meta
|
||||
field_list = aggregate.lookup.split(LOOKUP_SEP)
|
||||
if len(field_list) == 1 and self._aggregates and aggregate.lookup in self.aggregates:
|
||||
# Aggregate is over an annotation
|
||||
field_name = field_list[0]
|
||||
col = field_name
|
||||
source = self.aggregates[field_name]
|
||||
if not is_summary:
|
||||
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
|
||||
aggregate.name, field_name, field_name))
|
||||
elif ((len(field_list) > 1) or
|
||||
(field_list[0] not in [i.name for i in opts.fields]) or
|
||||
self.group_by is None or
|
||||
not is_summary):
|
||||
# If:
|
||||
# - the field descriptor has more than one part (foo__bar), or
|
||||
# - the field descriptor is referencing an m2m/m2o field, or
|
||||
# - this is a reference to a model field (possibly inherited), or
|
||||
# - this is an annotation over a model field
|
||||
# then we need to explore the joins that are required.
|
||||
|
||||
# Join promotion note - we must not remove any rows here, so use
|
||||
# outer join if there isn't any existing join.
|
||||
_, sources, opts, join_list, path = self.setup_joins(
|
||||
field_list, opts, self.get_initial_alias())
|
||||
|
||||
# Process the join chain to see if it can be trimmed
|
||||
targets, _, join_list = self.trim_joins(sources, join_list, path)
|
||||
|
||||
col = targets[0].column
|
||||
source = sources[0]
|
||||
col = (join_list[-1], col)
|
||||
else:
|
||||
# The simplest cases. No joins required -
|
||||
# just reference the provided column alias.
|
||||
field_name = field_list[0]
|
||||
source = opts.get_field(field_name)
|
||||
col = field_name
|
||||
# We want to have the alias in SELECT clause even if mask is set.
|
||||
self.append_aggregate_mask([alias])
|
||||
|
||||
# Add the aggregate to the query
|
||||
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
|
||||
annotation = annotation.resolve_expression(self, summarize=is_summary)
|
||||
self.append_annotation_mask([alias])
|
||||
self.annotations[alias] = annotation
|
||||
|
||||
def prepare_lookup_value(self, value, lookups, can_reuse):
|
||||
# Default lookup if none given is exact.
|
||||
@@ -1037,9 +997,8 @@ class Query(object):
|
||||
"Passing callable arguments to queryset is deprecated.",
|
||||
RemovedInDjango19Warning, stacklevel=2)
|
||||
value = value()
|
||||
elif isinstance(value, ExpressionNode):
|
||||
# If value is a query expression, evaluate it
|
||||
value = SQLEvaluator(value, self, reuse=can_reuse)
|
||||
elif hasattr(value, 'resolve_expression'):
|
||||
value = value.resolve_expression(self, reuse=can_reuse)
|
||||
if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'):
|
||||
value = value._clone()
|
||||
value.query.bump_prefix(self)
|
||||
@@ -1061,8 +1020,8 @@ class Query(object):
|
||||
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
|
||||
"""
|
||||
lookup_splitted = lookup.split(LOOKUP_SEP)
|
||||
if self._aggregates:
|
||||
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
|
||||
if self._annotations:
|
||||
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations)
|
||||
if aggregate:
|
||||
return aggregate_lookups, (), aggregate
|
||||
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
|
||||
@@ -1232,7 +1191,11 @@ class Query(object):
|
||||
lookup_type = lookups[-1]
|
||||
else:
|
||||
assert(len(targets) == 1)
|
||||
col = Col(alias, targets[0], field)
|
||||
if hasattr(targets[0], 'as_sql'):
|
||||
# handle Expressions as annotations
|
||||
col = targets[0]
|
||||
else:
|
||||
col = Col(alias, targets[0], field)
|
||||
condition = self.build_lookup(lookups, col, value)
|
||||
if not condition:
|
||||
# Backwards compat for custom lookups
|
||||
@@ -1278,12 +1241,12 @@ class Query(object):
|
||||
Returns whether or not all elements of this q_object need to be put
|
||||
together in the HAVING clause.
|
||||
"""
|
||||
if not self._aggregates:
|
||||
if not self._annotations:
|
||||
return False
|
||||
if not isinstance(obj, Node):
|
||||
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
|
||||
or (hasattr(obj[1], 'contains_aggregate')
|
||||
and obj[1].contains_aggregate(self.aggregates)))
|
||||
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0]
|
||||
or (hasattr(obj[1], 'refs_aggregate')
|
||||
and obj[1].refs_aggregate(self.annotations)[0]))
|
||||
return any(self.need_having(c) for c in obj.children)
|
||||
|
||||
def split_having_parts(self, q_object, negated=False):
|
||||
@@ -1390,13 +1353,21 @@ class Query(object):
|
||||
if name == 'pk':
|
||||
name = opts.pk.name
|
||||
try:
|
||||
field, model, direct, m2m = opts.get_field_by_name(name)
|
||||
field, model, _, _ = opts.get_field_by_name(name)
|
||||
except FieldDoesNotExist:
|
||||
# is it an annotation?
|
||||
if self._annotations and name in self._annotations:
|
||||
field, model = self._annotations[name], None
|
||||
if not field.contains_aggregate:
|
||||
# Local non-relational field.
|
||||
final_field = field
|
||||
targets = (field,)
|
||||
break
|
||||
# We didn't find the current field, so move position back
|
||||
# one step.
|
||||
pos -= 1
|
||||
if pos == -1 or fail_on_missing:
|
||||
available = opts.get_all_field_names() + list(self.aggregate_select)
|
||||
available = opts.get_all_field_names() + list(self.annotation_select)
|
||||
raise FieldError("Cannot resolve keyword %r into field. "
|
||||
"Choices are: %s" % (name, ", ".join(available)))
|
||||
break
|
||||
@@ -1445,6 +1416,11 @@ class Query(object):
|
||||
break
|
||||
return path, final_field, targets, names[pos + 1:]
|
||||
|
||||
def raise_field_error(self, opts, name):
|
||||
available = opts.get_all_field_names() + list(self.annotation_select)
|
||||
raise FieldError("Cannot resolve keyword %r into field. "
|
||||
"Choices are: %s" % (name, ", ".join(available)))
|
||||
|
||||
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
|
||||
"""
|
||||
Compute the necessary table joins for the passage through the fields
|
||||
@@ -1519,6 +1495,29 @@ class Query(object):
|
||||
self.unref_alias(joins.pop())
|
||||
return targets, joins[-1], joins
|
||||
|
||||
def resolve_ref(self, name, allow_joins, reuse, summarize):
|
||||
if not allow_joins and LOOKUP_SEP in name:
|
||||
raise FieldError("Joined field references are not permitted in this query")
|
||||
if name in self.annotations:
|
||||
if summarize:
|
||||
return Ref(name, self.annotation_select[name])
|
||||
else:
|
||||
return self.annotation_select[name]
|
||||
else:
|
||||
field_list = name.split(LOOKUP_SEP)
|
||||
field, sources, opts, join_list, path = self.setup_joins(
|
||||
field_list, self.get_meta(),
|
||||
self.get_initial_alias(), reuse)
|
||||
targets, _, join_list = self.trim_joins(sources, join_list, path)
|
||||
if len(targets) > 1:
|
||||
raise FieldError("Referencing multicolumn fields with F() objects "
|
||||
"isn't supported")
|
||||
if reuse is not None:
|
||||
reuse.update(join_list)
|
||||
col = Col(join_list[-1], targets[0], sources[0])
|
||||
col._used_joins = join_list
|
||||
return col
|
||||
|
||||
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
|
||||
"""
|
||||
When doing an exclude against any kind of N-to-many relation, we need
|
||||
@@ -1633,7 +1632,7 @@ class Query(object):
|
||||
self.default_cols = False
|
||||
self.select_related = False
|
||||
self.set_extra_mask(())
|
||||
self.set_aggregate_mask(())
|
||||
self.set_annotation_mask(())
|
||||
|
||||
def clear_select_fields(self):
|
||||
"""
|
||||
@@ -1676,7 +1675,7 @@ class Query(object):
|
||||
raise
|
||||
else:
|
||||
names = sorted(opts.get_all_field_names() + list(self.extra)
|
||||
+ list(self.aggregate_select))
|
||||
+ list(self.annotation_select))
|
||||
raise FieldError("Cannot resolve keyword %r into field. "
|
||||
"Choices are: %s" % (name, ", ".join(names)))
|
||||
self.remove_inherited_models()
|
||||
@@ -1725,39 +1724,55 @@ class Query(object):
|
||||
for col, _ in self.select:
|
||||
self.group_by.append(col)
|
||||
|
||||
if self._annotations:
|
||||
for alias, annotation in six.iteritems(self.annotations):
|
||||
for col in annotation.get_group_by_cols():
|
||||
self.group_by.append(col)
|
||||
|
||||
def add_count_column(self):
|
||||
"""
|
||||
Converts the query to do count(...) or count(distinct(pk)) in order to
|
||||
get its size.
|
||||
"""
|
||||
summarize = False
|
||||
if not self.distinct:
|
||||
if not self.select:
|
||||
count = self.aggregates_module.Count('*', is_summary=True)
|
||||
count = Count('*')
|
||||
summarize = True
|
||||
else:
|
||||
assert len(self.select) == 1, \
|
||||
"Cannot add count col with multiple cols in 'select': %r" % self.select
|
||||
count = self.aggregates_module.Count(self.select[0].col)
|
||||
col = self.select[0].col
|
||||
if isinstance(col, (tuple, list)):
|
||||
count = Count(col[1])
|
||||
else:
|
||||
count = Count(col)
|
||||
|
||||
else:
|
||||
opts = self.get_meta()
|
||||
if not self.select:
|
||||
count = self.aggregates_module.Count(
|
||||
(self.join((None, opts.db_table, None)), opts.pk.column),
|
||||
is_summary=True, distinct=True)
|
||||
lookup = self.join((None, opts.db_table, None)), opts.pk.column
|
||||
count = Count(lookup[1], distinct=True)
|
||||
summarize = True
|
||||
else:
|
||||
# Because of SQL portability issues, multi-column, distinct
|
||||
# counts need a sub-query -- see get_count() for details.
|
||||
assert len(self.select) == 1, \
|
||||
"Cannot add count col with multiple cols in 'select'."
|
||||
|
||||
count = self.aggregates_module.Count(self.select[0].col, distinct=True)
|
||||
col = self.select[0].col
|
||||
if isinstance(col, (tuple, list)):
|
||||
count = Count(col[1], distinct=True)
|
||||
else:
|
||||
count = Count(col, distinct=True)
|
||||
# Distinct handling is done in Count(), so don't do it at this
|
||||
# level.
|
||||
self.distinct = False
|
||||
|
||||
# Set only aggregate to be the count column.
|
||||
# Clear out the select cache to reflect the new unmasked aggregates.
|
||||
self._aggregates = {None: count}
|
||||
self.set_aggregate_mask(None)
|
||||
# Clear out the select cache to reflect the new unmasked annotations.
|
||||
count = count.resolve_expression(self, summarize=summarize)
|
||||
self._annotations = {None: count}
|
||||
self.set_annotation_mask(None)
|
||||
self.group_by = None
|
||||
|
||||
def add_select_related(self, fields):
|
||||
@@ -1886,16 +1901,28 @@ class Query(object):
|
||||
target[model] = set(f.name for f in fields)
|
||||
|
||||
def set_aggregate_mask(self, names):
|
||||
"Set the mask of aggregates that will actually be returned by the SELECT"
|
||||
warnings.warn(
|
||||
"set_aggregate_mask() is deprecated. Use set_annotation_mask() instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
self.set_annotation_mask(names)
|
||||
|
||||
def set_annotation_mask(self, names):
|
||||
"Set the mask of annotations that will actually be returned by the SELECT"
|
||||
if names is None:
|
||||
self.aggregate_select_mask = None
|
||||
self.annotation_select_mask = None
|
||||
else:
|
||||
self.aggregate_select_mask = set(names)
|
||||
self._aggregate_select_cache = None
|
||||
self.annotation_select_mask = set(names)
|
||||
self._annotation_select_cache = None
|
||||
|
||||
def append_aggregate_mask(self, names):
|
||||
if self.aggregate_select_mask is not None:
|
||||
self.set_aggregate_mask(set(names).union(self.aggregate_select_mask))
|
||||
warnings.warn(
|
||||
"append_aggregate_mask() is deprecated. Use append_annotation_mask() instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
self.append_annotation_mask(names)
|
||||
|
||||
def append_annotation_mask(self, names):
|
||||
if self.annotation_select_mask is not None:
|
||||
self.set_annotation_mask(set(names).union(self.annotation_select_mask))
|
||||
|
||||
def set_extra_mask(self, names):
|
||||
"""
|
||||
@@ -1910,24 +1937,31 @@ class Query(object):
|
||||
self._extra_select_cache = None
|
||||
|
||||
@property
|
||||
def aggregate_select(self):
|
||||
def annotation_select(self):
|
||||
"""The OrderedDict of aggregate columns that are not masked, and should
|
||||
be used in the SELECT clause.
|
||||
|
||||
This result is cached for optimization purposes.
|
||||
"""
|
||||
if self._aggregate_select_cache is not None:
|
||||
return self._aggregate_select_cache
|
||||
elif not self._aggregates:
|
||||
if self._annotation_select_cache is not None:
|
||||
return self._annotation_select_cache
|
||||
elif not self._annotations:
|
||||
return {}
|
||||
elif self.aggregate_select_mask is not None:
|
||||
self._aggregate_select_cache = OrderedDict(
|
||||
(k, v) for k, v in self.aggregates.items()
|
||||
if k in self.aggregate_select_mask
|
||||
elif self.annotation_select_mask is not None:
|
||||
self._annotation_select_cache = OrderedDict(
|
||||
(k, v) for k, v in self.annotations.items()
|
||||
if k in self.annotation_select_mask
|
||||
)
|
||||
return self._aggregate_select_cache
|
||||
return self._annotation_select_cache
|
||||
else:
|
||||
return self.aggregates
|
||||
return self.annotations
|
||||
|
||||
@property
|
||||
def aggregate_select(self):
|
||||
warnings.warn(
|
||||
"aggregate_select() is deprecated. Use annotation_select() instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
return self.annotation_select
|
||||
|
||||
@property
|
||||
def extra_select(self):
|
||||
|
@@ -7,9 +7,9 @@ from django.core.exceptions import FieldError
|
||||
from django.db import connections
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import Date, DateTime, Col
|
||||
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
|
||||
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
|
||||
from django.db.models.sql.datastructures import Date, DateTime
|
||||
from django.db.models.sql.query import Query
|
||||
from django.utils import six
|
||||
from django.utils import timezone
|
||||
@@ -229,7 +229,7 @@ class DateQuery(Query):
|
||||
))
|
||||
self._check_field(field) # overridden in DateTimeQuery
|
||||
alias = joins[-1]
|
||||
select = self._get_select((alias, field.column), lookup_type)
|
||||
select = self._get_select(Col(alias, field), lookup_type)
|
||||
self.clear_select_clause()
|
||||
self.select = [SelectInfo(select, None)]
|
||||
self.distinct = True
|
||||
|
@@ -10,7 +10,6 @@ import warnings
|
||||
from django.conf import settings
|
||||
from django.db.models.fields import DateTimeField, Field
|
||||
from django.db.models.sql.datastructures import EmptyResultSet, Empty
|
||||
from django.db.models.sql.aggregates import Aggregate
|
||||
from django.utils.deprecation import RemovedInDjango19Warning
|
||||
from django.utils.six.moves import xrange
|
||||
from django.utils import timezone
|
||||
@@ -78,7 +77,7 @@ class WhereNode(tree.Node):
|
||||
else:
|
||||
value_annotation = bool(value)
|
||||
|
||||
if hasattr(obj, "prepare"):
|
||||
if hasattr(obj, 'prepare'):
|
||||
value = obj.prepare(lookup_type, value)
|
||||
return (obj, lookup_type, value_annotation, value)
|
||||
|
||||
@@ -187,11 +186,9 @@ class WhereNode(tree.Node):
|
||||
lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
|
||||
except EmptyShortCircuit:
|
||||
raise EmptyResultSet
|
||||
elif isinstance(lvalue, Aggregate):
|
||||
params = lvalue.field.get_db_prep_lookup(lookup_type, params_or_value, connection)
|
||||
else:
|
||||
raise TypeError("'make_atom' expects a Constraint or an Aggregate "
|
||||
"as the first item of its 'child' argument.")
|
||||
raise TypeError("'make_atom' expects a Constraint as the first "
|
||||
"item of its 'child' argument.")
|
||||
|
||||
if isinstance(lvalue, tuple):
|
||||
# A direct database column lookup.
|
||||
|
@@ -86,7 +86,8 @@ manipulating the data of your Web application. Learn more about it below:
|
||||
:doc:`Aggregation <topics/db/aggregation>` |
|
||||
:doc:`Custom fields <howto/custom-model-fields>` |
|
||||
:doc:`Multiple databases <topics/db/multi-db>` |
|
||||
:doc:`Custom lookups <howto/custom-lookups>`
|
||||
:doc:`Custom lookups <howto/custom-lookups>` |
|
||||
:doc:`Query Expressions <ref/models/expressions>`
|
||||
|
||||
* **Other:**
|
||||
:doc:`Supported databases <ref/databases>` |
|
||||
|
@@ -41,6 +41,17 @@ details on these changes.
|
||||
:class:`~django.core.management.BaseCommand` instead, which takes no arguments
|
||||
by default.
|
||||
|
||||
* ``django.db.models.sql.aggregates`` module will be removed.
|
||||
|
||||
* ``django.contrib.gis.db.models.sql.aggregates`` module will be removed.
|
||||
|
||||
* The following methods and properties of ``django.db.sql.query.Query`` will
|
||||
be removed:
|
||||
|
||||
* Properties: ``aggregates`` and ``aggregate_select``
|
||||
* Methods: ``add_aggregate``, ``set_aggregate_mask``, and
|
||||
``append_aggregate_mask``.
|
||||
|
||||
* ``django.template.resolve_variable`` will be removed.
|
||||
|
||||
* The ``error_message`` argument of ``django.forms.RegexField`` will be removed.
|
||||
|
522
docs/ref/models/expressions.txt
Normal file
522
docs/ref/models/expressions.txt
Normal file
@@ -0,0 +1,522 @@
|
||||
=================
|
||||
Query Expressions
|
||||
=================
|
||||
|
||||
.. currentmodule:: django.db.models
|
||||
|
||||
Query expressions describe a value or a computation that can be used as part of
|
||||
a filter, an annotation, or an aggregation. There are a number of built-in
|
||||
expressions (documented below) that can be used to help you write queries.
|
||||
Expressions can be combined, or in some cases nested, to form more complex
|
||||
computations.
|
||||
|
||||
Supported arithmetic
|
||||
====================
|
||||
|
||||
Django supports addition, subtraction, multiplication, division, modulo
|
||||
arithmetic, and the power operator on query expressions, using Python constants,
|
||||
variables, and even other expressions.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
Support for the power operator ``**`` was added.
|
||||
|
||||
Some examples
|
||||
=============
|
||||
|
||||
.. versionchanged:: 1.8
|
||||
|
||||
Some of the examples rely on functionality that is new in Django 1.8.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Find companies that have more employees than chairs.
|
||||
Company.objects.filter(num_employees__gt=F('num_chairs'))
|
||||
|
||||
# Find companies that have at least twice as many employees
|
||||
# as chairs. Both the querysets below are equivalent.
|
||||
Company.objects.filter(num_employees__gt=F('num_chairs') * 2)
|
||||
Company.objects.filter(
|
||||
num_employees__gt=F('num_chairs') + F('num_chairs'))
|
||||
|
||||
# How many chairs are needed for each company to seat all employees?
|
||||
>>> company = Company.objects.filter(
|
||||
... num_employees__gt=F('num_chairs')).annotate(
|
||||
... chairs_needed=F('num_employees') - F('num_chairs')).first()
|
||||
>>> company.num_employees
|
||||
120
|
||||
>>> company.num_chairs
|
||||
50
|
||||
>>> company.chairs_needed
|
||||
70
|
||||
|
||||
# Annotate models with an aggregated value. Both forms
|
||||
# below are equivalent.
|
||||
Company.objects.annotate(num_products=Count('products'))
|
||||
Company.objects.annotate(num_products=Count(F('products')))
|
||||
|
||||
# Aggregates can contain complex computations also
|
||||
Company.objects.annotate(num_offerings=Count(F('products') + F('services')))
|
||||
|
||||
|
||||
Built-in Expressions
|
||||
====================
|
||||
|
||||
``F()`` expressions
|
||||
-------------------
|
||||
|
||||
.. class:: F
|
||||
|
||||
An ``F()`` object represents the value of a model field or annotated column. It
|
||||
makes it possible to refer to model field values and perform database
|
||||
operations using them without actually having to pull them out of the database
|
||||
into Python memory.
|
||||
|
||||
Instead, Django uses the ``F()`` object to generate a SQL expression that
|
||||
describes the required operation at the database level.
|
||||
|
||||
This is easiest to understand through an example. Normally, one might do
|
||||
something like this::
|
||||
|
||||
# Tintin filed a news story!
|
||||
reporter = Reporters.objects.get(name='Tintin')
|
||||
reporter.stories_filed += 1
|
||||
reporter.save()
|
||||
|
||||
Here, we have pulled the value of ``reporter.stories_filed`` from the database
|
||||
into memory and manipulated it using familiar Python operators, and then saved
|
||||
the object back to the database. But instead we could also have done::
|
||||
|
||||
from django.db.models import F
|
||||
reporter = Reporters.objects.get(name='Tintin')
|
||||
reporter.stories_filed = F('stories_filed') + 1
|
||||
reporter.save()
|
||||
|
||||
Although ``reporter.stories_filed = F('stories_filed') + 1`` looks like a
|
||||
normal Python assignment of value to an instance attribute, in fact it's an SQL
|
||||
construct describing an operation on the database.
|
||||
|
||||
When Django encounters an instance of ``F()``, it overrides the standard Python
|
||||
operators to create an encapsulated SQL expression; in this case, one which
|
||||
instructs the database to increment the database field represented by
|
||||
``reporter.stories_filed``.
|
||||
|
||||
Whatever value is or was on ``reporter.stories_filed``, Python never gets to
|
||||
know about it - it is dealt with entirely by the database. All Python does,
|
||||
through Django's ``F()`` class, is create the SQL syntax to refer to the field
|
||||
and describe the operation.
|
||||
|
||||
.. note::
|
||||
|
||||
In order to access the new value that has been saved in this way, the object
|
||||
will need to be reloaded::
|
||||
|
||||
reporter = Reporters.objects.get(pk=reporter.pk)
|
||||
|
||||
As well as being used in operations on single instances as above, ``F()`` can
|
||||
be used on ``QuerySets`` of object instances, with ``update()``. This reduces
|
||||
the two queries we were using above - the ``get()`` and the
|
||||
:meth:`~Model.save()` - to just one::
|
||||
|
||||
reporter = Reporters.objects.filter(name='Tintin')
|
||||
reporter.update(stories_filed=F('stories_filed') + 1)
|
||||
|
||||
We can also use :meth:`~django.db.models.query.QuerySet.update()` to increment
|
||||
the field value on multiple objects - which could be very much faster than
|
||||
pulling them all into Python from the database, looping over them, incrementing
|
||||
the field value of each one, and saving each one back to the database::
|
||||
|
||||
Reporter.objects.all().update(stories_filed=F('stories_filed) + 1)
|
||||
|
||||
``F()`` therefore can offer performance advantages by:
|
||||
|
||||
* getting the database, rather than Python, to do work
|
||||
* reducing the number of queries some operations require
|
||||
|
||||
.. _avoiding-race-conditions-using-f:
|
||||
|
||||
Avoiding race conditions using ``F()``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Another useful benefit of ``F()`` is that having the database - rather than
|
||||
Python - update a field's value avoids a *race condition*.
|
||||
|
||||
If two Python threads execute the code in the first example above, one thread
|
||||
could retrieve, increment, and save a field's value after the other has
|
||||
retrieved it from the database. The value that the second thread saves will be
|
||||
based on the original value; the work of the first thread will simply be lost.
|
||||
|
||||
If the database is responsible for updating the field, the process is more
|
||||
robust: it will only ever update the field based on the value of the field in
|
||||
the database when the :meth:`~Model.save()` or ``update()`` is executed, rather
|
||||
than based on its value when the instance was retrieved.
|
||||
|
||||
Using ``F()`` in filters
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
``F()`` is also very useful in ``QuerySet`` filters, where they make it
|
||||
possible to filter a set of objects against criteria based on their field
|
||||
values, rather than on Python values.
|
||||
|
||||
This is documented in :ref:`using F() expressions in queries
|
||||
<using-f-expressions-in-filters>`.
|
||||
|
||||
|
||||
.. _func-expressions:
|
||||
|
||||
``Func()`` expressions
|
||||
----------------------
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
``Func()`` expressions are the base type of all expressions that involve
|
||||
database functions like ``COALESCE`` and ``LOWER``, or aggregates like ``SUM``.
|
||||
They can be used directly::
|
||||
|
||||
queryset.annotate(field_lower=Func(F('field'), function='LOWER'))
|
||||
|
||||
or they can be used to build a library of database functions::
|
||||
|
||||
class Lower(Func):
|
||||
function = 'LOWER'
|
||||
|
||||
queryset.annotate(field_lower=Lower(F('field')))
|
||||
|
||||
But both cases will result in a queryset where each model is annotated with an
|
||||
extra attribute ``field_lower`` produced, roughly, from the following SQL::
|
||||
|
||||
SELECT
|
||||
...
|
||||
LOWER("app_label"."field") as "field_lower"
|
||||
|
||||
The ``Func`` API is as follows:
|
||||
|
||||
.. class:: Func(*expressions, **extra)
|
||||
|
||||
.. attribute:: function
|
||||
|
||||
A class attribute describing the function that will be generated.
|
||||
Specifically, the ``function`` will be interpolated as the ``function``
|
||||
placeholder within :attr:`template`. Defaults to ``None``.
|
||||
|
||||
.. attribute:: template
|
||||
|
||||
A class attribute, as a format string, that describes the SQL that is
|
||||
generated for this function. Defaults to
|
||||
``'%(function)s(%(expressions)s)'``.
|
||||
|
||||
.. attribute:: arg_joiner
|
||||
|
||||
A class attribute that denotes the character used to join the list of
|
||||
``expressions`` together. Defaults to ``', '``.
|
||||
|
||||
The ``*expressions`` argument is a list of positional expressions that the
|
||||
function will be applied to. The expressions will be converted to strings,
|
||||
joined together with ``arg_joiner``, and then interpolated into the ``template``
|
||||
as the ``expressions`` placeholder.
|
||||
|
||||
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
||||
into the ``template`` attribute. Note that the keywords ``function`` and
|
||||
``template`` can be used to replace the ``function`` and ``template``
|
||||
attributes respectively, without having to define your own class.
|
||||
``output_field`` can be used to define the expected return type.
|
||||
|
||||
``Aggregate()`` expressions
|
||||
---------------------------
|
||||
|
||||
An aggregate expression is a special case of a :ref:`Func() expression
|
||||
<func-expressions>` that informs the query that a ``GROUP BY`` clause
|
||||
is required. All of the :ref:`aggregate functions <aggregation-functions>`,
|
||||
like ``Sum()`` and ``Count()``, inherit from ``Aggregate()``.
|
||||
|
||||
Since ``Aggregate``\s are expressions and wrap expressions, you can represent
|
||||
some complex computations::
|
||||
|
||||
Company.objects.annotate(
|
||||
managers_required=(Count('num_employees') / 4) + Count('num_managers'))
|
||||
|
||||
The ``Aggregate`` API is as follows:
|
||||
|
||||
.. class:: Aggregate(expression, output_field=None, **extra)
|
||||
|
||||
.. attribute:: template
|
||||
|
||||
A class attribute, as a format string, that describes the SQL that is
|
||||
generated for this aggregate. Defaults to
|
||||
``'%(function)s( %(expressions)s )'``.
|
||||
|
||||
.. attribute:: function
|
||||
|
||||
A class attribute describing the aggregate function that will be
|
||||
generated. Specifically, the ``function`` will be interpolated as the
|
||||
``function`` placeholder within :attr:`template`. Defaults to ``None``.
|
||||
|
||||
The ``expression`` argument can be the name of a field on the model, or another
|
||||
expression. It will be converted to a string and used as the ``expressions``
|
||||
placeholder within the ``template``.
|
||||
|
||||
The ``output_field`` argument requires a model field instance, like
|
||||
``IntegerField()`` or ``BooleanField()``, into which Django will load the value
|
||||
after it's retrieved from the database.
|
||||
|
||||
Note that ``output_field`` is only required when Django is unable to determine
|
||||
what field type the result should be. Complex expressions that mix field types
|
||||
should define the desired ``output_field``. For example, adding an
|
||||
``IntegerField()`` and a ``FloatField()`` together should probably have
|
||||
``output_field=FloatField()`` defined.
|
||||
|
||||
.. versionchanged:: 1.8
|
||||
|
||||
``output_field`` is a new parameter.
|
||||
|
||||
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
||||
into the ``template`` attribute.
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
Aggregate functions can now use arithmetic and reference multiple
|
||||
model fields in a single function.
|
||||
|
||||
Creating your own Aggregate Functions
|
||||
-------------------------------------
|
||||
|
||||
Creating your own aggregate is extremely easy. At a minimum, you need
|
||||
to define ``function``, but you can also completely customize the
|
||||
SQL that is generated. Here's a brief example::
|
||||
|
||||
class Count(Aggregate):
|
||||
# supports COUNT(distinct field)
|
||||
function = 'COUNT'
|
||||
template = '%(function)s(%(distinct)s%(expressions)s)'
|
||||
|
||||
def __init__(self, expression, distinct=False, **extra):
|
||||
super(Count, self).__init__(
|
||||
expression,
|
||||
distinct='DISTINCT ' if distinct else '',
|
||||
output_field=IntegerField(),
|
||||
**extra)
|
||||
|
||||
|
||||
``Value()`` expressions
|
||||
-----------------------
|
||||
|
||||
.. class:: Value(value, output_field=None)
|
||||
|
||||
|
||||
A ``Value()`` object represents the smallest possible component of an
|
||||
expression: a simple value. When you need to represent the value of an integer,
|
||||
boolean, or string within an expression, you can wrap that value within a
|
||||
``Value()``.
|
||||
|
||||
You will rarely need to use ``Value()`` directly. When you write the expression
|
||||
``F('field') + 1``, Django implicitly wraps the ``1`` in a ``Value()``,
|
||||
allowing simple values to be used in more complex expressions.
|
||||
|
||||
The ``value`` argument describes the value to be included in the expression,
|
||||
such as ``1``, ``True``, or ``None``. Django knows how to convert these Python
|
||||
values into their corresponding database type.
|
||||
|
||||
The ``output_field`` argument should be a model field instance, like
|
||||
``IntegerField()`` or ``BooleanField()``, into which Django will load the value
|
||||
after it's retrieved from the database.
|
||||
|
||||
|
||||
Technical Information
|
||||
=====================
|
||||
|
||||
Below you'll find technical implementation details that may be useful to
|
||||
library authors. The technical API and examples below will help with
|
||||
creating generic query expressions that can extend the built-in functionality
|
||||
that Django provides.
|
||||
|
||||
Expression API
|
||||
--------------
|
||||
|
||||
Query expressions implement the :ref:`query expression API <query-expression>`,
|
||||
but also expose a number of extra methods and attributes listed below. All
|
||||
query expressions must inherit from ``ExpressionNode()`` or a relevant
|
||||
subclass.
|
||||
|
||||
When a query expression wraps another expression, it is responsible for
|
||||
calling the appropriate methods on the wrapped expression.
|
||||
|
||||
.. class:: ExpressionNode
|
||||
|
||||
.. attribute:: contains_aggregate
|
||||
|
||||
Tells Django that this expression contains an aggregate and that a
|
||||
``GROUP BY`` clause needs to be added to the query.
|
||||
|
||||
.. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False)
|
||||
|
||||
Provides the chance to do any pre-processing or validation of
|
||||
the expression before it's added to the query. ``resolve_expression()``
|
||||
must also be called on any nested expressions. A ``copy()`` of ``self``
|
||||
should be returned with any necessary transformations.
|
||||
|
||||
``query`` is the backend query implementation.
|
||||
|
||||
``allow_joins`` is a boolean that allows or denies the use of
|
||||
joins in the query.
|
||||
|
||||
``reuse`` is a set of reusable joins for multi-join scenarios.
|
||||
|
||||
``summarize`` is a boolean that, when ``True``, signals that the
|
||||
query being computed is a terminal aggregate query.
|
||||
|
||||
.. method:: get_source_expressions()
|
||||
|
||||
Returns an ordered list of inner expressions. For example::
|
||||
|
||||
>>> Sum(F('foo')).get_source_expressions()
|
||||
[F('foo')]
|
||||
|
||||
.. method:: set_source_expressions(expressions)
|
||||
|
||||
Takes a list of expressions and stores them such that
|
||||
``get_source_expressions()`` can return them.
|
||||
|
||||
.. method:: relabeled_clone(change_map)
|
||||
|
||||
Returns a clone (copy) of ``self``, with any column aliases relabeled.
|
||||
Column aliases are renamed when subqueries are created.
|
||||
``relabeled_clone()`` should also be called on any nested expressions
|
||||
and assigned to the clone.
|
||||
|
||||
``change_map`` is a dictionary mapping old aliases to new aliases.
|
||||
|
||||
Example::
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = copy.copy(self)
|
||||
clone.expression = self.expression.relabeled_clone(change_map)
|
||||
return clone
|
||||
|
||||
.. method:: convert_value(self, value, connection)
|
||||
|
||||
A hook allowing the expression to coerce ``value`` into a more
|
||||
appropriate type.
|
||||
|
||||
.. method:: refs_aggregate(existing_aggregates)
|
||||
|
||||
Returns a tuple containing the ``(aggregate, lookup_path)`` of the
|
||||
first aggregate that this expression (or any nested expression)
|
||||
references, or ``(False, ())`` if no aggregate is referenced.
|
||||
For example::
|
||||
|
||||
queryset.filter(num_chairs__gt=F('sum__employees'))
|
||||
|
||||
The ``F()`` expression here references a previous ``Sum()``
|
||||
computation which means that this filter expression should be
|
||||
added to the ``HAVING`` clause rather than the ``WHERE`` clause.
|
||||
|
||||
In the majority of cases, returning the result of ``refs_aggregate``
|
||||
on any nested expression should be appropriate, as the necessary
|
||||
built-in expressions will return the correct values.
|
||||
|
||||
.. method:: get_group_by_cols()
|
||||
|
||||
Responsible for returning the list of columns references by
|
||||
this expression. ``get_group_by_cols()`` should be called on any
|
||||
nested expressions. ``F()`` objects, in particular, hold a reference
|
||||
to a column.
|
||||
|
||||
Writing your own Query Expressions
|
||||
----------------------------------
|
||||
|
||||
You can write your own query expression classes that use, and can integrate
|
||||
with, other query expressions. Let's step through an example by writing an
|
||||
implementation of the ``COALESCE`` SQL function, without using the built-in
|
||||
:ref:`Func() expressions <func-expressions>`.
|
||||
|
||||
The ``COALESCE`` SQL function is defined as taking a list of columns or
|
||||
values. It will return the first column or value that isn't ``NULL``.
|
||||
|
||||
We'll start by defining the template to be used for SQL generation and
|
||||
an ``__init__()`` method to set some attributes::
|
||||
|
||||
import copy
|
||||
from django.db.models import ExpressionNode
|
||||
|
||||
class Coalesce(ExpressionNode):
|
||||
template = 'COALESCE( %(expressions)s )'
|
||||
|
||||
def __init__(self, expressions, output_field, **extra):
|
||||
super(Coalesce, self).__init__(output_field=output_field)
|
||||
if len(expressions) < 2:
|
||||
raise ValueError('expressions must have at least 2 elements')
|
||||
for expression in expressions:
|
||||
if not hasattr(expression, 'resolve_expression'):
|
||||
raise TypeError('%r is not an Expression' % expression)
|
||||
self.expressions = expressions
|
||||
self.extra = extra
|
||||
|
||||
We do some basic validation on the parameters, including requiring at least
|
||||
2 columns or values, and ensuring they are expressions. We are requiring
|
||||
``output_field`` here so that Django knows what kind of model field to assign
|
||||
the eventual result to.
|
||||
|
||||
Now we implement the pre-processing and validation. Since we do not have
|
||||
any of our own validation at this point, we just delegate to the nested
|
||||
expressions::
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
for pos, expression in enumerate(self.expressions):
|
||||
c.expressions[pos] = expression.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
return c
|
||||
|
||||
Next, we write the method responsible for generating the SQL::
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql_expressions, sql_params = [], []
|
||||
for expression in self.expressions:
|
||||
sql, params = compiler.compile(expression)
|
||||
sql_expressions.append(sql)
|
||||
sql_params.extend(params)
|
||||
self.extra['expressions'] = ','.join(sql_expressions)
|
||||
return self.template % self.extra, sql_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
"""
|
||||
Example of vendor specific handling (Oracle in this case).
|
||||
Let's make the function name lowercase.
|
||||
"""
|
||||
self.template = 'coalesce( %(expressions)s )'
|
||||
return self.as_sql(compiler, connection)
|
||||
|
||||
We generate the SQL for each of the ``expressions`` by using the
|
||||
``compiler.compile()`` method, and join the result together with commas.
|
||||
Then the template is filled out with our data and the SQL and parameters
|
||||
are returned.
|
||||
|
||||
We've also defined a custom implementation that is specific to the Oracle
|
||||
backend. The ``as_oracle()`` function will be called instead of ``as_sql()``
|
||||
if the Oracle backend is in use.
|
||||
|
||||
Finally, we implement the rest of the methods that allow our query expression
|
||||
to play nice with other query expressions::
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.expressions
|
||||
|
||||
def set_source_expressions(expressions):
|
||||
self.expressions = expressions
|
||||
|
||||
Let's see how it works::
|
||||
|
||||
>>> qs = Company.objects.annotate(
|
||||
... tagline=Coalesce([
|
||||
... F('motto'),
|
||||
... F('ticker_name'),
|
||||
... F('description'),
|
||||
... Value('No Tagline')
|
||||
... ], output_field=CharField()))
|
||||
>>> for c in qs:
|
||||
... print("%s: %s" % (c.name, c.tagline))
|
||||
...
|
||||
Google: Do No Evil
|
||||
Apple: AAPL
|
||||
Yahoo: Internet Company
|
||||
Django Software Foundation: No Tagline
|
@@ -15,3 +15,4 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`.
|
||||
querysets
|
||||
queries
|
||||
lookups
|
||||
expressions
|
||||
|
@@ -7,115 +7,6 @@ Query-related classes
|
||||
This document provides reference material for query-related tools not
|
||||
documented elsewhere.
|
||||
|
||||
``F()`` expressions
|
||||
===================
|
||||
|
||||
.. class:: F
|
||||
|
||||
An ``F()`` object represents the value of a model field. It makes it possible
|
||||
to refer to model field values and perform database operations using them
|
||||
without actually having to pull them out of the database into Python memory.
|
||||
|
||||
Instead, Django uses the ``F()`` object to generate a SQL expression that
|
||||
describes the required operation at the database level.
|
||||
|
||||
This is easiest to understand through an example. Normally, one might do
|
||||
something like this::
|
||||
|
||||
# Tintin filed a news story!
|
||||
reporter = Reporters.objects.get(name='Tintin')
|
||||
reporter.stories_filed += 1
|
||||
reporter.save()
|
||||
|
||||
Here, we have pulled the value of ``reporter.stories_filed`` from the database
|
||||
into memory and manipulated it using familiar Python operators, and then saved
|
||||
the object back to the database. But instead we could also have done::
|
||||
|
||||
from django.db.models import F
|
||||
reporter = Reporters.objects.get(name='Tintin')
|
||||
reporter.stories_filed = F('stories_filed') + 1
|
||||
reporter.save()
|
||||
|
||||
Although ``reporter.stories_filed = F('stories_filed') + 1`` looks like a
|
||||
normal Python assignment of value to an instance attribute, in fact it's an SQL
|
||||
construct describing an operation on the database.
|
||||
|
||||
When Django encounters an instance of ``F()``, it overrides the standard Python
|
||||
operators to create an encapsulated SQL expression; in this case, one which
|
||||
instructs the database to increment the database field represented by
|
||||
``reporter.stories_filed``.
|
||||
|
||||
Whatever value is or was on ``reporter.stories_filed``, Python never gets to
|
||||
know about it - it is dealt with entirely by the database. All Python does,
|
||||
through Django's ``F()`` class, is create the SQL syntax to refer to the field
|
||||
and describe the operation.
|
||||
|
||||
.. note::
|
||||
|
||||
In order to access the new value that has been saved in this way, the object
|
||||
will need to be reloaded::
|
||||
|
||||
reporter = Reporters.objects.get(pk=reporter.pk)
|
||||
|
||||
As well as being used in operations on single instances as above, ``F()`` can
|
||||
be used on ``QuerySets`` of object instances, with ``update()``. This reduces
|
||||
the two queries we were using above - the ``get()`` and the
|
||||
:meth:`~Model.save()` - to just one::
|
||||
|
||||
reporter = Reporters.objects.filter(name='Tintin')
|
||||
reporter.update(stories_filed=F('stories_filed') + 1)
|
||||
|
||||
We can also use :meth:`~django.db.models.query.QuerySet.update()` to increment
|
||||
the field value on multiple objects - which could be very much faster than
|
||||
pulling them all into Python from the database, looping over them, incrementing
|
||||
the field value of each one, and saving each one back to the database::
|
||||
|
||||
Reporter.objects.all().update(stories_filed=F('stories_filed') + 1)
|
||||
|
||||
``F()`` therefore can offer performance advantages by:
|
||||
|
||||
* getting the database, rather than Python, to do work
|
||||
* reducing the number of queries some operations require
|
||||
|
||||
.. _avoiding-race-conditions-using-f:
|
||||
|
||||
Avoiding race conditions using ``F()``
|
||||
--------------------------------------
|
||||
|
||||
Another useful benefit of ``F()`` is that having the database - rather than
|
||||
Python - update a field's value avoids a *race condition*.
|
||||
|
||||
If two Python threads execute the code in the first example above, one thread
|
||||
could retrieve, increment, and save a field's value after the other has
|
||||
retrieved it from the database. The value that the second thread saves will be
|
||||
based on the original value; the work of the first thread will simply be lost.
|
||||
|
||||
If the database is responsible for updating the field, the process is more
|
||||
robust: it will only ever update the field based on the value of the field in
|
||||
the database when the :meth:`~Model.save()` or ``update()`` is executed, rather
|
||||
than based on its value when the instance was retrieved.
|
||||
|
||||
Using ``F()`` in filters
|
||||
------------------------
|
||||
|
||||
``F()`` is also very useful in ``QuerySet`` filters, where they make it
|
||||
possible to filter a set of objects against criteria based on their field
|
||||
values, rather than on Python values.
|
||||
|
||||
This is documented in :ref:`using F() expressions in queries
|
||||
<using-f-expressions-in-filters>`
|
||||
|
||||
Supported operations with ``F()``
|
||||
---------------------------------
|
||||
|
||||
As well as addition, Django supports subtraction, multiplication, division,
|
||||
and modulo arithmetic with ``F()`` objects, using Python constants,
|
||||
variables, and even other ``F()`` objects.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
The power operator ``**`` is also supported.
|
||||
|
||||
``Q()`` objects
|
||||
===============
|
||||
|
||||
|
@@ -220,9 +220,18 @@ annotate
|
||||
|
||||
.. method:: annotate(*args, **kwargs)
|
||||
|
||||
Annotates each object in the ``QuerySet`` with the provided list of
|
||||
aggregate values (averages, sums, etc) that have been computed over
|
||||
the objects that are related to the objects in the ``QuerySet``.
|
||||
Annotates each object in the ``QuerySet`` with the provided list of :doc:`query
|
||||
expressions </ref/models/expressions>`. An expression may be a simple value, a
|
||||
reference to a field on the model (or any related models), or an aggregate
|
||||
expression (averages, sums, etc) that has been computed over the objects that
|
||||
are related to the objects in the ``QuerySet``.
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
Previous versions of Django only allowed aggregate functions to be used as
|
||||
annotations. It is now possible to annotate a model with all kinds of
|
||||
expressions.
|
||||
|
||||
Each argument to ``annotate()`` is an annotation that will be added
|
||||
to each object in the ``QuerySet`` that is returned.
|
||||
|
||||
@@ -232,7 +241,9 @@ in `Aggregation Functions`_ below.
|
||||
Annotations specified using keyword arguments will use the keyword as
|
||||
the alias for the annotation. Anonymous arguments will have an alias
|
||||
generated for them based upon the name of the aggregate function and
|
||||
the model field that is being aggregated.
|
||||
the model field that is being aggregated. Only aggregate expressions
|
||||
that reference a single field can be anonymous arguments. Everything
|
||||
else must be a keyword argument.
|
||||
|
||||
For example, if you were manipulating a list of blogs, you may want
|
||||
to determine how many entries have been made in each blog::
|
||||
@@ -1886,12 +1897,15 @@ the ``QuerySet``. Each argument to ``aggregate()`` specifies a value that will
|
||||
be included in the dictionary that is returned.
|
||||
|
||||
The aggregation functions that are provided by Django are described in
|
||||
`Aggregation Functions`_ below.
|
||||
`Aggregation Functions`_ below. Since aggregates are also :doc:`query
|
||||
expressions </ref/models/expressions>`, you may combine aggregates with other
|
||||
aggregates or values to create complex aggregates.
|
||||
|
||||
Aggregates specified using keyword arguments will use the keyword as the name
|
||||
for the annotation. Anonymous arguments will have a name generated for them
|
||||
based upon the name of the aggregate function and the model field that is being
|
||||
aggregated.
|
||||
aggregated. Complex aggregates cannot use anonymous arguments and must specify
|
||||
a keyword argument as an alias.
|
||||
|
||||
For example, when you are working with blog entries, you may want to know the
|
||||
number of authors that have contributed blog entries::
|
||||
@@ -2667,8 +2681,9 @@ Aggregation functions
|
||||
|
||||
Django provides the following aggregation functions in the
|
||||
``django.db.models`` module. For details on how to use these
|
||||
aggregate functions, see
|
||||
:doc:`the topic guide on aggregation </topics/db/aggregation>`.
|
||||
aggregate functions, see :doc:`the topic guide on aggregation
|
||||
</topics/db/aggregation>`. See the :class:`~django.db.models.Aggregate`
|
||||
documentation to learn how to create your aggregates.
|
||||
|
||||
.. warning::
|
||||
|
||||
@@ -2685,12 +2700,47 @@ aggregate functions, see
|
||||
instead of ``0`` if the ``QuerySet`` contains no entries. An exception is
|
||||
``Count``, which does return ``0`` if the ``QuerySet`` is empty.
|
||||
|
||||
All aggregates have the following parameters in common:
|
||||
|
||||
``expression``
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
A string that references a field on the model, or a :doc:`query expression
|
||||
</ref/models/expressions>`.
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
Aggregate functions are now able to reference multiple fields in complex
|
||||
computations.
|
||||
|
||||
``output_field``
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
An optional argument that represents the :doc:`model field </ref/models/fields>`
|
||||
of the return value
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
The ``output_field`` argument was added.
|
||||
|
||||
.. note::
|
||||
|
||||
When combining multiple field types, Django can only determine the
|
||||
``output_field`` if all fields are of the same type. Otherwise, you
|
||||
must provide the ``output_field`` yourself.
|
||||
|
||||
``**extra``
|
||||
~~~~~~~~~~~
|
||||
|
||||
Keyword arguments that can provide extra context for the SQL generated
|
||||
by the aggregate.
|
||||
|
||||
Avg
|
||||
~~~
|
||||
|
||||
.. class:: Avg(field)
|
||||
.. class:: Avg(expression, output_field=None, **extra)
|
||||
|
||||
Returns the mean value of the given field, which must be numeric.
|
||||
Returns the mean value of the given expression, which must be numeric.
|
||||
|
||||
* Default alias: ``<field>__avg``
|
||||
* Return type: ``float``
|
||||
@@ -2698,9 +2748,10 @@ Avg
|
||||
Count
|
||||
~~~~~
|
||||
|
||||
.. class:: Count(field, distinct=False)
|
||||
.. class:: Count(expression, distinct=False, **extra)
|
||||
|
||||
Returns the number of objects that are related through the provided field.
|
||||
Returns the number of objects that are related through the provided
|
||||
expression.
|
||||
|
||||
* Default alias: ``<field>__count``
|
||||
* Return type: ``int``
|
||||
@@ -2716,29 +2767,29 @@ Count
|
||||
Max
|
||||
~~~
|
||||
|
||||
.. class:: Max(field)
|
||||
.. class:: Max(expression, output_field=None, **extra)
|
||||
|
||||
Returns the maximum value of the given field.
|
||||
Returns the maximum value of the given expression.
|
||||
|
||||
* Default alias: ``<field>__max``
|
||||
* Return type: same as input field
|
||||
* Return type: same as input field, or ``output_field`` if supplied
|
||||
|
||||
Min
|
||||
~~~
|
||||
|
||||
.. class:: Min(field)
|
||||
.. class:: Min(expression, output_field=None, **extra)
|
||||
|
||||
Returns the minimum value of the given field.
|
||||
Returns the minimum value of the given expression.
|
||||
|
||||
* Default alias: ``<field>__min``
|
||||
* Return type: same as input field
|
||||
* Return type: same as input field, or ``output_field`` if supplied
|
||||
|
||||
StdDev
|
||||
~~~~~~
|
||||
|
||||
.. class:: StdDev(field, sample=False)
|
||||
.. class:: StdDev(expression, sample=False, **extra)
|
||||
|
||||
Returns the standard deviation of the data in the provided field.
|
||||
Returns the standard deviation of the data in the provided expression.
|
||||
|
||||
* Default alias: ``<field>__stddev``
|
||||
* Return type: ``float``
|
||||
@@ -2760,19 +2811,19 @@ StdDev
|
||||
Sum
|
||||
~~~
|
||||
|
||||
.. class:: Sum(field)
|
||||
.. class:: Sum(expression, output_field=None, **extra)
|
||||
|
||||
Computes the sum of all values of the given field.
|
||||
Computes the sum of all values of the given expression.
|
||||
|
||||
* Default alias: ``<field>__sum``
|
||||
* Return type: same as input field
|
||||
* Return type: same as input field, or ``output_field`` if supplied
|
||||
|
||||
Variance
|
||||
~~~~~~~~
|
||||
|
||||
.. class:: Variance(field, sample=False)
|
||||
.. class:: Variance(expression, sample=False, **extra)
|
||||
|
||||
Returns the variance of the data in the provided field.
|
||||
Returns the variance of the data in the provided expression.
|
||||
|
||||
* Default alias: ``<field>__variance``
|
||||
* Return type: ``float``
|
||||
|
@@ -52,6 +52,15 @@ New data types
|
||||
<django.forms.UUIDField>`. It is stored as the native ``uuid`` data type on
|
||||
PostgreSQL and as a fixed length character field on other backends.
|
||||
|
||||
Query Expressions
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
:doc:`Query Expressions </ref/models/expressions>` allow users to create,
|
||||
customize, and compose complex SQL expressions. This has enabled annotate
|
||||
to accept expressions other than aggregates. Aggregates are now able to
|
||||
reference multiple fields, as well as perform arithmetic, similar to ``F()``
|
||||
objects.
|
||||
|
||||
Minor features
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
@@ -857,6 +866,29 @@ or ``name='django.contrib.gis.sitemaps.views.kmz'``.
|
||||
|
||||
.. _security issue: https://www.djangoproject.com/weblog/2014/apr/21/security/#s-issue-unexpected-code-execution-using-reverse
|
||||
|
||||
Aggregate methods and modules
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The ``django.db.models.sql.aggregates`` and
|
||||
``django.contrib.gis.db.models.sql.aggregates`` modules (both private API), have
|
||||
been deprecated as ``django.db.models.aggregates`` and
|
||||
``django.contrib.gis.db.models.aggregates`` are now also responsible
|
||||
for SQL generation. The old modules will be removed in Django 2.0.
|
||||
|
||||
If you were using the old modules, see :doc:`Query Expressions
|
||||
</ref/models/expressions>` for instructions on rewriting custom aggregates
|
||||
using the new stable API.
|
||||
|
||||
The following methods and properties of ``django.db.models.sql.query.Query``
|
||||
have also been deprecated and the backwards compatibility shims will be removed
|
||||
in Django 2.0:
|
||||
|
||||
* ``Query.aggregates``, replaced by ``annotations``.
|
||||
* ``Query.aggregate_select``, replaced by ``annotation_select``.
|
||||
* ``Query.add_aggregate()``, replaced by ``add_annotation()``.
|
||||
* ``Query.set_aggregate_mask()``, replaced by ``set_annotation_mask()``.
|
||||
* ``Query.append_aggregate_mask()``, replaced by ``append_annotation_mask()``.
|
||||
|
||||
Extending management command arguments through ``Command.option_list``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@@ -67,6 +67,11 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above
|
||||
>>> Book.objects.all().aggregate(Max('price'))
|
||||
{'price__max': Decimal('81.20')}
|
||||
|
||||
# Cost per page
|
||||
>>> Book.objects.all().aggregate(
|
||||
... price_per_page=Sum(F('price')/F('pages'), output_field=FloatField()))
|
||||
{'price_per_page': 0.4470664529184653}
|
||||
|
||||
# All the following queries involve traversing the Book<->Publisher
|
||||
# many-to-many relationship backward
|
||||
|
||||
|
@@ -3,12 +3,21 @@ from __future__ import unicode_literals
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import connection
|
||||
from django.db.models import Avg, Sum, Count, Max, Min
|
||||
from django.db.models import (
|
||||
Avg, Sum, Count, Max, Min,
|
||||
Aggregate, F, Value, Func,
|
||||
IntegerField, FloatField, DecimalField)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
from django.db.models.sql import aggregates as sql_aggregates
|
||||
from django.test import TestCase
|
||||
from django.test.utils import Approximate
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
from django.utils.deprecation import RemovedInDjango20Warning
|
||||
|
||||
from .models import Author, Publisher, Book, Store
|
||||
|
||||
@@ -678,3 +687,271 @@ class BaseAggregateTestCase(TestCase):
|
||||
else:
|
||||
self.assertNotIn('order by', qstr)
|
||||
self.assertEqual(qstr.count(' join '), 0)
|
||||
|
||||
|
||||
class ComplexAggregateTestCase(TestCase):
|
||||
fixtures = ["aggregation.json"]
|
||||
|
||||
def test_nonaggregate_aggregation_throws(self):
|
||||
with self.assertRaisesRegexp(TypeError, 'fail is not an aggregate expression'):
|
||||
Book.objects.aggregate(fail=F('price'))
|
||||
|
||||
def test_nonfield_annotation(self):
|
||||
book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField())))[0]
|
||||
self.assertEqual(book.val, 2)
|
||||
book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField()))[0]
|
||||
self.assertEqual(book.val, 2)
|
||||
|
||||
def test_missing_output_field_raises_error(self):
|
||||
with self.assertRaisesRegexp(FieldError, 'Cannot resolve expression type, unknown output_field'):
|
||||
Book.objects.annotate(val=Max(Value(2)))[0]
|
||||
|
||||
def test_annotation_expressions(self):
|
||||
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
|
||||
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
|
||||
for qs in (authors, authors2):
|
||||
self.assertEqual(len(qs), 9)
|
||||
self.assertQuerysetEqual(
|
||||
qs, [
|
||||
('Adrian Holovaty', 132),
|
||||
('Brad Dayley', None),
|
||||
('Jacob Kaplan-Moss', 129),
|
||||
('James Bennett', 63),
|
||||
('Jeffrey Forcier', 128),
|
||||
('Paul Bissex', 120),
|
||||
('Peter Norvig', 103),
|
||||
('Stuart Russell', 103),
|
||||
('Wesley J. Chun', 176)
|
||||
],
|
||||
lambda a: (a.name, a.combined_ages)
|
||||
)
|
||||
|
||||
def test_aggregation_expressions(self):
|
||||
a1 = Author.objects.aggregate(av_age=Sum('age') / Count('*'))
|
||||
a2 = Author.objects.aggregate(av_age=Sum('age') / Count('age'))
|
||||
a3 = Author.objects.aggregate(av_age=Avg('age'))
|
||||
self.assertEqual(a1, {'av_age': 37})
|
||||
self.assertEqual(a2, {'av_age': 37})
|
||||
self.assertEqual(a3, {'av_age': Approximate(37.4, places=1)})
|
||||
|
||||
def test_order_of_precedence(self):
|
||||
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
|
||||
self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)})
|
||||
|
||||
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
|
||||
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
|
||||
|
||||
def test_combine_different_types(self):
|
||||
with self.assertRaisesRegexp(FieldError, 'Expression contains mixed types. You must set output_field'):
|
||||
Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')).get(pk=4)
|
||||
|
||||
b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
||||
output_field=IntegerField())).get(pk=4)
|
||||
self.assertEqual(b1.sums, 383)
|
||||
|
||||
b2 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
||||
output_field=FloatField())).get(pk=4)
|
||||
self.assertEqual(b2.sums, 383.69)
|
||||
|
||||
b3 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
||||
output_field=DecimalField(max_digits=6, decimal_places=2))).get(pk=4)
|
||||
self.assertEqual(b3.sums, Decimal("383.69"))
|
||||
|
||||
def test_complex_aggregations_require_kwarg(self):
|
||||
with self.assertRaisesRegexp(TypeError, 'Complex expressions require an alias'):
|
||||
Author.objects.annotate(Sum(F('age') + F('friends__age')))
|
||||
with self.assertRaisesRegexp(TypeError, 'Complex aggregates require an alias'):
|
||||
Author.objects.aggregate(Sum('age') / Count('age'))
|
||||
|
||||
def test_aggregate_over_complex_annotation(self):
|
||||
qs = Author.objects.annotate(
|
||||
combined_ages=Sum(F('age') + F('friends__age')))
|
||||
|
||||
age = qs.aggregate(max_combined_age=Max('combined_ages'))
|
||||
self.assertEqual(age['max_combined_age'], 176)
|
||||
|
||||
age = qs.aggregate(max_combined_age_doubled=Max('combined_ages') * 2)
|
||||
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
|
||||
|
||||
age = qs.aggregate(
|
||||
max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'))
|
||||
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
|
||||
|
||||
age = qs.aggregate(
|
||||
max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'),
|
||||
sum_combined_age=Sum('combined_ages'))
|
||||
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
|
||||
self.assertEqual(age['sum_combined_age'], 954)
|
||||
|
||||
age = qs.aggregate(
|
||||
max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'),
|
||||
sum_combined_age_doubled=Sum('combined_ages') + Sum('combined_ages'))
|
||||
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
|
||||
self.assertEqual(age['sum_combined_age_doubled'], 954 * 2)
|
||||
|
||||
def test_values_annotation_with_expression(self):
|
||||
# ensure the F() is promoted to the group by clause
|
||||
qs = Author.objects.values('name').annotate(another_age=Sum('age') + F('age'))
|
||||
a = qs.get(pk=1)
|
||||
self.assertEqual(a['another_age'], 68)
|
||||
|
||||
qs = qs.annotate(friend_count=Count('friends'))
|
||||
a = qs.get(pk=1)
|
||||
self.assertEqual(a['friend_count'], 2)
|
||||
|
||||
qs = qs.annotate(combined_age=Sum('age') + F('friends__age')).filter(pk=1).order_by('-combined_age')
|
||||
self.assertEqual(
|
||||
list(qs), [
|
||||
{
|
||||
"name": 'Adrian Holovaty',
|
||||
"another_age": 68,
|
||||
"friend_count": 1,
|
||||
"combined_age": 69
|
||||
},
|
||||
{
|
||||
"name": 'Adrian Holovaty',
|
||||
"another_age": 68,
|
||||
"friend_count": 1,
|
||||
"combined_age": 63
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
vals = qs.values('name', 'combined_age')
|
||||
self.assertEqual(
|
||||
list(vals), [
|
||||
{
|
||||
"name": 'Adrian Holovaty',
|
||||
"combined_age": 69
|
||||
},
|
||||
{
|
||||
"name": 'Adrian Holovaty',
|
||||
"combined_age": 63
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_annotate_values_aggregate(self):
|
||||
alias_age = Author.objects.annotate(
|
||||
age_alias=F('age')
|
||||
).values(
|
||||
'age_alias',
|
||||
).aggregate(sum_age=Sum('age_alias'))
|
||||
|
||||
age = Author.objects.values('age').aggregate(sum_age=Sum('age'))
|
||||
|
||||
self.assertEqual(alias_age['sum_age'], age['sum_age'])
|
||||
|
||||
def test_annotate_over_annotate(self):
|
||||
author = Author.objects.annotate(
|
||||
age_alias=F('age')
|
||||
).annotate(
|
||||
sum_age=Sum('age_alias')
|
||||
).get(pk=1)
|
||||
|
||||
other_author = Author.objects.annotate(
|
||||
sum_age=Sum('age')
|
||||
).get(pk=1)
|
||||
|
||||
self.assertEqual(author.sum_age, other_author.sum_age)
|
||||
|
||||
def test_annotated_aggregate_over_annotated_aggregate(self):
|
||||
with self.assertRaisesRegexp(FieldError, "Cannot compute Sum\('id__max'\): 'id__max' is an aggregate"):
|
||||
Book.objects.annotate(Max('id')).annotate(Sum('id__max'))
|
||||
|
||||
def test_add_implementation(self):
|
||||
try:
|
||||
# test completely changing how the output is rendered
|
||||
def lower_case_function_override(self, qn, connection):
|
||||
sql, params = qn.compile(self.source_expressions[0])
|
||||
substitutions = dict(function=self.function.lower(), expressions=sql)
|
||||
substitutions.update(self.extra)
|
||||
return self.template % substitutions, params
|
||||
setattr(Sum, 'as_' + connection.vendor, lower_case_function_override)
|
||||
|
||||
qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
||||
output_field=IntegerField()))
|
||||
self.assertEqual(str(qs.query).count('sum('), 1)
|
||||
b1 = qs.get(pk=4)
|
||||
self.assertEqual(b1.sums, 383)
|
||||
|
||||
# test changing the dict and delegating
|
||||
def lower_case_function_super(self, qn, connection):
|
||||
self.extra['function'] = self.function.lower()
|
||||
return super(Sum, self).as_sql(qn, connection)
|
||||
setattr(Sum, 'as_' + connection.vendor, lower_case_function_super)
|
||||
|
||||
qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
||||
output_field=IntegerField()))
|
||||
self.assertEqual(str(qs.query).count('sum('), 1)
|
||||
b1 = qs.get(pk=4)
|
||||
self.assertEqual(b1.sums, 383)
|
||||
|
||||
# test overriding all parts of the template
|
||||
def be_evil(self, qn, connection):
|
||||
substitutions = dict(function='MAX', expressions='2')
|
||||
substitutions.update(self.extra)
|
||||
return self.template % substitutions, ()
|
||||
setattr(Sum, 'as_' + connection.vendor, be_evil)
|
||||
|
||||
qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
||||
output_field=IntegerField()))
|
||||
self.assertEqual(str(qs.query).count('MAX('), 1)
|
||||
b1 = qs.get(pk=4)
|
||||
self.assertEqual(b1.sums, 2)
|
||||
finally:
|
||||
delattr(Sum, 'as_' + connection.vendor)
|
||||
|
||||
def test_complex_values_aggregation(self):
|
||||
max_rating = Book.objects.values('rating').aggregate(
|
||||
double_max_rating=Max('rating') + Max('rating'))
|
||||
self.assertEqual(max_rating['double_max_rating'], 5 * 2)
|
||||
|
||||
max_books_per_rating = Book.objects.values('rating').annotate(
|
||||
books_per_rating=Count('id') + 5
|
||||
).aggregate(Max('books_per_rating'))
|
||||
self.assertEqual(
|
||||
max_books_per_rating,
|
||||
{'books_per_rating__max': 3 + 5})
|
||||
|
||||
def test_expression_on_aggregation(self):
|
||||
|
||||
# Create a plain expression
|
||||
class Greatest(Func):
|
||||
function = 'GREATEST'
|
||||
|
||||
def as_sqlite(self, qn, connection):
|
||||
return super(Greatest, self).as_sql(qn, connection, function='MAX')
|
||||
|
||||
qs = Publisher.objects.annotate(
|
||||
price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
|
||||
).filter(price_or_median__gte=F('num_awards')).order_by('pk')
|
||||
self.assertQuerysetEqual(
|
||||
qs, [1, 2, 3, 4], lambda v: v.pk)
|
||||
|
||||
qs2 = Publisher.objects.annotate(
|
||||
rating_or_num_awards=Greatest(Avg('book__rating'), F('num_awards'),
|
||||
output_field=FloatField())
|
||||
).filter(rating_or_num_awards__gt=F('num_awards')).order_by('pk')
|
||||
self.assertQuerysetEqual(
|
||||
qs2, [1, 2], lambda v: v.pk)
|
||||
|
||||
def test_backwards_compatibility(self):
|
||||
|
||||
class SqlNewSum(sql_aggregates.Aggregate):
|
||||
sql_function = 'SUM'
|
||||
|
||||
class NewSum(Aggregate):
|
||||
name = 'Sum'
|
||||
|
||||
def add_to_query(self, query, alias, col, source, is_summary):
|
||||
klass = SqlNewSum
|
||||
aggregate = klass(
|
||||
col, source=source, is_summary=is_summary, **self.extra)
|
||||
query.annotations[alias] = aggregate
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", RemovedInDjango20Warning)
|
||||
qs = Author.objects.values('name').annotate(another_age=NewSum('age') + F('age'))
|
||||
a = qs.get(pk=1)
|
||||
self.assertEqual(a['another_age'], 68)
|
||||
|
0
tests/annotations/__init__.py
Normal file
0
tests/annotations/__init__.py
Normal file
243
tests/annotations/fixtures/annotations.json
Normal file
243
tests/annotations/fixtures/annotations.json
Normal file
@@ -0,0 +1,243 @@
|
||||
[
|
||||
{
|
||||
"pk": 1,
|
||||
"model": "annotations.publisher",
|
||||
"fields": {
|
||||
"name": "Apress",
|
||||
"num_awards": 3
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 2,
|
||||
"model": "annotations.publisher",
|
||||
"fields": {
|
||||
"name": "Sams",
|
||||
"num_awards": 1
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 3,
|
||||
"model": "annotations.publisher",
|
||||
"fields": {
|
||||
"name": "Prentice Hall",
|
||||
"num_awards": 7
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 4,
|
||||
"model": "annotations.publisher",
|
||||
"fields": {
|
||||
"name": "Morgan Kaufmann",
|
||||
"num_awards": 9
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 5,
|
||||
"model": "annotations.publisher",
|
||||
"fields": {
|
||||
"name": "Jonno's House of Books",
|
||||
"num_awards": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 1,
|
||||
"model": "annotations.book",
|
||||
"fields": {
|
||||
"publisher": 1,
|
||||
"isbn": "159059725",
|
||||
"name": "The Definitive Guide to Django: Web Development Done Right",
|
||||
"price": "30.00",
|
||||
"rating": 4.5,
|
||||
"authors": [1, 2],
|
||||
"contact": 1,
|
||||
"pages": 447,
|
||||
"pubdate": "2007-12-6"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 2,
|
||||
"model": "annotations.book",
|
||||
"fields": {
|
||||
"publisher": 2,
|
||||
"isbn": "067232959",
|
||||
"name": "Sams Teach Yourself Django in 24 Hours",
|
||||
"price": "23.09",
|
||||
"rating": 3.0,
|
||||
"authors": [3],
|
||||
"contact": 3,
|
||||
"pages": 528,
|
||||
"pubdate": "2008-3-3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 3,
|
||||
"model": "annotations.book",
|
||||
"fields": {
|
||||
"publisher": 1,
|
||||
"isbn": "159059996",
|
||||
"name": "Practical Django Projects",
|
||||
"price": "29.69",
|
||||
"rating": 4.0,
|
||||
"authors": [4],
|
||||
"contact": 4,
|
||||
"pages": 300,
|
||||
"pubdate": "2008-6-23"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 4,
|
||||
"model": "annotations.book",
|
||||
"fields": {
|
||||
"publisher": 3,
|
||||
"isbn": "013235613",
|
||||
"name": "Python Web Development with Django",
|
||||
"price": "29.69",
|
||||
"rating": 4.0,
|
||||
"authors": [5, 6, 7],
|
||||
"contact": 5,
|
||||
"pages": 350,
|
||||
"pubdate": "2008-11-3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 5,
|
||||
"model": "annotations.book",
|
||||
"fields": {
|
||||
"publisher": 3,
|
||||
"isbn": "013790395",
|
||||
"name": "Artificial Intelligence: A Modern Approach",
|
||||
"price": "82.80",
|
||||
"rating": 4.0,
|
||||
"authors": [8, 9],
|
||||
"contact": 8,
|
||||
"pages": 1132,
|
||||
"pubdate": "1995-1-15"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 6,
|
||||
"model": "annotations.book",
|
||||
"fields": {
|
||||
"publisher": 4,
|
||||
"isbn": "155860191",
|
||||
"name": "Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp",
|
||||
"price": "75.00",
|
||||
"rating": 5.0,
|
||||
"authors": [8],
|
||||
"contact": 8,
|
||||
"pages": 946,
|
||||
"pubdate": "1991-10-15"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 1,
|
||||
"model": "annotations.store",
|
||||
"fields": {
|
||||
"books": [1, 2, 3, 4, 5, 6],
|
||||
"name": "Amazon.com",
|
||||
"original_opening": "1994-4-23 9:17:42",
|
||||
"friday_night_closing": "23:59:59"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 2,
|
||||
"model": "annotations.store",
|
||||
"fields": {
|
||||
"books": [1, 3, 5, 6],
|
||||
"name": "Books.com",
|
||||
"original_opening": "2001-3-15 11:23:37",
|
||||
"friday_night_closing": "23:59:59"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 3,
|
||||
"model": "annotations.store",
|
||||
"fields": {
|
||||
"books": [3, 4, 6],
|
||||
"name": "Mamma and Pappa's Books",
|
||||
"original_opening": "1945-4-25 16:24:14",
|
||||
"friday_night_closing": "21:30:00"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 1,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 34,
|
||||
"friends": [2, 4],
|
||||
"name": "Adrian Holovaty"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 2,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 35,
|
||||
"friends": [1, 7],
|
||||
"name": "Jacob Kaplan-Moss"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 3,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 45,
|
||||
"friends": [],
|
||||
"name": "Brad Dayley"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 4,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 29,
|
||||
"friends": [1],
|
||||
"name": "James Bennett"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 5,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 37,
|
||||
"friends": [6, 7],
|
||||
"name": "Jeffrey Forcier"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 6,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 29,
|
||||
"friends": [5, 7],
|
||||
"name": "Paul Bissex"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 7,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 25,
|
||||
"friends": [2, 5, 6],
|
||||
"name": "Wesley J. Chun"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 8,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 57,
|
||||
"friends": [9],
|
||||
"name": "Peter Norvig"
|
||||
}
|
||||
},
|
||||
{
|
||||
"pk": 9,
|
||||
"model": "annotations.author",
|
||||
"fields": {
|
||||
"age": 46,
|
||||
"friends": [8],
|
||||
"name": "Stuart Russell"
|
||||
}
|
||||
}
|
||||
]
|
86
tests/annotations/models.py
Normal file
86
tests/annotations/models.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# coding: utf-8
|
||||
from django.db import models
|
||||
from django.utils.encoding import python_2_unicode_compatible
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Author(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
age = models.IntegerField()
|
||||
friends = models.ManyToManyField('self', blank=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Publisher(models.Model):
|
||||
name = models.CharField(max_length=255)
|
||||
num_awards = models.IntegerField()
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Book(models.Model):
|
||||
isbn = models.CharField(max_length=9)
|
||||
name = models.CharField(max_length=255)
|
||||
pages = models.IntegerField()
|
||||
rating = models.FloatField()
|
||||
price = models.DecimalField(decimal_places=2, max_digits=6)
|
||||
authors = models.ManyToManyField(Author)
|
||||
contact = models.ForeignKey(Author, related_name='book_contact_set')
|
||||
publisher = models.ForeignKey(Publisher)
|
||||
pubdate = models.DateField()
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Store(models.Model):
|
||||
name = models.CharField(max_length=255)
|
||||
books = models.ManyToManyField(Book)
|
||||
original_opening = models.DateTimeField()
|
||||
friday_night_closing = models.TimeField()
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class DepartmentStore(Store):
|
||||
chain = models.CharField(max_length=255)
|
||||
|
||||
def __str__(self):
|
||||
return '%s - %s ' % (self.chain, self.name)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Employee(models.Model):
|
||||
# The order of these fields matter, do not change. Certain backends
|
||||
# rely on field ordering to perform database conversions, and this
|
||||
# model helps to test that.
|
||||
first_name = models.CharField(max_length=20)
|
||||
manager = models.BooleanField(default=False)
|
||||
last_name = models.CharField(max_length=20)
|
||||
store = models.ForeignKey(Store)
|
||||
age = models.IntegerField()
|
||||
salary = models.DecimalField(max_digits=8, decimal_places=2)
|
||||
|
||||
def __str__(self):
|
||||
return '%s %s' % (self.first_name, self.last_name)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Company(models.Model):
|
||||
name = models.CharField(max_length=200)
|
||||
motto = models.CharField(max_length=200, null=True, blank=True)
|
||||
ticker_name = models.CharField(max_length=10, null=True, blank=True)
|
||||
description = models.CharField(max_length=200, null=True, blank=True)
|
||||
|
||||
def __str__(self):
|
||||
return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)'
|
||||
% (self.name, self.motto, self.ticker_name, self.description)
|
||||
)
|
288
tests/annotations/tests.py
Normal file
288
tests/annotations/tests.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from __future__ import unicode_literals
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models import (
|
||||
Sum, Count,
|
||||
F, Value, Func,
|
||||
IntegerField, BooleanField, CharField)
|
||||
from django.db.models.fields import FieldDoesNotExist
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import Author, Book, Store, DepartmentStore, Company, Employee
|
||||
|
||||
|
||||
class NonAggregateAnnotationTestCase(TestCase):
|
||||
fixtures = ["annotations.json"]
|
||||
|
||||
def test_basic_annotation(self):
|
||||
books = Book.objects.annotate(
|
||||
is_book=Value(1, output_field=IntegerField()))
|
||||
for book in books:
|
||||
self.assertEqual(book.is_book, 1)
|
||||
|
||||
def test_basic_f_annotation(self):
|
||||
books = Book.objects.annotate(another_rating=F('rating'))
|
||||
for book in books:
|
||||
self.assertEqual(book.another_rating, book.rating)
|
||||
|
||||
def test_joined_annotation(self):
|
||||
books = Book.objects.select_related('publisher').annotate(
|
||||
num_awards=F('publisher__num_awards'))
|
||||
for book in books:
|
||||
self.assertEqual(book.num_awards, book.publisher.num_awards)
|
||||
|
||||
def test_annotate_with_aggregation(self):
|
||||
books = Book.objects.annotate(
|
||||
is_book=Value(1, output_field=IntegerField()),
|
||||
rating_count=Count('rating'))
|
||||
for book in books:
|
||||
self.assertEqual(book.is_book, 1)
|
||||
self.assertEqual(book.rating_count, 1)
|
||||
|
||||
def test_aggregate_over_annotation(self):
|
||||
agg = Author.objects.annotate(other_age=F('age')).aggregate(otherage_sum=Sum('other_age'))
|
||||
other_agg = Author.objects.aggregate(age_sum=Sum('age'))
|
||||
self.assertEqual(agg['otherage_sum'], other_agg['age_sum'])
|
||||
|
||||
def test_filter_annotation(self):
|
||||
books = Book.objects.annotate(
|
||||
is_book=Value(1, output_field=IntegerField())
|
||||
).filter(is_book=1)
|
||||
for book in books:
|
||||
self.assertEqual(book.is_book, 1)
|
||||
|
||||
def test_filter_annotation_with_f(self):
|
||||
books = Book.objects.annotate(
|
||||
other_rating=F('rating')
|
||||
).filter(other_rating=3.5)
|
||||
for book in books:
|
||||
self.assertEqual(book.other_rating, 3.5)
|
||||
|
||||
def test_filter_annotation_with_double_f(self):
|
||||
books = Book.objects.annotate(
|
||||
other_rating=F('rating')
|
||||
).filter(other_rating=F('rating'))
|
||||
for book in books:
|
||||
self.assertEqual(book.other_rating, book.rating)
|
||||
|
||||
def test_filter_agg_with_double_f(self):
|
||||
books = Book.objects.annotate(
|
||||
sum_rating=Sum('rating')
|
||||
).filter(sum_rating=F('sum_rating'))
|
||||
for book in books:
|
||||
self.assertEqual(book.sum_rating, book.rating)
|
||||
|
||||
def test_filter_wrong_annotation(self):
|
||||
with self.assertRaisesRegexp(FieldError, "Cannot resolve keyword .*"):
|
||||
list(Book.objects.annotate(
|
||||
sum_rating=Sum('rating')
|
||||
).filter(sum_rating=F('nope')))
|
||||
|
||||
def test_update_with_annotation(self):
|
||||
book_preupdate = Book.objects.get(pk=2)
|
||||
Book.objects.annotate(other_rating=F('rating') - 1).update(rating=F('other_rating'))
|
||||
book_postupdate = Book.objects.get(pk=2)
|
||||
self.assertEqual(book_preupdate.rating - 1, book_postupdate.rating)
|
||||
|
||||
def test_annotation_with_m2m(self):
|
||||
books = Book.objects.annotate(author_age=F('authors__age')).filter(pk=1).order_by('author_age')
|
||||
self.assertEqual(books[0].author_age, 34)
|
||||
self.assertEqual(books[1].author_age, 35)
|
||||
|
||||
def test_annotation_reverse_m2m(self):
|
||||
books = Book.objects.annotate(
|
||||
store_name=F('store__name')).filter(
|
||||
name='Practical Django Projects').order_by(
|
||||
'store_name')
|
||||
|
||||
self.assertQuerysetEqual(
|
||||
books, [
|
||||
'Amazon.com',
|
||||
'Books.com',
|
||||
'Mamma and Pappa\'s Books'
|
||||
],
|
||||
lambda b: b.store_name
|
||||
)
|
||||
|
||||
def test_values_annotation(self):
|
||||
"""
|
||||
Annotations can reference fields in a values clause,
|
||||
and contribute to an existing values clause.
|
||||
"""
|
||||
# annotate references a field in values()
|
||||
qs = Book.objects.values('rating').annotate(other_rating=F('rating') - 1)
|
||||
book = qs.get(pk=1)
|
||||
self.assertEqual(book['rating'] - 1, book['other_rating'])
|
||||
|
||||
# filter refs the annotated value
|
||||
book = qs.get(other_rating=4)
|
||||
self.assertEqual(book['other_rating'], 4)
|
||||
|
||||
# can annotate an existing values with a new field
|
||||
book = qs.annotate(other_isbn=F('isbn')).get(other_rating=4)
|
||||
self.assertEqual(book['other_rating'], 4)
|
||||
self.assertEqual(book['other_isbn'], '155860191')
|
||||
|
||||
def test_defer_annotation(self):
|
||||
"""
|
||||
Deferred attributes can be referenced by an annotation,
|
||||
but they are not themselves deferred, and cannot be deferred.
|
||||
"""
|
||||
qs = Book.objects.defer('rating').annotate(other_rating=F('rating') - 1)
|
||||
|
||||
with self.assertNumQueries(2):
|
||||
book = qs.get(other_rating=4)
|
||||
self.assertEqual(book.rating, 5)
|
||||
self.assertEqual(book.other_rating, 4)
|
||||
|
||||
with self.assertRaisesRegexp(FieldDoesNotExist, "\w has no field named u?'other_rating'"):
|
||||
book = qs.defer('other_rating').get(other_rating=4)
|
||||
|
||||
def test_mti_annotations(self):
|
||||
"""
|
||||
Fields on an inherited model can be referenced by an
|
||||
annotated field.
|
||||
"""
|
||||
d = DepartmentStore.objects.create(
|
||||
name='Angus & Robinson',
|
||||
original_opening=datetime.date(2014, 3, 8),
|
||||
friday_night_closing=datetime.time(21, 00, 00),
|
||||
chain='Westfield'
|
||||
)
|
||||
|
||||
books = Book.objects.filter(rating__gt=4)
|
||||
for b in books:
|
||||
d.books.add(b)
|
||||
|
||||
qs = DepartmentStore.objects.annotate(
|
||||
other_name=F('name'),
|
||||
other_chain=F('chain'),
|
||||
is_open=Value(True, BooleanField()),
|
||||
book_isbn=F('books__isbn')
|
||||
).select_related('store').order_by('book_isbn').filter(chain='Westfield')
|
||||
|
||||
self.assertQuerysetEqual(
|
||||
qs, [
|
||||
('Angus & Robinson', 'Westfield', True, '155860191'),
|
||||
('Angus & Robinson', 'Westfield', True, '159059725')
|
||||
],
|
||||
lambda d: (d.other_name, d.other_chain, d.is_open, d.book_isbn)
|
||||
)
|
||||
|
||||
def test_column_field_ordering(self):
|
||||
"""
|
||||
Test that columns are aligned in the correct order for
|
||||
resolve_columns. This test will fail on mysql if column
|
||||
ordering is out. Column fields should be aligned as:
|
||||
1. extra_select
|
||||
2. model_fields
|
||||
3. annotation_fields
|
||||
4. model_related_fields
|
||||
"""
|
||||
store = Store.objects.first()
|
||||
Employee.objects.create(id=1, first_name='Max', manager=True, last_name='Paine',
|
||||
store=store, age=23, salary=Decimal(50000.00))
|
||||
Employee.objects.create(id=2, first_name='Buffy', manager=False, last_name='Summers',
|
||||
store=store, age=18, salary=Decimal(40000.00))
|
||||
|
||||
qs = Employee.objects.extra(
|
||||
select={'random_value': '42'}
|
||||
).select_related('store').annotate(
|
||||
annotated_value=Value(17, output_field=IntegerField())
|
||||
)
|
||||
|
||||
rows = [
|
||||
(1, 'Max', True, 42, 'Paine', 23, Decimal(50000.00), store.name, 17),
|
||||
(2, 'Buffy', False, 42, 'Summers', 18, Decimal(40000.00), store.name, 17)
|
||||
]
|
||||
|
||||
self.assertQuerysetEqual(
|
||||
qs.order_by('id'), rows,
|
||||
lambda e: (
|
||||
e.id, e.first_name, e.manager, e.random_value, e.last_name, e.age,
|
||||
e.salary, e.store.name, e.annotated_value))
|
||||
|
||||
def test_column_field_ordering_with_deferred(self):
|
||||
store = Store.objects.first()
|
||||
Employee.objects.create(id=1, first_name='Max', manager=True, last_name='Paine',
|
||||
store=store, age=23, salary=Decimal(50000.00))
|
||||
Employee.objects.create(id=2, first_name='Buffy', manager=False, last_name='Summers',
|
||||
store=store, age=18, salary=Decimal(40000.00))
|
||||
|
||||
qs = Employee.objects.extra(
|
||||
select={'random_value': '42'}
|
||||
).select_related('store').annotate(
|
||||
annotated_value=Value(17, output_field=IntegerField())
|
||||
)
|
||||
|
||||
rows = [
|
||||
(1, 'Max', True, 42, 'Paine', 23, Decimal(50000.00), store.name, 17),
|
||||
(2, 'Buffy', False, 42, 'Summers', 18, Decimal(40000.00), store.name, 17)
|
||||
]
|
||||
|
||||
# and we respect deferred columns!
|
||||
self.assertQuerysetEqual(
|
||||
qs.defer('age').order_by('id'), rows,
|
||||
lambda e: (
|
||||
e.id, e.first_name, e.manager, e.random_value, e.last_name, e.age,
|
||||
e.salary, e.store.name, e.annotated_value))
|
||||
|
||||
def test_custom_functions(self):
|
||||
Company(name='Apple', motto=None, ticker_name='APPL', description='Beautiful Devices').save()
|
||||
Company(name='Django Software Foundation', motto=None, ticker_name=None, description=None).save()
|
||||
Company(name='Google', motto='Do No Evil', ticker_name='GOOG', description='Internet Company').save()
|
||||
Company(name='Yahoo', motto=None, ticker_name=None, description='Internet Company').save()
|
||||
|
||||
qs = Company.objects.annotate(
|
||||
tagline=Func(
|
||||
F('motto'),
|
||||
F('ticker_name'),
|
||||
F('description'),
|
||||
Value('No Tag'),
|
||||
function='COALESCE')
|
||||
).order_by('name')
|
||||
|
||||
self.assertQuerysetEqual(
|
||||
qs, [
|
||||
('Apple', 'APPL'),
|
||||
('Django Software Foundation', 'No Tag'),
|
||||
('Google', 'Do No Evil'),
|
||||
('Yahoo', 'Internet Company')
|
||||
],
|
||||
lambda c: (c.name, c.tagline)
|
||||
)
|
||||
|
||||
def test_custom_functions_can_ref_other_functions(self):
|
||||
Company(name='Apple', motto=None, ticker_name='APPL', description='Beautiful Devices').save()
|
||||
Company(name='Django Software Foundation', motto=None, ticker_name=None, description=None).save()
|
||||
Company(name='Google', motto='Do No Evil', ticker_name='GOOG', description='Internet Company').save()
|
||||
Company(name='Yahoo', motto=None, ticker_name=None, description='Internet Company').save()
|
||||
|
||||
class Lower(Func):
|
||||
function = 'LOWER'
|
||||
|
||||
qs = Company.objects.annotate(
|
||||
tagline=Func(
|
||||
F('motto'),
|
||||
F('ticker_name'),
|
||||
F('description'),
|
||||
Value('No Tag'),
|
||||
function='COALESCE')
|
||||
).annotate(
|
||||
tagline_lower=Lower(F('tagline'), output_field=CharField())
|
||||
).order_by('name')
|
||||
|
||||
# LOWER function supported by:
|
||||
# oracle, postgres, mysql, sqlite, sqlserver
|
||||
|
||||
self.assertQuerysetEqual(
|
||||
qs, [
|
||||
('Apple', 'APPL'.lower()),
|
||||
('Django Software Foundation', 'No Tag'.lower()),
|
||||
('Google', 'Do No Evil'.lower()),
|
||||
('Yahoo', 'Internet Company'.lower())
|
||||
],
|
||||
lambda c: (c.name, c.tagline_lower)
|
||||
)
|
@@ -296,6 +296,21 @@ class ExpressionsTests(TestCase):
|
||||
g = deepcopy(f)
|
||||
self.assertEqual(f.name, g.name)
|
||||
|
||||
def test_f_reuse(self):
|
||||
f = F('id')
|
||||
n = Number.objects.create(integer=-1)
|
||||
c = Company.objects.create(
|
||||
name="Example Inc.", num_employees=2300, num_chairs=5,
|
||||
ceo=Employee.objects.create(firstname="Joe", lastname="Smith")
|
||||
)
|
||||
c_qs = Company.objects.filter(id=f)
|
||||
self.assertEqual(c_qs.get(), c)
|
||||
# Reuse the same F-object for another queryset
|
||||
n_qs = Number.objects.filter(id=f)
|
||||
self.assertEqual(n_qs.get(), n)
|
||||
# The original query still works correctly
|
||||
self.assertEqual(c_qs.get(), c)
|
||||
|
||||
|
||||
class ExpressionsNumericTests(TestCase):
|
||||
|
||||
@@ -362,12 +377,16 @@ class ExpressionsNumericTests(TestCase):
|
||||
Complex expressions of different connection types are possible.
|
||||
"""
|
||||
n = Number.objects.create(integer=10, float=123.45)
|
||||
self.assertEqual(Number.objects.filter(pk=n.pk)
|
||||
.update(float=F('integer') + F('float') * 2), 1)
|
||||
self.assertEqual(Number.objects.filter(pk=n.pk).update(
|
||||
float=F('integer') + F('float') * 2), 1)
|
||||
|
||||
self.assertEqual(Number.objects.get(pk=n.pk).integer, 10)
|
||||
self.assertEqual(Number.objects.get(pk=n.pk).float, Approximate(256.900, places=3))
|
||||
|
||||
def test_incorrect_field_expression(self):
|
||||
with self.assertRaisesRegexp(FieldError, "Cannot resolve keyword u?'nope' into field.*"):
|
||||
list(Employee.objects.filter(firstname=F('nope')))
|
||||
|
||||
|
||||
class ExpressionOperatorTests(TestCase):
|
||||
def setUp(self):
|
||||
|
Reference in New Issue
Block a user