From 988b3bbdcb52c4e20551fc8936369a7f176f48bc Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Sun, 14 Oct 2007 02:12:40 +0000 Subject: [PATCH] queryset-refactor: Ported DateQuerySet and ValueQuerySet over and fixed most of the related tests. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6486 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/base.py | 12 ++- django/db/models/query.py | 106 ++++++-------------- django/db/models/sql/datastructures.py | 23 +++++ django/db/models/sql/query.py | 128 +++++++++++++++++++++---- 4 files changed, 164 insertions(+), 105 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index e88eda6dd0..379ed898f6 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -338,13 +338,15 @@ class Model(object): def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs): qn = connection.ops.quote_name op = is_next and '>' or '<' - where = '(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \ + where = ['(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \ (qn(field.column), op, qn(field.column), - qn(self._meta.db_table), qn(self._meta.pk.column), op) + qn(self._meta.db_table), qn(self._meta.pk.column), op)] param = smart_str(getattr(self, field.attname)) - q = self.__class__._default_manager.filter(**kwargs).order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name) - q.extra(where=where, params=[param, param, - getattr(self, self._meta.pk.attname)]) + order_char = not is_next and '-' or '' + q = self.__class__._default_manager.filter(**kwargs).order_by( + order_char + field.name, order_char + self._meta.pk.name) + q = q.extra(where=where, params=[param, param, + getattr(self, self._meta.pk.attname)]) try: return q[0] except IndexError: diff --git a/django/db/models/query.py b/django/db/models/query.py index 7e53397161..ca85403f7e 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -253,7 +253,6 @@ class _QuerySet(object): def values(self, *fields): return self._clone(klass=ValuesQuerySet, _fields=fields) - # FIXME: Not converted yet! def dates(self, field_name, kind, order='ASC'): """ Returns a list of datetime objects representing all available dates @@ -265,8 +264,10 @@ class _QuerySet(object): "'order' must be either 'ASC' or 'DESC'." # Let the FieldDoesNotExist exception propagate. field = self.model._meta.get_field(field_name, many_to_many=False) - assert isinstance(field, DateField), "%r isn't a DateField." % field_name - return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order) + assert isinstance(field, DateField), "%r isn't a DateField." \ + % field_name + return self._clone(klass=DateQuerySet, _field=field, _kind=kind, + _order=order) ################################################################## # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET # @@ -389,16 +390,8 @@ class ValuesQuerySet(QuerySet): self.query.select_related = False def iterator(self): - try: - select, sql, params = self._get_sql_clause() - except EmptyResultSet: - raise StopIteration - - qn = connection.ops.quote_name - - # self._select is a dictionary, and dictionaries' key order is - # undefined, so we convert it to a list of tuples. - extra_select = self._select.items() + extra_select = self.query.extra_select.keys() + extra_select.sort() # Construct two objects -- fields and field_names. # fields is a list of Field objects to fetch. @@ -406,39 +399,30 @@ class ValuesQuerySet(QuerySet): # resulting dictionaries. if self._fields: if not extra_select: - fields = [self.model._meta.get_field(f, many_to_many=False) for f in self._fields] + fields = [self.model._meta.get_field(f, many_to_many=False) + for f in self._fields] field_names = self._fields else: fields = [] field_names = [] for f in self._fields: if f in [field.name for field in self.model._meta.fields]: - fields.append(self.model._meta.get_field(f, many_to_many=False)) + fields.append(self.model._meta.get_field(f, + many_to_many=False)) field_names.append(f) - elif not self._select.has_key(f): - raise FieldDoesNotExist('%s has no field named %r' % (self.model._meta.object_name, f)) + elif not self.query.extra_select.has_key(f): + raise FieldDoesNotExist('%s has no field named %r' + % (self.model._meta.object_name, f)) else: # Default to all fields. fields = self.model._meta.fields field_names = [f.attname for f in fields] - columns = [f.column for f in fields] - select = ['%s.%s' % (qn(self.model._meta.db_table), qn(c)) for c in columns] + self.query.add_local_columns([f.column for f in fields]) if extra_select: - select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in extra_select]) - field_names.extend([f[0] for f in extra_select]) + field_names.extend([f for f in extra_select]) - cursor = connection.cursor() - cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) - - has_resolve_columns = hasattr(self, 'resolve_columns') - while 1: - rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) - if not rows: - raise StopIteration - for row in rows: - if has_resolve_columns: - row = self.resolve_columns(row, fields) - yield dict(zip(field_names, row)) + for row in self.query.results_iter(): + yield dict(zip(field_names, row)) def _clone(self, klass=None, **kwargs): c = super(ValuesQuerySet, self)._clone(klass, **kwargs) @@ -447,60 +431,19 @@ class ValuesQuerySet(QuerySet): class DateQuerySet(QuerySet): def iterator(self): - from django.db.backends.util import typecast_timestamp - from django.db.models.fields import DateTimeField - - qn = connection.ops.quote_name - self._order_by = () # Clear this because it'll mess things up otherwise. + self.query = self.query.clone(klass=sql.DateQuery) + self.query.select = [] + self.query.add_date_select(self._field.column, self._kind, self._order) if self._field.null: - self._where.append('%s.%s IS NOT NULL' % \ - (qn(self.model._meta.db_table), qn(self._field.column))) - try: - select, sql, params = self._get_sql_clause() - except EmptyResultSet: - raise StopIteration - - table_name = qn(self.model._meta.db_table) - field_name = qn(self._field.column) - - if connection.features.allows_group_by_ordinal: - group_by = '1' - else: - group_by = connection.ops.date_trunc_sql(self._kind, '%s.%s' % (table_name, field_name)) - - sql = 'SELECT %s %s GROUP BY %s ORDER BY 1 %s' % \ - (connection.ops.date_trunc_sql(self._kind, '%s.%s' % (qn(self.model._meta.db_table), - qn(self._field.column))), sql, group_by, self._order) - cursor = connection.cursor() - cursor.execute(sql, params) - - has_resolve_columns = hasattr(self, 'resolve_columns') - needs_datetime_string_cast = connection.features.needs_datetime_string_cast - dates = [] - # It would be better to use self._field here instead of DateTimeField(), - # but in Oracle that will result in a list of datetime.date instead of - # datetime.datetime. - fields = [DateTimeField()] - while 1: - rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) - if not rows: - return dates - for row in rows: - date = row[0] - if has_resolve_columns: - date = self.resolve_columns([date], fields)[0] - elif needs_datetime_string_cast: - date = typecast_timestamp(str(date)) - dates.append(date) + self.query.add_filter(('%s__isnull' % self._field.name, True)) + return self.query.results_iter() def _clone(self, klass=None, **kwargs): c = super(DateQuerySet, self)._clone(klass, **kwargs) c._field = self._field c._kind = self._kind - c._order = self._order return c -# XXX; Everything below here is done. class EmptyQuerySet(QuerySet): def __init__(self, model=None): super(EmptyQuerySet, self).__init__(model) @@ -517,6 +460,11 @@ class EmptyQuerySet(QuerySet): c._result_cache = [] return c + def iterator(self): + # This slightly odd construction is because we need an empty generator + # (it should raise StopIteration immediately). + yield iter([]).next() + # QOperator, QAnd and QOr are temporarily retained for backwards compatibility. # All the old functionality is now part of the 'Q' class. class QOperator(Q): diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 5b01abbb87..46411599c8 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -57,3 +57,26 @@ class Count(Aggregate): else: return 'COUNT(%s)' % col +class Date(object): + """ + Add a date selection column. + """ + def __init__(self, col, lookup_type, date_sql_func): + self.col = col + self.lookup_type = lookup_type + self.date_sql_func= date_sql_func + + def relabel_aliases(self, change_map): + c = self.col + if isinstance(c, (list, tuple)): + self.col = (change_map.get(c[0], c[0]), c[1]) + + def as_sql(self, quote_func=None): + if not quote_func: + quote_func = lambda x: x + if isinstance(self.col, (list, tuple)): + col = '%s.%s' % tuple([quote_func(c) for c in self.col]) + else: + col = self.col + return self.date_sql_func(self.lookup_type, col) + diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index eaa5f11d35..db3ece3047 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -11,8 +11,8 @@ import copy from django.utils import tree from django.db.models.sql.where import WhereNode, AND, OR -from django.db.models.sql.datastructures import Count -from django.db.models.fields import FieldDoesNotExist +from django.db.models.sql.datastructures import Count, Date +from django.db.models.fields import FieldDoesNotExist, Field from django.contrib.contenttypes import generic from datastructures import EmptyResultSet from utils import handle_legacy_orderlist @@ -54,6 +54,7 @@ MULTI = 'multi' SINGLE = 'single' NONE = None +# FIXME: Add quote_name() calls around all the tables. class Query(object): """ A single SQL query. @@ -77,8 +78,8 @@ class Query(object): self.select = [] self.tables = [] # Aliases in the order they are created. self.where = WhereNode(self) - self.having = [] self.group_by = [] + self.having = [] self.order_by = [] self.low_mark, self.high_mark = 0, None # Used for offset/limit self.distinct = False @@ -103,12 +104,14 @@ class Query(object): sql, params = self.as_sql() return sql % params - def clone(self, **kwargs): + def clone(self, klass=None, **kwargs): """ Creates a copy of the current instance. The 'kwargs' parameter can be used by clients to update attributes after copying has taken place. """ - obj = self.__class__(self.model, self.connection) + if not klass: + klass = self.__class__ + obj = klass(self.model, self.connection) obj.table_map = self.table_map.copy() obj.alias_map = copy.deepcopy(self.alias_map) obj.join_map = copy.deepcopy(self.join_map) @@ -198,7 +201,16 @@ class Query(object): where, params = self.where.as_sql() if where: result.append('WHERE %s' % where) - result.append(' AND'.join(self.extra_where)) + if self.extra_where: + if not where: + result.append('WHERE') + else: + result.append('AND') + result.append(' AND'.join(self.extra_where)) + + if self.group_by: + grouping = self.get_grouping() + result.append('GROUP BY %s' % ', '.join(grouping)) ordering = self.get_ordering() if ordering: @@ -312,12 +324,12 @@ class Query(object): """ qn = self.connection.ops.quote_name result = [] - if self.select: + if self.select or self.extra_select: for col in self.select: if isinstance(col, (list, tuple)): result.append('%s.%s' % (qn(col[0]), qn(col[1]))) else: - result.append(col.as_sql()) + result.append(col.as_sql(quote_func=qn)) else: table_alias = self.tables[0] result = ['%s.%s' % (table_alias, qn(f.column)) @@ -331,6 +343,21 @@ class Query(object): for alias, col in extra_select]) return result + def get_grouping(self): + """ + Returns a tuple representing the SQL elements in the "group by" clause. + """ + qn = self.connection.ops.quote_name + result = [] + for col in self.group_by: + if isinstance(col, (list, tuple)): + result.append('%s.%s' % (qn(col[0]), qn(col[1]))) + elif hasattr(col, 'as_sql'): + result.append(col.as_sql(qn)) + else: + result.append(str(col)) + return result + def get_ordering(self): """ Returns a tuple representing the SQL elements in the "order by" clause. @@ -339,10 +366,18 @@ class Query(object): qn = self.connection.ops.quote_name opts = self.model._meta result = [] - for field in handle_legacy_orderlist(ordering): + for field in ordering: if field == '?': result.append(self.connection.ops.random_function_sql()) continue + if isinstance(field, int): + if field < 0: + order = 'DESC' + field = -field + else: + order = 'ASC' + result.append('%s %s' % (field, order)) + continue if field[0] == '-': col = field[1:] order = 'DESC' @@ -683,10 +718,28 @@ class Query(object): """ self.low_mark, self.high_mark = 0, None + def can_filter(self): + """ + Returns True if adding filters to this instance is still possible. + + Typically, this means no limits or offsets have been put on the results. + """ + return not (self.low_mark or self.high_mark) + + def add_local_columns(self, columns): + """ + Adds the given column names to the select set, assuming they come from + the root model (the one given in self.model). + """ + table = self.model._meta.db_table + self.select.extend([(table, col) for col in columns]) + def add_ordering(self, *ordering): """ Adds items from the 'ordering' sequence to the query's "order by" - clause. + clause. These items are either field names (not column names) -- + possibly with a direction prefix ('-' or '?') -- or ordinals, + corresponding to column positions in the 'select' list. """ self.order_by.extend(ordering) @@ -696,14 +749,6 @@ class Query(object): """ self.order_by = [] - def can_filter(self): - """ - Returns True if adding filters to this instance is still possible. - - Typically, this means no limits or offsets have been put on the results. - """ - return not (self.low_mark or self.high_mark) - def add_count_column(self): """ Converts the query to do count(*) or count(distinct(pk)) in order to @@ -713,12 +758,12 @@ class Query(object): # that it doesn't totally overwrite the select list. if not self.distinct: select = Count() - # Distinct handling is now done in Count(), so don't do it at this - # level. - self.distinct = False else: select = Count((self.table_map[self.model._meta.db_table][0], self.model._meta.pk.column), True) + # Distinct handling is done in Count(), so don't do it at this + # level. + self.distinct = False self.select = [select] self.extra_select = {} @@ -873,6 +918,47 @@ class UpdateQuery(Query): values = [(related_field.column, 'NULL')] self.do_query(self.model._meta.db_table, values, where) +class DateQuery(Query): + """ + A DateQuery is a normal query, except that it specifically selects a single + date field. This requires some special handling when converting the results + back to Python objects, so we put it in a separate class. + """ + def results_iter(self): + """ + Returns an iterator over the results from executing this query. + """ + resolve_columns = hasattr(self, 'resolve_columns') + if resolve_columns: + from django.db.models.fields import DateTimeField + fields = [DateTimeField()] + else: + from django.db.backends.util import typecast_timestamp + needs_string_cast = self.connection.features.needs_datetime_string_cast + + for rows in self.execute_sql(MULTI): + for row in rows: + date = row[0] + if resolve_columns: + date = self.resolve_columns([date], fields)[0] + elif needs_string_cast: + date = typecast_timestamp(str(date)) + yield date + + def add_date_select(self, column, lookup_type, order='ASC'): + """ + Converts the query into a date extraction query. + """ + alias = self.join((None, self.model._meta.db_table, None, None)) + select = Date((alias, column), lookup_type, + self.connection.ops.date_trunc_sql) + self.select = [select] + self.order_by = order == 'ASC' and [1] or [-1] + if self.connection.features.allows_group_by_ordinal: + self.group_by = [1] + else: + self.group_by = [select] + def find_field(name, field_list, related_query): """ Finds a field with a specific name in a list of field instances.