1
0
mirror of https://github.com/django/django.git synced 2025-07-05 10:19:20 +00:00

[soc2009/multidb] Fixed test failures that were introduced in r10943

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@10951 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-06-08 01:14:44 +00:00
parent 1653ffb571
commit 33ed158dc4
5 changed files with 23 additions and 11 deletions

View File

@ -21,6 +21,7 @@ class Aggregate(object):
is_ordinal = False is_ordinal = False
is_computed = False is_computed = False
sql_template = '%(function)s(%(field)s)' sql_template = '%(function)s(%(field)s)'
as_sql_takes_connection = True
def __init__(self, col, source=None, is_summary=False, **extra): def __init__(self, col, source=None, is_summary=False, **extra):
"""Instantiate an SQL aggregate """Instantiate an SQL aggregate
@ -72,15 +73,16 @@ class Aggregate(object):
if isinstance(self.col, (list, tuple)): if isinstance(self.col, (list, tuple)):
self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
def as_sql(self, quote_func=None): def as_sql(self, qn, connection):
"Return the aggregate, rendered as SQL." "Return the aggregate, rendered as SQL."
if not quote_func:
quote_func = lambda x: x
if hasattr(self.col, 'as_sql'): if hasattr(self.col, 'as_sql'):
field_name = self.col.as_sql(quote_func) if getattr(self.col, 'as_sql_takes_connection', False):
field_name = self.col.as_sql(qn, connection)
else:
field_name = self.col.as_sql(qn)
elif isinstance(self.col, (list, tuple)): elif isinstance(self.col, (list, tuple)):
field_name = '.'.join([quote_func(c) for c in self.col]) field_name = '.'.join([qn(c) for c in self.col])
else: else:
field_name = self.col field_name = self.col
@ -127,4 +129,3 @@ class Variance(Aggregate):
def __init__(self, col, sample=False, **extra): def __init__(self, col, sample=False, **extra):
super(Variance, self).__init__(col, **extra) super(Variance, self).__init__(col, **extra)
self.sql_function = sample and 'VAR_SAMP' or 'VAR_POP' self.sql_function = sample and 'VAR_SAMP' or 'VAR_POP'

View File

@ -78,6 +78,9 @@ class SQLEvaluator(object):
def evaluate_leaf(self, node, qn, connection): def evaluate_leaf(self, node, qn, connection):
col = self.cols[node] col = self.cols[node]
if hasattr(col, 'as_sql'): if hasattr(col, 'as_sql'):
return col.as_sql(qn), () if getattr(col, 'as_sql_takes_connection', False):
return col.as_sql(qn, connection), ()
else:
return col.as_sql(qn)
else: else:
return '%s.%s' % (qn(col[0]), qn(col[1])), () return '%s.%s' % (qn(col[0]), qn(col[1])), ()

View File

@ -746,7 +746,7 @@ class BaseQuery(object):
result.extend([ result.extend([
'%s%s' % ( '%s%s' % (
aggregate.as_sql(quote_func=qn), aggregate.as_sql(qn, self.connection),
alias is not None and ' AS %s' % qn(alias) or '' alias is not None and ' AS %s' % qn(alias) or ''
) )
for alias, aggregate in self.aggregate_select.items() for alias, aggregate in self.aggregate_select.items()

View File

@ -421,17 +421,21 @@ class AggregateQuery(Query):
An AggregateQuery takes another query as a parameter to the FROM An AggregateQuery takes another query as a parameter to the FROM
clause and only selects the elements in the provided list. clause and only selects the elements in the provided list.
""" """
as_sql_takes_connection = True
def add_subquery(self, query): def add_subquery(self, query):
self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)
def as_sql(self, quote_func=None): def as_sql(self, qn=None):
""" """
Creates the SQL for this query. Returns the SQL string and list of Creates the SQL for this query. Returns the SQL string and list of
parameters. parameters.
""" """
if qn is None:
qn = self.quote_name_unless_alias
sql = ('SELECT %s FROM (%s) subquery' % ( sql = ('SELECT %s FROM (%s) subquery' % (
', '.join([ ', '.join([
aggregate.as_sql() aggregate.as_sql(qn, self.connection)
for aggregate in self.aggregate_select.values() for aggregate in self.aggregate_select.values()
]), ]),
self.subquery) self.subquery)

View File

@ -152,7 +152,11 @@ class WhereNode(tree.Node):
field_sql = self.sql_for_columns(lvalue, qn, connection) field_sql = self.sql_for_columns(lvalue, qn, connection)
else: else:
# A smart object with an as_sql() method. # A smart object with an as_sql() method.
field_sql = lvalue.as_sql(quote_func=qn) if getattr(lvalue, 'as_sql_takes_connection', False):
field_sql = lvalue.as_sql(qn, connection)
else:
field_sql = lvalue.as_sql(qn)
if value_annot is datetime.datetime: if value_annot is datetime.datetime:
cast_sql = connection.ops.datetime_cast_sql() cast_sql = connection.ops.datetime_cast_sql()