1
0
mirror of https://github.com/django/django.git synced 2025-07-06 18:59:13 +00:00

queryset-refactor: Ported DateQuerySet and ValueQuerySet over and fixed most of

the related tests.


git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6486 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2007-10-14 02:12:40 +00:00
parent bcdedbbf08
commit 988b3bbdcb
4 changed files with 164 additions and 105 deletions

View File

@ -338,12 +338,14 @@ class Model(object):
def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
qn = connection.ops.quote_name
op = is_next and '>' or '<'
where = '(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \
where = ['(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \
(qn(field.column), op, qn(field.column),
qn(self._meta.db_table), qn(self._meta.pk.column), op)
qn(self._meta.db_table), qn(self._meta.pk.column), op)]
param = smart_str(getattr(self, field.attname))
q = self.__class__._default_manager.filter(**kwargs).order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name)
q.extra(where=where, params=[param, param,
order_char = not is_next and '-' or ''
q = self.__class__._default_manager.filter(**kwargs).order_by(
order_char + field.name, order_char + self._meta.pk.name)
q = q.extra(where=where, params=[param, param,
getattr(self, self._meta.pk.attname)])
try:
return q[0]

View File

@ -253,7 +253,6 @@ class _QuerySet(object):
def values(self, *fields):
return self._clone(klass=ValuesQuerySet, _fields=fields)
# FIXME: Not converted yet!
def dates(self, field_name, kind, order='ASC'):
"""
Returns a list of datetime objects representing all available dates
@ -265,8 +264,10 @@ class _QuerySet(object):
"'order' must be either 'ASC' or 'DESC'."
# Let the FieldDoesNotExist exception propagate.
field = self.model._meta.get_field(field_name, many_to_many=False)
assert isinstance(field, DateField), "%r isn't a DateField." % field_name
return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order)
assert isinstance(field, DateField), "%r isn't a DateField." \
% field_name
return self._clone(klass=DateQuerySet, _field=field, _kind=kind,
_order=order)
##################################################################
# PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
@ -389,16 +390,8 @@ class ValuesQuerySet(QuerySet):
self.query.select_related = False
def iterator(self):
try:
select, sql, params = self._get_sql_clause()
except EmptyResultSet:
raise StopIteration
qn = connection.ops.quote_name
# self._select is a dictionary, and dictionaries' key order is
# undefined, so we convert it to a list of tuples.
extra_select = self._select.items()
extra_select = self.query.extra_select.keys()
extra_select.sort()
# Construct two objects -- fields and field_names.
# fields is a list of Field objects to fetch.
@ -406,38 +399,29 @@ class ValuesQuerySet(QuerySet):
# resulting dictionaries.
if self._fields:
if not extra_select:
fields = [self.model._meta.get_field(f, many_to_many=False) for f in self._fields]
fields = [self.model._meta.get_field(f, many_to_many=False)
for f in self._fields]
field_names = self._fields
else:
fields = []
field_names = []
for f in self._fields:
if f in [field.name for field in self.model._meta.fields]:
fields.append(self.model._meta.get_field(f, many_to_many=False))
fields.append(self.model._meta.get_field(f,
many_to_many=False))
field_names.append(f)
elif not self._select.has_key(f):
raise FieldDoesNotExist('%s has no field named %r' % (self.model._meta.object_name, f))
elif not self.query.extra_select.has_key(f):
raise FieldDoesNotExist('%s has no field named %r'
% (self.model._meta.object_name, f))
else: # Default to all fields.
fields = self.model._meta.fields
field_names = [f.attname for f in fields]
columns = [f.column for f in fields]
select = ['%s.%s' % (qn(self.model._meta.db_table), qn(c)) for c in columns]
self.query.add_local_columns([f.column for f in fields])
if extra_select:
select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in extra_select])
field_names.extend([f[0] for f in extra_select])
field_names.extend([f for f in extra_select])
cursor = connection.cursor()
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
has_resolve_columns = hasattr(self, 'resolve_columns')
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
raise StopIteration
for row in rows:
if has_resolve_columns:
row = self.resolve_columns(row, fields)
for row in self.query.results_iter():
yield dict(zip(field_names, row))
def _clone(self, klass=None, **kwargs):
@ -447,60 +431,19 @@ class ValuesQuerySet(QuerySet):
class DateQuerySet(QuerySet):
def iterator(self):
from django.db.backends.util import typecast_timestamp
from django.db.models.fields import DateTimeField
qn = connection.ops.quote_name
self._order_by = () # Clear this because it'll mess things up otherwise.
self.query = self.query.clone(klass=sql.DateQuery)
self.query.select = []
self.query.add_date_select(self._field.column, self._kind, self._order)
if self._field.null:
self._where.append('%s.%s IS NOT NULL' % \
(qn(self.model._meta.db_table), qn(self._field.column)))
try:
select, sql, params = self._get_sql_clause()
except EmptyResultSet:
raise StopIteration
table_name = qn(self.model._meta.db_table)
field_name = qn(self._field.column)
if connection.features.allows_group_by_ordinal:
group_by = '1'
else:
group_by = connection.ops.date_trunc_sql(self._kind, '%s.%s' % (table_name, field_name))
sql = 'SELECT %s %s GROUP BY %s ORDER BY 1 %s' % \
(connection.ops.date_trunc_sql(self._kind, '%s.%s' % (qn(self.model._meta.db_table),
qn(self._field.column))), sql, group_by, self._order)
cursor = connection.cursor()
cursor.execute(sql, params)
has_resolve_columns = hasattr(self, 'resolve_columns')
needs_datetime_string_cast = connection.features.needs_datetime_string_cast
dates = []
# It would be better to use self._field here instead of DateTimeField(),
# but in Oracle that will result in a list of datetime.date instead of
# datetime.datetime.
fields = [DateTimeField()]
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
return dates
for row in rows:
date = row[0]
if has_resolve_columns:
date = self.resolve_columns([date], fields)[0]
elif needs_datetime_string_cast:
date = typecast_timestamp(str(date))
dates.append(date)
self.query.add_filter(('%s__isnull' % self._field.name, True))
return self.query.results_iter()
def _clone(self, klass=None, **kwargs):
c = super(DateQuerySet, self)._clone(klass, **kwargs)
c._field = self._field
c._kind = self._kind
c._order = self._order
return c
# XXX; Everything below here is done.
class EmptyQuerySet(QuerySet):
def __init__(self, model=None):
super(EmptyQuerySet, self).__init__(model)
@ -517,6 +460,11 @@ class EmptyQuerySet(QuerySet):
c._result_cache = []
return c
def iterator(self):
# This slightly odd construction is because we need an empty generator
# (it should raise StopIteration immediately).
yield iter([]).next()
# QOperator, QAnd and QOr are temporarily retained for backwards compatibility.
# All the old functionality is now part of the 'Q' class.
class QOperator(Q):

View File

@ -57,3 +57,26 @@ class Count(Aggregate):
else:
return 'COUNT(%s)' % col
class Date(object):
"""
Add a date selection column.
"""
def __init__(self, col, lookup_type, date_sql_func):
self.col = col
self.lookup_type = lookup_type
self.date_sql_func= date_sql_func
def relabel_aliases(self, change_map):
c = self.col
if isinstance(c, (list, tuple)):
self.col = (change_map.get(c[0], c[0]), c[1])
def as_sql(self, quote_func=None):
if not quote_func:
quote_func = lambda x: x
if isinstance(self.col, (list, tuple)):
col = '%s.%s' % tuple([quote_func(c) for c in self.col])
else:
col = self.col
return self.date_sql_func(self.lookup_type, col)

View File

@ -11,8 +11,8 @@ import copy
from django.utils import tree
from django.db.models.sql.where import WhereNode, AND, OR
from django.db.models.sql.datastructures import Count
from django.db.models.fields import FieldDoesNotExist
from django.db.models.sql.datastructures import Count, Date
from django.db.models.fields import FieldDoesNotExist, Field
from django.contrib.contenttypes import generic
from datastructures import EmptyResultSet
from utils import handle_legacy_orderlist
@ -54,6 +54,7 @@ MULTI = 'multi'
SINGLE = 'single'
NONE = None
# FIXME: Add quote_name() calls around all the tables.
class Query(object):
"""
A single SQL query.
@ -77,8 +78,8 @@ class Query(object):
self.select = []
self.tables = [] # Aliases in the order they are created.
self.where = WhereNode(self)
self.having = []
self.group_by = []
self.having = []
self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False
@ -103,12 +104,14 @@ class Query(object):
sql, params = self.as_sql()
return sql % params
def clone(self, **kwargs):
def clone(self, klass=None, **kwargs):
"""
Creates a copy of the current instance. The 'kwargs' parameter can be
used by clients to update attributes after copying has taken place.
"""
obj = self.__class__(self.model, self.connection)
if not klass:
klass = self.__class__
obj = klass(self.model, self.connection)
obj.table_map = self.table_map.copy()
obj.alias_map = copy.deepcopy(self.alias_map)
obj.join_map = copy.deepcopy(self.join_map)
@ -198,8 +201,17 @@ class Query(object):
where, params = self.where.as_sql()
if where:
result.append('WHERE %s' % where)
if self.extra_where:
if not where:
result.append('WHERE')
else:
result.append('AND')
result.append(' AND'.join(self.extra_where))
if self.group_by:
grouping = self.get_grouping()
result.append('GROUP BY %s' % ', '.join(grouping))
ordering = self.get_ordering()
if ordering:
result.append('ORDER BY %s' % ', '.join(ordering))
@ -312,12 +324,12 @@ class Query(object):
"""
qn = self.connection.ops.quote_name
result = []
if self.select:
if self.select or self.extra_select:
for col in self.select:
if isinstance(col, (list, tuple)):
result.append('%s.%s' % (qn(col[0]), qn(col[1])))
else:
result.append(col.as_sql())
result.append(col.as_sql(quote_func=qn))
else:
table_alias = self.tables[0]
result = ['%s.%s' % (table_alias, qn(f.column))
@ -331,6 +343,21 @@ class Query(object):
for alias, col in extra_select])
return result
def get_grouping(self):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
qn = self.connection.ops.quote_name
result = []
for col in self.group_by:
if isinstance(col, (list, tuple)):
result.append('%s.%s' % (qn(col[0]), qn(col[1])))
elif hasattr(col, 'as_sql'):
result.append(col.as_sql(qn))
else:
result.append(str(col))
return result
def get_ordering(self):
"""
Returns a tuple representing the SQL elements in the "order by" clause.
@ -339,10 +366,18 @@ class Query(object):
qn = self.connection.ops.quote_name
opts = self.model._meta
result = []
for field in handle_legacy_orderlist(ordering):
for field in ordering:
if field == '?':
result.append(self.connection.ops.random_function_sql())
continue
if isinstance(field, int):
if field < 0:
order = 'DESC'
field = -field
else:
order = 'ASC'
result.append('%s %s' % (field, order))
continue
if field[0] == '-':
col = field[1:]
order = 'DESC'
@ -683,10 +718,28 @@ class Query(object):
"""
self.low_mark, self.high_mark = 0, None
def can_filter(self):
"""
Returns True if adding filters to this instance is still possible.
Typically, this means no limits or offsets have been put on the results.
"""
return not (self.low_mark or self.high_mark)
def add_local_columns(self, columns):
"""
Adds the given column names to the select set, assuming they come from
the root model (the one given in self.model).
"""
table = self.model._meta.db_table
self.select.extend([(table, col) for col in columns])
def add_ordering(self, *ordering):
"""
Adds items from the 'ordering' sequence to the query's "order by"
clause.
clause. These items are either field names (not column names) --
possibly with a direction prefix ('-' or '?') -- or ordinals,
corresponding to column positions in the 'select' list.
"""
self.order_by.extend(ordering)
@ -696,14 +749,6 @@ class Query(object):
"""
self.order_by = []
def can_filter(self):
"""
Returns True if adding filters to this instance is still possible.
Typically, this means no limits or offsets have been put on the results.
"""
return not (self.low_mark or self.high_mark)
def add_count_column(self):
"""
Converts the query to do count(*) or count(distinct(pk)) in order to
@ -713,12 +758,12 @@ class Query(object):
# that it doesn't totally overwrite the select list.
if not self.distinct:
select = Count()
# Distinct handling is now done in Count(), so don't do it at this
# level.
self.distinct = False
else:
select = Count((self.table_map[self.model._meta.db_table][0],
self.model._meta.pk.column), True)
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
self.select = [select]
self.extra_select = {}
@ -873,6 +918,47 @@ class UpdateQuery(Query):
values = [(related_field.column, 'NULL')]
self.do_query(self.model._meta.db_table, values, where)
class DateQuery(Query):
"""
A DateQuery is a normal query, except that it specifically selects a single
date field. This requires some special handling when converting the results
back to Python objects, so we put it in a separate class.
"""
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
from django.db.backends.util import typecast_timestamp
needs_string_cast = self.connection.features.needs_datetime_string_cast
for rows in self.execute_sql(MULTI):
for row in rows:
date = row[0]
if resolve_columns:
date = self.resolve_columns([date], fields)[0]
elif needs_string_cast:
date = typecast_timestamp(str(date))
yield date
def add_date_select(self, column, lookup_type, order='ASC'):
"""
Converts the query into a date extraction query.
"""
alias = self.join((None, self.model._meta.db_table, None, None))
select = Date((alias, column), lookup_type,
self.connection.ops.date_trunc_sql)
self.select = [select]
self.order_by = order == 'ASC' and [1] or [-1]
if self.connection.features.allows_group_by_ordinal:
self.group_by = [1]
else:
self.group_by = [select]
def find_field(name, field_list, related_query):
"""
Finds a field with a specific name in a list of field instances.