diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 6fdaf188c4..b7a38f5b69 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -21,6 +21,7 @@ class Aggregate(object): is_ordinal = False is_computed = False sql_template = '%(function)s(%(field)s)' + as_sql_takes_connection = True def __init__(self, col, source=None, is_summary=False, **extra): """Instantiate an SQL aggregate @@ -72,15 +73,16 @@ class Aggregate(object): if isinstance(self.col, (list, tuple)): self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) - def as_sql(self, quote_func=None): + def as_sql(self, qn, connection): "Return the aggregate, rendered as SQL." - if not quote_func: - quote_func = lambda x: x if hasattr(self.col, 'as_sql'): - field_name = self.col.as_sql(quote_func) + if getattr(self.col, 'as_sql_takes_connection', False): + field_name = self.col.as_sql(qn, connection) + else: + field_name = self.col.as_sql(qn) elif isinstance(self.col, (list, tuple)): - field_name = '.'.join([quote_func(c) for c in self.col]) + field_name = '.'.join([qn(c) for c in self.col]) else: field_name = self.col @@ -127,4 +129,3 @@ class Variance(Aggregate): def __init__(self, col, sample=False, **extra): super(Variance, self).__init__(col, **extra) self.sql_function = sample and 'VAR_SAMP' or 'VAR_POP' - diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 2c447022d3..e0118fd198 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -78,6 +78,9 @@ class SQLEvaluator(object): def evaluate_leaf(self, node, qn, connection): col = self.cols[node] if hasattr(col, 'as_sql'): - return col.as_sql(qn), () + if getattr(col, 'as_sql_takes_connection', False): + return col.as_sql(qn, connection), () + else: + return col.as_sql(qn) else: return '%s.%s' % (qn(col[0]), qn(col[1])), () diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fb17db546a..f06a9c2300 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -746,7 +746,7 @@ class BaseQuery(object): result.extend([ '%s%s' % ( - aggregate.as_sql(quote_func=qn), + aggregate.as_sql(qn, self.connection), alias is not None and ' AS %s' % qn(alias) or '' ) for alias, aggregate in self.aggregate_select.items() diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index db514c8b75..8712a4fdac 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -421,17 +421,21 @@ class AggregateQuery(Query): An AggregateQuery takes another query as a parameter to the FROM clause and only selects the elements in the provided list. """ + as_sql_takes_connection = True + def add_subquery(self, query): self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) - def as_sql(self, quote_func=None): + def as_sql(self, qn=None): """ Creates the SQL for this query. Returns the SQL string and list of parameters. """ + if qn is None: + qn = self.quote_name_unless_alias sql = ('SELECT %s FROM (%s) subquery' % ( ', '.join([ - aggregate.as_sql() + aggregate.as_sql(qn, self.connection) for aggregate in self.aggregate_select.values() ]), self.subquery) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 69d798d3bb..ac9c375342 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -152,7 +152,11 @@ class WhereNode(tree.Node): field_sql = self.sql_for_columns(lvalue, qn, connection) else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(quote_func=qn) + if getattr(lvalue, 'as_sql_takes_connection', False): + field_sql = lvalue.as_sql(qn, connection) + else: + field_sql = lvalue.as_sql(qn) + if value_annot is datetime.datetime: cast_sql = connection.ops.datetime_cast_sql()