1
0
mirror of https://github.com/django/django.git synced 2025-07-04 17:59:13 +00:00

[soc2010/query-refactor] Implemented count() (and by extension the Count() aggregate on the primary key).

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13353 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2010-06-14 20:09:24 +00:00
parent 8f441f0962
commit f522555392
5 changed files with 44 additions and 6 deletions

View File

@ -47,6 +47,11 @@ class DatabaseOperations(object):
for table in tables: for table in tables:
self.connection.db.drop_collection(table) self.connection.db.drop_collection(table)
def check_aggregate_support(self, aggregate):
# TODO: this really should use the generic aggregates, not the SQL ones
from django.db.models.sql.aggregates import Count
return isinstance(aggregate, Count)
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs) super(DatabaseWrapper, self).__init__(*args, **kwargs)

View File

@ -32,9 +32,9 @@ class SQLCompiler(object):
column = "_id" column = "_id"
return column, params[0] return column, params[0]
def build_query(self): def build_query(self, aggregates=False):
assert not self.query.aggregates assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1
assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) == 1 if not aggregates:
assert self.query.default_cols assert self.query.default_cols
assert not self.query.distinct assert not self.query.distinct
assert not self.query.extra assert not self.query.extra
@ -61,6 +61,17 @@ class SQLCompiler(object):
else: else:
return True return True
def get_aggregates(self):
assert len(self.query.aggregates) == 1
agg = self.query.aggregates.values()[0]
assert (
isinstance(agg, self.query.aggregates_module.Count) and (
agg.col == "*" or
isinstance(agg.col, tuple) and agg.col == (self.query.model._meta.db_table, self.query.model._meta.pk.column)
)
)
return [self.build_query(aggregates=True).count()]
class SQLInsertCompiler(SQLCompiler): class SQLInsertCompiler(SQLCompiler):
def insert(self, return_id=False): def insert(self, return_id=False):

View File

@ -676,6 +676,9 @@ class SQLCompiler(object):
self.query.set_limits(high=1) self.query.set_limits(high=1)
return bool(self.execute_sql(SINGLE)) return bool(self.execute_sql(SINGLE))
def get_aggregates(self):
return self.execute_sql(SINGLE)
def results_iter(self): def results_iter(self):
""" """
Returns an iterator over the results from executing this query. Returns an iterator over the results from executing this query.

View File

@ -363,7 +363,7 @@ class Query(object):
query.related_select_cols = [] query.related_select_cols = []
query.related_select_fields = [] query.related_select_fields = []
result = query.get_compiler(using).execute_sql(SINGLE) result = query.get_compiler(using).get_aggregates()
if result is None: if result is None:
result = [None for q in query.aggregate_select.items()] result = [None for q in query.aggregate_select.items()]

View File

@ -1,3 +1,4 @@
from django.db.models import Count
from django.test import TestCase from django.test import TestCase
from models import Artist from models import Artist
@ -25,3 +26,21 @@ class MongoTestCase(TestCase):
l = Artist.objects.get(pk=pk) l = Artist.objects.get(pk=pk)
self.assertTrue(not l.good) self.assertTrue(not l.good)
def test_count(self):
Artist.objects.create(name="Billy Joel", good=True)
Artist.objects.create(name="John Mellencamp", good=True)
Artist.objects.create(name="Warren Zevon", good=True)
Artist.objects.create(name="Matisyahu", good=True)
Artist.objects.create(name="Gary US Bonds", good=True)
self.assertEqual(Artist.objects.count(), 5)
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
Artist.objects.create(name="Bon Iver", good=False)
self.assertEqual(Artist.objects.count(), 6)
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
self.assertEqual(Artist.objects.filter(good=False).count(), 1)
self.assertEqual(Artist.objects.aggregate(c=Count("pk")), {"c": 6})