From f522555392ed9e133431437dd815ba0e84bc2394 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Mon, 14 Jun 2010 20:09:24 +0000 Subject: [PATCH] [soc2010/query-refactor] Implemented count() (and by extension the Count() aggregate on the primary key). git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13353 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/contrib/mongodb/base.py | 5 +++++ django/contrib/mongodb/compiler.py | 19 +++++++++++++++---- django/db/models/sql/compiler.py | 5 ++++- django/db/models/sql/query.py | 2 +- tests/regressiontests/mongodb/tests.py | 19 +++++++++++++++++++ 5 files changed, 44 insertions(+), 6 deletions(-) diff --git a/django/contrib/mongodb/base.py b/django/contrib/mongodb/base.py index f1bacee812..70f46c91d4 100644 --- a/django/contrib/mongodb/base.py +++ b/django/contrib/mongodb/base.py @@ -46,6 +46,11 @@ class DatabaseOperations(object): tables = self.connection.introspection.table_names() for table in tables: self.connection.db.drop_collection(table) + + def check_aggregate_support(self, aggregate): + # TODO: this really should use the generic aggregates, not the SQL ones + from django.db.models.sql.aggregates import Count + return isinstance(aggregate, Count) class DatabaseWrapper(BaseDatabaseWrapper): def __init__(self, *args, **kwargs): diff --git a/django/contrib/mongodb/compiler.py b/django/contrib/mongodb/compiler.py index 83464be76e..0acdc1f4a5 100644 --- a/django/contrib/mongodb/compiler.py +++ b/django/contrib/mongodb/compiler.py @@ -32,10 +32,10 @@ class SQLCompiler(object): column = "_id" return column, params[0] - def build_query(self): - assert not self.query.aggregates - assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) == 1 - assert self.query.default_cols + def build_query(self, aggregates=False): + assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1 + if not aggregates: + assert self.query.default_cols assert not self.query.distinct assert not self.query.extra assert not self.query.having @@ -60,6 +60,17 @@ class SQLCompiler(object): return False else: return True + + def get_aggregates(self): + assert len(self.query.aggregates) == 1 + agg = self.query.aggregates.values()[0] + assert ( + isinstance(agg, self.query.aggregates_module.Count) and ( + agg.col == "*" or + isinstance(agg.col, tuple) and agg.col == (self.query.model._meta.db_table, self.query.model._meta.pk.column) + ) + ) + return [self.build_query(aggregates=True).count()] class SQLInsertCompiler(SQLCompiler): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 1cd5cb535b..b4c9ea1f44 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -675,7 +675,10 @@ class SQLCompiler(object): self.query.clear_ordering(True) self.query.set_limits(high=1) return bool(self.execute_sql(SINGLE)) - + + def get_aggregates(self): + return self.execute_sql(SINGLE) + def results_iter(self): """ Returns an iterator over the results from executing this query. diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 02d13bd7a9..959990f628 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -363,7 +363,7 @@ class Query(object): query.related_select_cols = [] query.related_select_fields = [] - result = query.get_compiler(using).execute_sql(SINGLE) + result = query.get_compiler(using).get_aggregates() if result is None: result = [None for q in query.aggregate_select.items()] diff --git a/tests/regressiontests/mongodb/tests.py b/tests/regressiontests/mongodb/tests.py index 3fd0ddde5c..bdc5a10727 100644 --- a/tests/regressiontests/mongodb/tests.py +++ b/tests/regressiontests/mongodb/tests.py @@ -1,3 +1,4 @@ +from django.db.models import Count from django.test import TestCase from models import Artist @@ -25,3 +26,21 @@ class MongoTestCase(TestCase): l = Artist.objects.get(pk=pk) self.assertTrue(not l.good) + + def test_count(self): + Artist.objects.create(name="Billy Joel", good=True) + Artist.objects.create(name="John Mellencamp", good=True) + Artist.objects.create(name="Warren Zevon", good=True) + Artist.objects.create(name="Matisyahu", good=True) + Artist.objects.create(name="Gary US Bonds", good=True) + + self.assertEqual(Artist.objects.count(), 5) + self.assertEqual(Artist.objects.filter(good=True).count(), 5) + + Artist.objects.create(name="Bon Iver", good=False) + + self.assertEqual(Artist.objects.count(), 6) + self.assertEqual(Artist.objects.filter(good=True).count(), 5) + self.assertEqual(Artist.objects.filter(good=False).count(), 1) + + self.assertEqual(Artist.objects.aggregate(c=Count("pk")), {"c": 6})