mirror of
https://github.com/django/django.git
synced 2025-07-04 01:39:20 +00:00
[soc2010/query-refactor] Implemented count() (and by extension the Count() aggregate on the primary key).
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13353 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
8f441f0962
commit
f522555392
@ -46,6 +46,11 @@ class DatabaseOperations(object):
|
||||
tables = self.connection.introspection.table_names()
|
||||
for table in tables:
|
||||
self.connection.db.drop_collection(table)
|
||||
|
||||
def check_aggregate_support(self, aggregate):
|
||||
# TODO: this really should use the generic aggregates, not the SQL ones
|
||||
from django.db.models.sql.aggregates import Count
|
||||
return isinstance(aggregate, Count)
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -32,10 +32,10 @@ class SQLCompiler(object):
|
||||
column = "_id"
|
||||
return column, params[0]
|
||||
|
||||
def build_query(self):
|
||||
assert not self.query.aggregates
|
||||
assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) == 1
|
||||
assert self.query.default_cols
|
||||
def build_query(self, aggregates=False):
|
||||
assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1
|
||||
if not aggregates:
|
||||
assert self.query.default_cols
|
||||
assert not self.query.distinct
|
||||
assert not self.query.extra
|
||||
assert not self.query.having
|
||||
@ -60,6 +60,17 @@ class SQLCompiler(object):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_aggregates(self):
|
||||
assert len(self.query.aggregates) == 1
|
||||
agg = self.query.aggregates.values()[0]
|
||||
assert (
|
||||
isinstance(agg, self.query.aggregates_module.Count) and (
|
||||
agg.col == "*" or
|
||||
isinstance(agg.col, tuple) and agg.col == (self.query.model._meta.db_table, self.query.model._meta.pk.column)
|
||||
)
|
||||
)
|
||||
return [self.build_query(aggregates=True).count()]
|
||||
|
||||
|
||||
class SQLInsertCompiler(SQLCompiler):
|
||||
|
@ -675,7 +675,10 @@ class SQLCompiler(object):
|
||||
self.query.clear_ordering(True)
|
||||
self.query.set_limits(high=1)
|
||||
return bool(self.execute_sql(SINGLE))
|
||||
|
||||
|
||||
def get_aggregates(self):
|
||||
return self.execute_sql(SINGLE)
|
||||
|
||||
def results_iter(self):
|
||||
"""
|
||||
Returns an iterator over the results from executing this query.
|
||||
|
@ -363,7 +363,7 @@ class Query(object):
|
||||
query.related_select_cols = []
|
||||
query.related_select_fields = []
|
||||
|
||||
result = query.get_compiler(using).execute_sql(SINGLE)
|
||||
result = query.get_compiler(using).get_aggregates()
|
||||
if result is None:
|
||||
result = [None for q in query.aggregate_select.items()]
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from django.db.models import Count
|
||||
from django.test import TestCase
|
||||
|
||||
from models import Artist
|
||||
@ -25,3 +26,21 @@ class MongoTestCase(TestCase):
|
||||
|
||||
l = Artist.objects.get(pk=pk)
|
||||
self.assertTrue(not l.good)
|
||||
|
||||
def test_count(self):
|
||||
Artist.objects.create(name="Billy Joel", good=True)
|
||||
Artist.objects.create(name="John Mellencamp", good=True)
|
||||
Artist.objects.create(name="Warren Zevon", good=True)
|
||||
Artist.objects.create(name="Matisyahu", good=True)
|
||||
Artist.objects.create(name="Gary US Bonds", good=True)
|
||||
|
||||
self.assertEqual(Artist.objects.count(), 5)
|
||||
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
|
||||
|
||||
Artist.objects.create(name="Bon Iver", good=False)
|
||||
|
||||
self.assertEqual(Artist.objects.count(), 6)
|
||||
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
|
||||
self.assertEqual(Artist.objects.filter(good=False).count(), 1)
|
||||
|
||||
self.assertEqual(Artist.objects.aggregate(c=Count("pk")), {"c": 6})
|
||||
|
Loading…
x
Reference in New Issue
Block a user