1
0
mirror of https://github.com/django/django.git synced 2025-07-06 18:59:13 +00:00

queryset-refactor: Ported DateQuerySet and ValueQuerySet over and fixed most of

the related tests.


git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6486 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2007-10-14 02:12:40 +00:00
parent bcdedbbf08
commit 988b3bbdcb
4 changed files with 164 additions and 105 deletions

View File

@ -338,13 +338,15 @@ class Model(object):
def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs): def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
qn = connection.ops.quote_name qn = connection.ops.quote_name
op = is_next and '>' or '<' op = is_next and '>' or '<'
where = '(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \ where = ['(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \
(qn(field.column), op, qn(field.column), (qn(field.column), op, qn(field.column),
qn(self._meta.db_table), qn(self._meta.pk.column), op) qn(self._meta.db_table), qn(self._meta.pk.column), op)]
param = smart_str(getattr(self, field.attname)) param = smart_str(getattr(self, field.attname))
q = self.__class__._default_manager.filter(**kwargs).order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name) order_char = not is_next and '-' or ''
q.extra(where=where, params=[param, param, q = self.__class__._default_manager.filter(**kwargs).order_by(
getattr(self, self._meta.pk.attname)]) order_char + field.name, order_char + self._meta.pk.name)
q = q.extra(where=where, params=[param, param,
getattr(self, self._meta.pk.attname)])
try: try:
return q[0] return q[0]
except IndexError: except IndexError:

View File

@ -253,7 +253,6 @@ class _QuerySet(object):
def values(self, *fields): def values(self, *fields):
return self._clone(klass=ValuesQuerySet, _fields=fields) return self._clone(klass=ValuesQuerySet, _fields=fields)
# FIXME: Not converted yet!
def dates(self, field_name, kind, order='ASC'): def dates(self, field_name, kind, order='ASC'):
""" """
Returns a list of datetime objects representing all available dates Returns a list of datetime objects representing all available dates
@ -265,8 +264,10 @@ class _QuerySet(object):
"'order' must be either 'ASC' or 'DESC'." "'order' must be either 'ASC' or 'DESC'."
# Let the FieldDoesNotExist exception propagate. # Let the FieldDoesNotExist exception propagate.
field = self.model._meta.get_field(field_name, many_to_many=False) field = self.model._meta.get_field(field_name, many_to_many=False)
assert isinstance(field, DateField), "%r isn't a DateField." % field_name assert isinstance(field, DateField), "%r isn't a DateField." \
return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order) % field_name
return self._clone(klass=DateQuerySet, _field=field, _kind=kind,
_order=order)
################################################################## ##################################################################
# PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET # # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
@ -389,16 +390,8 @@ class ValuesQuerySet(QuerySet):
self.query.select_related = False self.query.select_related = False
def iterator(self): def iterator(self):
try: extra_select = self.query.extra_select.keys()
select, sql, params = self._get_sql_clause() extra_select.sort()
except EmptyResultSet:
raise StopIteration
qn = connection.ops.quote_name
# self._select is a dictionary, and dictionaries' key order is
# undefined, so we convert it to a list of tuples.
extra_select = self._select.items()
# Construct two objects -- fields and field_names. # Construct two objects -- fields and field_names.
# fields is a list of Field objects to fetch. # fields is a list of Field objects to fetch.
@ -406,39 +399,30 @@ class ValuesQuerySet(QuerySet):
# resulting dictionaries. # resulting dictionaries.
if self._fields: if self._fields:
if not extra_select: if not extra_select:
fields = [self.model._meta.get_field(f, many_to_many=False) for f in self._fields] fields = [self.model._meta.get_field(f, many_to_many=False)
for f in self._fields]
field_names = self._fields field_names = self._fields
else: else:
fields = [] fields = []
field_names = [] field_names = []
for f in self._fields: for f in self._fields:
if f in [field.name for field in self.model._meta.fields]: if f in [field.name for field in self.model._meta.fields]:
fields.append(self.model._meta.get_field(f, many_to_many=False)) fields.append(self.model._meta.get_field(f,
many_to_many=False))
field_names.append(f) field_names.append(f)
elif not self._select.has_key(f): elif not self.query.extra_select.has_key(f):
raise FieldDoesNotExist('%s has no field named %r' % (self.model._meta.object_name, f)) raise FieldDoesNotExist('%s has no field named %r'
% (self.model._meta.object_name, f))
else: # Default to all fields. else: # Default to all fields.
fields = self.model._meta.fields fields = self.model._meta.fields
field_names = [f.attname for f in fields] field_names = [f.attname for f in fields]
columns = [f.column for f in fields] self.query.add_local_columns([f.column for f in fields])
select = ['%s.%s' % (qn(self.model._meta.db_table), qn(c)) for c in columns]
if extra_select: if extra_select:
select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in extra_select]) field_names.extend([f for f in extra_select])
field_names.extend([f[0] for f in extra_select])
cursor = connection.cursor() for row in self.query.results_iter():
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) yield dict(zip(field_names, row))
has_resolve_columns = hasattr(self, 'resolve_columns')
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
raise StopIteration
for row in rows:
if has_resolve_columns:
row = self.resolve_columns(row, fields)
yield dict(zip(field_names, row))
def _clone(self, klass=None, **kwargs): def _clone(self, klass=None, **kwargs):
c = super(ValuesQuerySet, self)._clone(klass, **kwargs) c = super(ValuesQuerySet, self)._clone(klass, **kwargs)
@ -447,60 +431,19 @@ class ValuesQuerySet(QuerySet):
class DateQuerySet(QuerySet): class DateQuerySet(QuerySet):
def iterator(self): def iterator(self):
from django.db.backends.util import typecast_timestamp self.query = self.query.clone(klass=sql.DateQuery)
from django.db.models.fields import DateTimeField self.query.select = []
self.query.add_date_select(self._field.column, self._kind, self._order)
qn = connection.ops.quote_name
self._order_by = () # Clear this because it'll mess things up otherwise.
if self._field.null: if self._field.null:
self._where.append('%s.%s IS NOT NULL' % \ self.query.add_filter(('%s__isnull' % self._field.name, True))
(qn(self.model._meta.db_table), qn(self._field.column))) return self.query.results_iter()
try:
select, sql, params = self._get_sql_clause()
except EmptyResultSet:
raise StopIteration
table_name = qn(self.model._meta.db_table)
field_name = qn(self._field.column)
if connection.features.allows_group_by_ordinal:
group_by = '1'
else:
group_by = connection.ops.date_trunc_sql(self._kind, '%s.%s' % (table_name, field_name))
sql = 'SELECT %s %s GROUP BY %s ORDER BY 1 %s' % \
(connection.ops.date_trunc_sql(self._kind, '%s.%s' % (qn(self.model._meta.db_table),
qn(self._field.column))), sql, group_by, self._order)
cursor = connection.cursor()
cursor.execute(sql, params)
has_resolve_columns = hasattr(self, 'resolve_columns')
needs_datetime_string_cast = connection.features.needs_datetime_string_cast
dates = []
# It would be better to use self._field here instead of DateTimeField(),
# but in Oracle that will result in a list of datetime.date instead of
# datetime.datetime.
fields = [DateTimeField()]
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
return dates
for row in rows:
date = row[0]
if has_resolve_columns:
date = self.resolve_columns([date], fields)[0]
elif needs_datetime_string_cast:
date = typecast_timestamp(str(date))
dates.append(date)
def _clone(self, klass=None, **kwargs): def _clone(self, klass=None, **kwargs):
c = super(DateQuerySet, self)._clone(klass, **kwargs) c = super(DateQuerySet, self)._clone(klass, **kwargs)
c._field = self._field c._field = self._field
c._kind = self._kind c._kind = self._kind
c._order = self._order
return c return c
# XXX; Everything below here is done.
class EmptyQuerySet(QuerySet): class EmptyQuerySet(QuerySet):
def __init__(self, model=None): def __init__(self, model=None):
super(EmptyQuerySet, self).__init__(model) super(EmptyQuerySet, self).__init__(model)
@ -517,6 +460,11 @@ class EmptyQuerySet(QuerySet):
c._result_cache = [] c._result_cache = []
return c return c
def iterator(self):
# This slightly odd construction is because we need an empty generator
# (it should raise StopIteration immediately).
yield iter([]).next()
# QOperator, QAnd and QOr are temporarily retained for backwards compatibility. # QOperator, QAnd and QOr are temporarily retained for backwards compatibility.
# All the old functionality is now part of the 'Q' class. # All the old functionality is now part of the 'Q' class.
class QOperator(Q): class QOperator(Q):

View File

@ -57,3 +57,26 @@ class Count(Aggregate):
else: else:
return 'COUNT(%s)' % col return 'COUNT(%s)' % col
class Date(object):
"""
Add a date selection column.
"""
def __init__(self, col, lookup_type, date_sql_func):
self.col = col
self.lookup_type = lookup_type
self.date_sql_func= date_sql_func
def relabel_aliases(self, change_map):
c = self.col
if isinstance(c, (list, tuple)):
self.col = (change_map.get(c[0], c[0]), c[1])
def as_sql(self, quote_func=None):
if not quote_func:
quote_func = lambda x: x
if isinstance(self.col, (list, tuple)):
col = '%s.%s' % tuple([quote_func(c) for c in self.col])
else:
col = self.col
return self.date_sql_func(self.lookup_type, col)

View File

@ -11,8 +11,8 @@ import copy
from django.utils import tree from django.utils import tree
from django.db.models.sql.where import WhereNode, AND, OR from django.db.models.sql.where import WhereNode, AND, OR
from django.db.models.sql.datastructures import Count from django.db.models.sql.datastructures import Count, Date
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist, Field
from django.contrib.contenttypes import generic from django.contrib.contenttypes import generic
from datastructures import EmptyResultSet from datastructures import EmptyResultSet
from utils import handle_legacy_orderlist from utils import handle_legacy_orderlist
@ -54,6 +54,7 @@ MULTI = 'multi'
SINGLE = 'single' SINGLE = 'single'
NONE = None NONE = None
# FIXME: Add quote_name() calls around all the tables.
class Query(object): class Query(object):
""" """
A single SQL query. A single SQL query.
@ -77,8 +78,8 @@ class Query(object):
self.select = [] self.select = []
self.tables = [] # Aliases in the order they are created. self.tables = [] # Aliases in the order they are created.
self.where = WhereNode(self) self.where = WhereNode(self)
self.having = []
self.group_by = [] self.group_by = []
self.having = []
self.order_by = [] self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False self.distinct = False
@ -103,12 +104,14 @@ class Query(object):
sql, params = self.as_sql() sql, params = self.as_sql()
return sql % params return sql % params
def clone(self, **kwargs): def clone(self, klass=None, **kwargs):
""" """
Creates a copy of the current instance. The 'kwargs' parameter can be Creates a copy of the current instance. The 'kwargs' parameter can be
used by clients to update attributes after copying has taken place. used by clients to update attributes after copying has taken place.
""" """
obj = self.__class__(self.model, self.connection) if not klass:
klass = self.__class__
obj = klass(self.model, self.connection)
obj.table_map = self.table_map.copy() obj.table_map = self.table_map.copy()
obj.alias_map = copy.deepcopy(self.alias_map) obj.alias_map = copy.deepcopy(self.alias_map)
obj.join_map = copy.deepcopy(self.join_map) obj.join_map = copy.deepcopy(self.join_map)
@ -198,7 +201,16 @@ class Query(object):
where, params = self.where.as_sql() where, params = self.where.as_sql()
if where: if where:
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
result.append(' AND'.join(self.extra_where)) if self.extra_where:
if not where:
result.append('WHERE')
else:
result.append('AND')
result.append(' AND'.join(self.extra_where))
if self.group_by:
grouping = self.get_grouping()
result.append('GROUP BY %s' % ', '.join(grouping))
ordering = self.get_ordering() ordering = self.get_ordering()
if ordering: if ordering:
@ -312,12 +324,12 @@ class Query(object):
""" """
qn = self.connection.ops.quote_name qn = self.connection.ops.quote_name
result = [] result = []
if self.select: if self.select or self.extra_select:
for col in self.select: for col in self.select:
if isinstance(col, (list, tuple)): if isinstance(col, (list, tuple)):
result.append('%s.%s' % (qn(col[0]), qn(col[1]))) result.append('%s.%s' % (qn(col[0]), qn(col[1])))
else: else:
result.append(col.as_sql()) result.append(col.as_sql(quote_func=qn))
else: else:
table_alias = self.tables[0] table_alias = self.tables[0]
result = ['%s.%s' % (table_alias, qn(f.column)) result = ['%s.%s' % (table_alias, qn(f.column))
@ -331,6 +343,21 @@ class Query(object):
for alias, col in extra_select]) for alias, col in extra_select])
return result return result
def get_grouping(self):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
qn = self.connection.ops.quote_name
result = []
for col in self.group_by:
if isinstance(col, (list, tuple)):
result.append('%s.%s' % (qn(col[0]), qn(col[1])))
elif hasattr(col, 'as_sql'):
result.append(col.as_sql(qn))
else:
result.append(str(col))
return result
def get_ordering(self): def get_ordering(self):
""" """
Returns a tuple representing the SQL elements in the "order by" clause. Returns a tuple representing the SQL elements in the "order by" clause.
@ -339,10 +366,18 @@ class Query(object):
qn = self.connection.ops.quote_name qn = self.connection.ops.quote_name
opts = self.model._meta opts = self.model._meta
result = [] result = []
for field in handle_legacy_orderlist(ordering): for field in ordering:
if field == '?': if field == '?':
result.append(self.connection.ops.random_function_sql()) result.append(self.connection.ops.random_function_sql())
continue continue
if isinstance(field, int):
if field < 0:
order = 'DESC'
field = -field
else:
order = 'ASC'
result.append('%s %s' % (field, order))
continue
if field[0] == '-': if field[0] == '-':
col = field[1:] col = field[1:]
order = 'DESC' order = 'DESC'
@ -683,10 +718,28 @@ class Query(object):
""" """
self.low_mark, self.high_mark = 0, None self.low_mark, self.high_mark = 0, None
def can_filter(self):
"""
Returns True if adding filters to this instance is still possible.
Typically, this means no limits or offsets have been put on the results.
"""
return not (self.low_mark or self.high_mark)
def add_local_columns(self, columns):
"""
Adds the given column names to the select set, assuming they come from
the root model (the one given in self.model).
"""
table = self.model._meta.db_table
self.select.extend([(table, col) for col in columns])
def add_ordering(self, *ordering): def add_ordering(self, *ordering):
""" """
Adds items from the 'ordering' sequence to the query's "order by" Adds items from the 'ordering' sequence to the query's "order by"
clause. clause. These items are either field names (not column names) --
possibly with a direction prefix ('-' or '?') -- or ordinals,
corresponding to column positions in the 'select' list.
""" """
self.order_by.extend(ordering) self.order_by.extend(ordering)
@ -696,14 +749,6 @@ class Query(object):
""" """
self.order_by = [] self.order_by = []
def can_filter(self):
"""
Returns True if adding filters to this instance is still possible.
Typically, this means no limits or offsets have been put on the results.
"""
return not (self.low_mark or self.high_mark)
def add_count_column(self): def add_count_column(self):
""" """
Converts the query to do count(*) or count(distinct(pk)) in order to Converts the query to do count(*) or count(distinct(pk)) in order to
@ -713,12 +758,12 @@ class Query(object):
# that it doesn't totally overwrite the select list. # that it doesn't totally overwrite the select list.
if not self.distinct: if not self.distinct:
select = Count() select = Count()
# Distinct handling is now done in Count(), so don't do it at this
# level.
self.distinct = False
else: else:
select = Count((self.table_map[self.model._meta.db_table][0], select = Count((self.table_map[self.model._meta.db_table][0],
self.model._meta.pk.column), True) self.model._meta.pk.column), True)
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
self.select = [select] self.select = [select]
self.extra_select = {} self.extra_select = {}
@ -873,6 +918,47 @@ class UpdateQuery(Query):
values = [(related_field.column, 'NULL')] values = [(related_field.column, 'NULL')]
self.do_query(self.model._meta.db_table, values, where) self.do_query(self.model._meta.db_table, values, where)
class DateQuery(Query):
"""
A DateQuery is a normal query, except that it specifically selects a single
date field. This requires some special handling when converting the results
back to Python objects, so we put it in a separate class.
"""
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
from django.db.backends.util import typecast_timestamp
needs_string_cast = self.connection.features.needs_datetime_string_cast
for rows in self.execute_sql(MULTI):
for row in rows:
date = row[0]
if resolve_columns:
date = self.resolve_columns([date], fields)[0]
elif needs_string_cast:
date = typecast_timestamp(str(date))
yield date
def add_date_select(self, column, lookup_type, order='ASC'):
"""
Converts the query into a date extraction query.
"""
alias = self.join((None, self.model._meta.db_table, None, None))
select = Date((alias, column), lookup_type,
self.connection.ops.date_trunc_sql)
self.select = [select]
self.order_by = order == 'ASC' and [1] or [-1]
if self.connection.features.allows_group_by_ordinal:
self.group_by = [1]
else:
self.group_by = [select]
def find_field(name, field_list, related_query): def find_field(name, field_list, related_query):
""" """
Finds a field with a specific name in a list of field instances. Finds a field with a specific name in a list of field instances.