mirror of
https://github.com/django/django.git
synced 2024-12-28 12:06:22 +00:00
[1.8.x] Fixed #24154 -- Backends can now check support for expressions
Backport of 8196e4bdf4
from master
This commit is contained in:
parent
5dff3513cc
commit
e56810e839
@ -98,12 +98,12 @@ class BaseSpatialOperations(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
|
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
|
||||||
|
|
||||||
def check_aggregate_support(self, aggregate):
|
def check_expression_support(self, expression):
|
||||||
if isinstance(aggregate, self.disallowed_aggregates):
|
if isinstance(expression, self.disallowed_aggregates):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"%s spatial aggregation is not supported by this database backend." % aggregate.name
|
"%s spatial aggregation is not supported by this database backend." % expression.name
|
||||||
)
|
)
|
||||||
super(BaseSpatialOperations, self).check_aggregate_support(aggregate)
|
super(BaseSpatialOperations, self).check_expression_support(expression)
|
||||||
|
|
||||||
def spatial_aggregate_name(self, agg_name):
|
def spatial_aggregate_name(self, agg_name):
|
||||||
raise NotImplementedError('Aggregate support not implemented for this spatial backend.')
|
raise NotImplementedError('Aggregate support not implemented for this spatial backend.')
|
||||||
|
@ -9,6 +9,9 @@ class GeoAggregate(Aggregate):
|
|||||||
is_extent = False
|
is_extent = False
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
# this will be called again in parent, but it's needed now - before
|
||||||
|
# we get the spatial_aggregate_name
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
self.function = connection.ops.spatial_aggregate_name(self.name)
|
self.function = connection.ops.spatial_aggregate_name(self.name)
|
||||||
return super(GeoAggregate, self).as_sql(compiler, connection)
|
return super(GeoAggregate, self).as_sql(compiler, connection)
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from django.db.models.aggregates import StdDev
|
||||||
|
from django.db.models.expressions import Value
|
||||||
from django.db.utils import ProgrammingError
|
from django.db.utils import ProgrammingError
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
@ -226,12 +228,8 @@ class BaseDatabaseFeatures(object):
|
|||||||
@cached_property
|
@cached_property
|
||||||
def supports_stddev(self):
|
def supports_stddev(self):
|
||||||
"""Confirm support for STDDEV and related stats functions."""
|
"""Confirm support for STDDEV and related stats functions."""
|
||||||
class StdDevPop(object):
|
|
||||||
contains_aggregate = True
|
|
||||||
sql_function = 'STDDEV_POP'
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.connection.ops.check_aggregate_support(StdDevPop())
|
self.connection.ops.check_expression_support(StdDev(Value(1)))
|
||||||
return True
|
return True
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return False
|
return False
|
||||||
|
@ -526,12 +526,16 @@ class BaseDatabaseOperations(object):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def check_aggregate_support(self, aggregate_func):
|
def check_aggregate_support(self, aggregate_func):
|
||||||
"""Check that the backend supports the provided aggregate
|
return self.check_expression_support(aggregate_func)
|
||||||
|
|
||||||
This is used on specific backends to rule out known aggregates
|
def check_expression_support(self, expression):
|
||||||
that are known to have faulty implementations. If the named
|
"""
|
||||||
aggregate function has a known problem, the backend should
|
Check that the backend supports the provided expression.
|
||||||
raise NotImplementedError.
|
|
||||||
|
This is used on specific backends to rule out known expressions
|
||||||
|
that have problematic or nonexistent implementations. If the
|
||||||
|
expression has a known problem, the backend should raise
|
||||||
|
NotImplementedError.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
"""Confirm support for STDDEV and related stats functions
|
"""Confirm support for STDDEV and related stats functions
|
||||||
|
|
||||||
SQLite supports STDDEV as an extension package; so
|
SQLite supports STDDEV as an extension package; so
|
||||||
connection.ops.check_aggregate_support() can't unilaterally
|
connection.ops.check_expression_support() can't unilaterally
|
||||||
rule out support for STDDEV. We need to manually check
|
rule out support for STDDEV. We need to manually check
|
||||||
whether the call works.
|
whether the call works.
|
||||||
"""
|
"""
|
||||||
|
@ -4,7 +4,7 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured, FieldError
|
||||||
from django.db import utils
|
from django.db import utils
|
||||||
from django.db.backends import utils as backend_utils
|
from django.db.backends import utils as backend_utils
|
||||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||||
@ -33,15 +33,21 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||||||
limit = 999 if len(fields) > 1 else 500
|
limit = 999 if len(fields) > 1 else 500
|
||||||
return (limit // len(fields)) if len(fields) > 0 else len(objs)
|
return (limit // len(fields)) if len(fields) > 0 else len(objs)
|
||||||
|
|
||||||
def check_aggregate_support(self, aggregate):
|
def check_expression_support(self, expression):
|
||||||
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
|
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
|
||||||
bad_aggregates = (aggregates.Sum, aggregates.Avg,
|
bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
|
||||||
aggregates.Variance, aggregates.StdDev)
|
if isinstance(expression, bad_aggregates):
|
||||||
if aggregate.refs_field(bad_aggregates, bad_fields):
|
try:
|
||||||
|
output_field = expression.input_field.output_field
|
||||||
|
if isinstance(output_field, bad_fields):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'You cannot use Sum, Avg, StdDev and Variance aggregations '
|
'You cannot use Sum, Avg, StdDev and Variance aggregations '
|
||||||
'on date/time fields in sqlite3 '
|
'on date/time fields in sqlite3 '
|
||||||
'since date/time is saved as text.')
|
'since date/time is saved as text.')
|
||||||
|
except FieldError:
|
||||||
|
# not every sub-expression has an output_field which is fine to
|
||||||
|
# ignore
|
||||||
|
pass
|
||||||
|
|
||||||
def date_extract_sql(self, lookup_type, field_name):
|
def date_extract_sql(self, lookup_type, field_name):
|
||||||
# sqlite doesn't support extract, so we fake it with the user-defined
|
# sqlite doesn't support extract, so we fake it with the user-defined
|
||||||
|
@ -25,17 +25,6 @@ class Aggregate(Func):
|
|||||||
c._patch_aggregate(query) # backward-compatibility support
|
c._patch_aggregate(query) # backward-compatibility support
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def refs_field(self, aggregate_types, field_types):
|
|
||||||
try:
|
|
||||||
return (isinstance(self, aggregate_types) and
|
|
||||||
isinstance(self.input_field._output_field_or_none, field_types))
|
|
||||||
except FieldError:
|
|
||||||
# Sometimes we don't know the input_field's output type (for example,
|
|
||||||
# doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
|
|
||||||
# is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
|
|
||||||
# have any output field.
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_field(self):
|
def input_field(self):
|
||||||
return self.source_expressions[0]
|
return self.source_expressions[0]
|
||||||
|
@ -297,14 +297,6 @@ class BaseExpression(object):
|
|||||||
return agg, lookup
|
return agg, lookup
|
||||||
return False, ()
|
return False, ()
|
||||||
|
|
||||||
def refs_field(self, aggregate_types, field_types):
|
|
||||||
"""
|
|
||||||
Helper method for check_aggregate_support on backends
|
|
||||||
"""
|
|
||||||
return any(
|
|
||||||
node.refs_field(aggregate_types, field_types)
|
|
||||||
for node in self.get_source_expressions())
|
|
||||||
|
|
||||||
def prepare_database_save(self, field):
|
def prepare_database_save(self, field):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -401,6 +393,7 @@ class DurationExpression(Expression):
|
|||||||
return compiler.compile(side)
|
return compiler.compile(side)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
expressions = []
|
expressions = []
|
||||||
expression_params = []
|
expression_params = []
|
||||||
sql, params = self.compile(self.lhs, compiler, connection)
|
sql, params = self.compile(self.lhs, compiler, connection)
|
||||||
@ -473,6 +466,7 @@ class Func(ExpressionNode):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, function=None, template=None):
|
def as_sql(self, compiler, connection, function=None, template=None):
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
sql_parts = []
|
sql_parts = []
|
||||||
params = []
|
params = []
|
||||||
for arg in self.source_expressions:
|
for arg in self.source_expressions:
|
||||||
@ -511,6 +505,7 @@ class Value(ExpressionNode):
|
|||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
val = self.value
|
val = self.value
|
||||||
# check _output_field to avoid triggering an exception
|
# check _output_field to avoid triggering an exception
|
||||||
if self._output_field is not None:
|
if self._output_field is not None:
|
||||||
@ -536,6 +531,7 @@ class Value(ExpressionNode):
|
|||||||
|
|
||||||
class DurationValue(Value):
|
class DurationValue(Value):
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
if (connection.features.has_native_duration_field and
|
if (connection.features.has_native_duration_field and
|
||||||
connection.features.driver_supports_timedelta_args):
|
connection.features.driver_supports_timedelta_args):
|
||||||
return super(DurationValue, self).as_sql(compiler, connection)
|
return super(DurationValue, self).as_sql(compiler, connection)
|
||||||
@ -650,6 +646,7 @@ class When(ExpressionNode):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, template=None):
|
def as_sql(self, compiler, connection, template=None):
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
template_params = {}
|
template_params = {}
|
||||||
sql_params = []
|
sql_params = []
|
||||||
condition_sql, condition_params = compiler.compile(self.condition)
|
condition_sql, condition_params = compiler.compile(self.condition)
|
||||||
@ -715,6 +712,7 @@ class Case(ExpressionNode):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, template=None, extra=None):
|
def as_sql(self, compiler, connection, template=None, extra=None):
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
if not self.cases:
|
if not self.cases:
|
||||||
return compiler.compile(self.default)
|
return compiler.compile(self.default)
|
||||||
template_params = dict(extra) if extra else {}
|
template_params = dict(extra) if extra else {}
|
||||||
@ -851,6 +849,7 @@ class OrderBy(BaseExpression):
|
|||||||
return [self.expression]
|
return [self.expression]
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
connection.ops.check_expression_support(self)
|
||||||
expression_sql, params = compiler.compile(self.expression)
|
expression_sql, params = compiler.compile(self.expression)
|
||||||
placeholders = {'expression': expression_sql}
|
placeholders = {'expression': expression_sql}
|
||||||
placeholders['ordering'] = 'DESC' if self.descending else 'ASC'
|
placeholders['ordering'] = 'DESC' if self.descending else 'ASC'
|
||||||
|
@ -230,11 +230,6 @@ class Query(object):
|
|||||||
raise ValueError("Need either using or connection")
|
raise ValueError("Need either using or connection")
|
||||||
if using:
|
if using:
|
||||||
connection = connections[using]
|
connection = connections[using]
|
||||||
|
|
||||||
# Check that the compiler will be able to execute the query
|
|
||||||
for alias, annotation in self.annotation_select.items():
|
|
||||||
connection.ops.check_aggregate_support(annotation)
|
|
||||||
|
|
||||||
return connection.ops.compiler(self.compiler)(self, connection, using)
|
return connection.ops.compiler(self.compiler)(self, connection, using)
|
||||||
|
|
||||||
def get_meta(self):
|
def get_meta(self):
|
||||||
|
@ -325,17 +325,6 @@ class WhereNode(tree.Node):
|
|||||||
def contains_aggregate(self):
|
def contains_aggregate(self):
|
||||||
return self._contains_aggregate(self)
|
return self._contains_aggregate(self)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _refs_field(cls, obj, aggregate_types, field_types):
|
|
||||||
if not isinstance(obj, tree.Node):
|
|
||||||
if hasattr(obj.rhs, 'refs_field'):
|
|
||||||
return obj.rhs.refs_field(aggregate_types, field_types)
|
|
||||||
return False
|
|
||||||
return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children)
|
|
||||||
|
|
||||||
def refs_field(self, aggregate_types, field_types):
|
|
||||||
return self._refs_field(self, aggregate_types, field_types)
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyWhere(WhereNode):
|
class EmptyWhere(WhereNode):
|
||||||
def add(self, data, connector):
|
def add(self, data, connector):
|
||||||
|
@ -128,12 +128,19 @@ class SQLiteTests(TestCase):
|
|||||||
#19360: Raise NotImplementedError when aggregating on date/time fields.
|
#19360: Raise NotImplementedError when aggregating on date/time fields.
|
||||||
"""
|
"""
|
||||||
for aggregate in (Sum, Avg, Variance, StdDev):
|
for aggregate in (Sum, Avg, Variance, StdDev):
|
||||||
self.assertRaises(NotImplementedError,
|
self.assertRaises(
|
||||||
|
NotImplementedError,
|
||||||
models.Item.objects.all().aggregate, aggregate('time'))
|
models.Item.objects.all().aggregate, aggregate('time'))
|
||||||
self.assertRaises(NotImplementedError,
|
self.assertRaises(
|
||||||
|
NotImplementedError,
|
||||||
models.Item.objects.all().aggregate, aggregate('date'))
|
models.Item.objects.all().aggregate, aggregate('date'))
|
||||||
self.assertRaises(NotImplementedError,
|
self.assertRaises(
|
||||||
|
NotImplementedError,
|
||||||
models.Item.objects.all().aggregate, aggregate('last_modified'))
|
models.Item.objects.all().aggregate, aggregate('last_modified'))
|
||||||
|
self.assertRaises(
|
||||||
|
NotImplementedError,
|
||||||
|
models.Item.objects.all().aggregate,
|
||||||
|
**{'complex': aggregate('last_modified') + aggregate('last_modified')})
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
|
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
|
||||||
|
Loading…
Reference in New Issue
Block a user