From 0ebb752e89b312436b424065da05f2b1f8e22a22 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Sun, 14 Oct 2007 02:16:38 +0000 Subject: [PATCH] queryset-refactor: Made all the changes needed to have count() work properly with ValuesQuerySet. This is the general case of #2939. At this point, all the existing tests now pass on the branch (except for Oracle). It's a bit slower than before, though, and there are still a bunch of known bugs that aren't in the tests (or only exercised for some backends). git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6497 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 73 +++++++++++++++------ django/db/models/sql/datastructures.py | 10 +-- django/db/models/sql/query.py | 85 +++++++++++++++++-------- tests/regressiontests/queries/models.py | 15 +++-- 4 files changed, 128 insertions(+), 55 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index f7371ad1da..66837921fe 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -251,7 +251,7 @@ class _QuerySet(object): ################################################## def values(self, *fields): - return self._clone(klass=ValuesQuerySet, _fields=fields) + return self._clone(klass=ValuesQuerySet, setup=True, _fields=fields) def dates(self, field_name, kind, order='ASC'): """ @@ -266,8 +266,8 @@ class _QuerySet(object): field = self.model._meta.get_field(field_name, many_to_many=False) assert isinstance(field, DateField), "%r isn't a DateField." \ % field_name - return self._clone(klass=DateQuerySet, _field=field, _kind=kind, - _order=order) + return self._clone(klass=DateQuerySet, setup=True, _field=field, + _kind=kind, _order=order) ################################################################## # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET # @@ -363,13 +363,15 @@ class _QuerySet(object): # PRIVATE METHODS # ################### - def _clone(self, klass=None, **kwargs): + def _clone(self, klass=None, setup=False, **kwargs): if klass is None: klass = self.__class__ c = klass() c.model = self.model c.query = self.query.clone() c.__dict__.update(kwargs) + if setup and hasattr(c, '_setup_query'): + c._setup_query() return c def _get_data(self): @@ -389,16 +391,33 @@ class ValuesQuerySet(QuerySet): # select_related isn't supported in values(). self.query.select_related = False + # QuerySet.clone() will also set up the _fields attribute with the + # names of the model fields to select. + def iterator(self): extra_select = self.query.extra_select.keys() extra_select.sort() + if extra_select: + self.field_names.extend([f for f in extra_select]) - # Construct two objects -- fields and field_names. - # fields is a list of Field objects to fetch. - # field_names is a list of field names, which will be the keys in the - # resulting dictionaries. + for row in self.query.results_iter(): + yield dict(zip(self.field_names, row)) + + def _setup_query(self): + """ + Sets up any special features of the query attribute. + + Called by the _clone() method after initialising the rest of the + instance. + """ + # Construct two objects: + # - fields is a list of Field objects to fetch. + # - field_names is a list of field names, which will be the keys in + # the resulting dictionaries. + # 'fields' is used to configure the query, whilst field_names is stored + # in this object for use by iterator(). if self._fields: - if not extra_select: + if not self.query.extra_select: fields = [self.model._meta.get_field(f, many_to_many=False) for f in self._fields] field_names = self._fields @@ -418,30 +437,42 @@ class ValuesQuerySet(QuerySet): field_names = [f.attname for f in fields] self.query.add_local_columns([f.column for f in fields]) - if extra_select: - field_names.extend([f for f in extra_select]) + self.field_names = field_names - for row in self.query.results_iter(): - yield dict(zip(field_names, row)) - - def _clone(self, klass=None, **kwargs): + def _clone(self, klass=None, setup=False, **kwargs): + """ + Cloning a ValuesQuerySet preserves the current fields. + """ c = super(ValuesQuerySet, self)._clone(klass, **kwargs) c._fields = self._fields[:] + c.field_names = self.field_names[:] + if setup and hasattr(c, '_setup_query'): + c._setup_query() return c class DateQuerySet(QuerySet): def iterator(self): - self.query = self.query.clone(klass=sql.DateQuery) + return self.query.results_iter() + + def _setup_query(self): + """ + Sets up any special features of the query attribute. + + Called by the _clone() method after initialising the rest of the + instance. + """ + self.query = self.query.clone(klass=sql.DateQuery, setup=True) self.query.select = [] self.query.add_date_select(self._field.column, self._kind, self._order) if self._field.null: self.query.add_filter(('%s__isnull' % self._field.name, True)) - return self.query.results_iter() - def _clone(self, klass=None, **kwargs): - c = super(DateQuerySet, self)._clone(klass, **kwargs) + def _clone(self, klass=None, setup=False, **kwargs): + c = super(DateQuerySet, self)._clone(klass, False, **kwargs) c._field = self._field c._kind = self._kind + if setup and hasattr(c, '_setup_query'): + c._setup_query() return c class EmptyQuerySet(QuerySet): @@ -455,14 +486,14 @@ class EmptyQuerySet(QuerySet): def delete(self): pass - def _clone(self, klass=None, **kwargs): + def _clone(self, klass=None, setup=False, **kwargs): c = super(EmptyQuerySet, self)._clone(klass, **kwargs) c._result_cache = [] return c def iterator(self): # This slightly odd construction is because we need an empty generator - # (it should raise StopIteration immediately). + # (it raises StopIteration immediately). yield iter([]).next() # QOperator, QAnd and QOr are temporarily retained for backwards compatibility. diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 46411599c8..af3db23e02 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -32,12 +32,12 @@ class Count(Aggregate): """ Perform a count on the given column. """ - def __init__(self, col=None, distinct=False): + def __init__(self, col='*', distinct=False): """ Set the column to count on (defaults to '*') and set whether the count should be distinct or not. """ - self.col = col and col or '*' + self.col = col self.distinct = distinct def relabel_aliases(self, change_map): @@ -49,13 +49,13 @@ class Count(Aggregate): if not quote_func: quote_func = lambda x: x if isinstance(self.col, (list, tuple)): - col = '%s.%s' % tuple([quote_func(c) for c in self.col]) + col = ('%s.%s' % tuple([quote_func(c) for c in self.col])) else: col = self.col if self.distinct: - return 'COUNT(DISTINCT(%s))' % col + return 'COUNT(DISTINCT %s)' % col else: - return 'COUNT(%s)' % col + return 'COUNT(%s)' % self.col class Date(object): """ diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9f4edff7cf..dbbe47f23a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -147,15 +147,17 @@ class Query(object): def get_count(self): """ - Performs a COUNT() or COUNT(DISTINCT()) query, as appropriate, using - the current filter constraints. + Performs a COUNT() query using the current filter constraints. """ - counter = self.clone() - counter.clear_ordering() - counter.clear_limits() - counter.select_related = False - counter.add_count_column() - data = counter.execute_sql(SINGLE) + obj = self.clone() + obj.clear_ordering() + obj.clear_limits() + obj.select_related = False + if obj.distinct and len(obj.select) > 1: + obj = self.clone(CountQuery, _query=obj, where=WhereNode(self), + distinct=False) + obj.add_count_column() + data = obj.execute_sql(SINGLE) if not data: return 0 number = data[0] @@ -176,7 +178,6 @@ class Query(object): If 'with_limits' is False, any limit/offset information is not included in the query. """ - qn = self.connection.ops.quote_name self.pre_sql_setup() result = ['SELECT'] if self.distinct: @@ -185,21 +186,12 @@ class Query(object): result.append(', '.join(out_cols)) result.append('FROM') - for alias in self.tables: - if not self.alias_map[alias][ALIAS_REFCOUNT]: - continue - name, alias, join_type, lhs, lhs_col, col = \ - self.alias_map[alias][ALIAS_JOIN] - alias_str = (alias != name and ' AS %s' % alias or '') - if join_type: - result.append('%s %s%s ON (%s.%s = %s.%s)' - % (join_type, qn(name), alias_str, qn(lhs), - qn(lhs_col), qn(alias), qn(col))) - else: - result.append('%s%s' % (qn(name), alias_str)) - result.extend(self.extra_tables) + from_, f_params = self.get_from_clause() + result.extend(from_) + params = list(f_params) - where, params = self.where.as_sql() + where, w_params = self.where.as_sql() + params.extend(w_params) if where: result.append('WHERE %s' % where) if self.extra_where: @@ -348,6 +340,30 @@ class Query(object): for alias, col in extra_select]) return result + def get_from_clause(self): + """ + Returns a list of strings that are joined together to go after the + "FROM" part of the query, as well as any extra parameters that need to + be included. Sub-classes, can override this to create a from-clause via + a "select", for example (e.g. CountQuery). + """ + result = [] + qn = self.connection.ops.quote_name + for alias in self.tables: + if not self.alias_map[alias][ALIAS_REFCOUNT]: + continue + name, alias, join_type, lhs, lhs_col, col = \ + self.alias_map[alias][ALIAS_JOIN] + alias_str = (alias != name and ' AS %s' % alias or '') + if join_type: + result.append('%s %s%s ON (%s.%s = %s.%s)' + % (join_type, qn(name), alias_str, qn(lhs), + qn(lhs_col), qn(alias), qn(col))) + else: + result.append('%s%s' % (qn(name), alias_str)) + result.extend(self.extra_tables) + return result, [] + def get_grouping(self): """ Returns a tuple representing the SQL elements in the "group by" clause. @@ -787,8 +803,17 @@ class Query(object): if not self.distinct: select = Count() else: - select = Count((self.table_map[self.model._meta.db_table][0], - self.model._meta.pk.column), True) + opts = self.model._meta + if not self.select: + select = Count((self.join((None, opts.db_table, None, None)), + opts.pk.column), True) + else: + # Because of SQL portability issues, multi-column, distinct + # counts need a sub-query -- see get_count() for details. + assert len(self.select) == 1, \ + "Cannot add count col with multiple cols in 'select'." + select = Count(self.select[0], True) + # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False @@ -987,6 +1012,16 @@ class DateQuery(Query): else: self.group_by = [select] +class CountQuery(Query): + """ + A CountQuery knows how to take a normal query which would select over + multiple distinct columns and turn it into SQL that can be used on a + variety of backends (it requires a select in the FROM clause). + """ + def get_from_clause(self): + result, params = self._query.as_sql() + return ['(%s) AS A1' % result], params + def find_field(name, field_list, related_query): """ Finds a field with a specific name in a list of field instances. diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index 36161852a1..175e0c867e 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -111,10 +111,17 @@ Bug #2080, #3592 >>> Author.objects.filter(Q(name='a3') | Q(item__name='one')) [, ] -Bug #2939 -# FIXME: ValueQuerySets don't work yet. -# >>> Item.objects.values('creator').distinct().count() -# 2 +Bug #1878, #2939 +>>> Item.objects.values('creator').distinct().count() +3 + +# Create something with a duplicate 'name' so that we can test multi-column +# cases (which require some tricky SQL transformations under the covers). +>>> xx = Item(name='four', creator=a2) +>>> xx.save() +>>> Item.objects.exclude(name='two').values('creator', 'name').distinct().count() +4 +>>> xx.delete() Bug #2253 >>> q1 = Item.objects.order_by('name')