1
0
mirror of https://github.com/django/django.git synced 2025-07-04 01:39:20 +00:00

[soc2010/query-refactor] Implemented count() (and by extension the Count() aggregate on the primary key).

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13353 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2010-06-14 20:09:24 +00:00
parent 8f441f0962
commit f522555392
5 changed files with 44 additions and 6 deletions

View File

@ -46,6 +46,11 @@ class DatabaseOperations(object):
tables = self.connection.introspection.table_names()
for table in tables:
self.connection.db.drop_collection(table)
def check_aggregate_support(self, aggregate):
# TODO: this really should use the generic aggregates, not the SQL ones
from django.db.models.sql.aggregates import Count
return isinstance(aggregate, Count)
class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs):

View File

@ -32,10 +32,10 @@ class SQLCompiler(object):
column = "_id"
return column, params[0]
def build_query(self):
assert not self.query.aggregates
assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) == 1
assert self.query.default_cols
def build_query(self, aggregates=False):
assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1
if not aggregates:
assert self.query.default_cols
assert not self.query.distinct
assert not self.query.extra
assert not self.query.having
@ -60,6 +60,17 @@ class SQLCompiler(object):
return False
else:
return True
def get_aggregates(self):
assert len(self.query.aggregates) == 1
agg = self.query.aggregates.values()[0]
assert (
isinstance(agg, self.query.aggregates_module.Count) and (
agg.col == "*" or
isinstance(agg.col, tuple) and agg.col == (self.query.model._meta.db_table, self.query.model._meta.pk.column)
)
)
return [self.build_query(aggregates=True).count()]
class SQLInsertCompiler(SQLCompiler):

View File

@ -675,7 +675,10 @@ class SQLCompiler(object):
self.query.clear_ordering(True)
self.query.set_limits(high=1)
return bool(self.execute_sql(SINGLE))
def get_aggregates(self):
return self.execute_sql(SINGLE)
def results_iter(self):
"""
Returns an iterator over the results from executing this query.

View File

@ -363,7 +363,7 @@ class Query(object):
query.related_select_cols = []
query.related_select_fields = []
result = query.get_compiler(using).execute_sql(SINGLE)
result = query.get_compiler(using).get_aggregates()
if result is None:
result = [None for q in query.aggregate_select.items()]

View File

@ -1,3 +1,4 @@
from django.db.models import Count
from django.test import TestCase
from models import Artist
@ -25,3 +26,21 @@ class MongoTestCase(TestCase):
l = Artist.objects.get(pk=pk)
self.assertTrue(not l.good)
def test_count(self):
Artist.objects.create(name="Billy Joel", good=True)
Artist.objects.create(name="John Mellencamp", good=True)
Artist.objects.create(name="Warren Zevon", good=True)
Artist.objects.create(name="Matisyahu", good=True)
Artist.objects.create(name="Gary US Bonds", good=True)
self.assertEqual(Artist.objects.count(), 5)
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
Artist.objects.create(name="Bon Iver", good=False)
self.assertEqual(Artist.objects.count(), 6)
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
self.assertEqual(Artist.objects.filter(good=False).count(), 1)
self.assertEqual(Artist.objects.aggregate(c=Count("pk")), {"c": 6})