diff --git a/django/contrib/mongodb/compiler.py b/django/contrib/mongodb/compiler.py index 5bd34a7569..b9a12b04bf 100644 --- a/django/contrib/mongodb/compiler.py +++ b/django/contrib/mongodb/compiler.py @@ -2,6 +2,7 @@ import re from pymongo import ASCENDING, DESCENDING +from django.db import UnsupportedDatabaseOperation from django.db.models import F from django.db.models.sql.datastructures import FullResultSet, EmptyResultSet @@ -43,10 +44,13 @@ class SQLCompiler(object): pass return filters - def make_atom(self, lhs, lookup_type, value_annotation, params_or_value, negated): + def make_atom(self, lhs, lookup_type, value_annotation, params_or_value, + negated): assert lookup_type in self.LOOKUP_TYPES, lookup_type if hasattr(lhs, "process"): - lhs, params = lhs.process(lookup_type, params_or_value, self.connection) + lhs, params = lhs.process( + lookup_type, params_or_value, self.connection + ) else: params = Field().get_db_prep_lookup(lookup_type, params_or_value, connection=self.connection, prepared=True) @@ -56,7 +60,8 @@ class SQLCompiler(object): if column == self.query.model._meta.pk.column: column = "_id" - return column, self.LOOKUP_TYPES[lookup_type](params, value_annotation, negated) + val = self.LOOKUP_TYPES[lookup_type](params, value_annotation, negated) + return column, val def negate(self, k, v): # Regex lookups are of the form {"field": re.compile("pattern") and @@ -79,14 +84,18 @@ class SQLCompiler(object): return None def build_query(self, aggregates=False): - assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1 + if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1: + raise UnsupportedDatabaseOperation("MongoDB does not support " + "operations across relations.") + if self.query.extra: + raise UnsupportedDatabaseOperation("MongoDB does not support extra().") assert not self.query.distinct - assert not self.query.extra assert not self.query.having filters = self.get_filters(self.query.where) fields = self.get_fields(aggregates=aggregates) - cursor = self.connection.db[self.query.model._meta.db_table].find(filters, fields=fields) + collection = self.connection.db[self.query.model._meta.db_table] + cursor = collection.find(filters, fields=fields) if self.query.order_by: cursor = cursor.sort([ (ordering.lstrip("-"), DESCENDING if ordering.startswith("-") else ASCENDING) @@ -125,14 +134,19 @@ class SQLCompiler(object): return True def get_aggregates(self): + if len(self.query.aggregates) != 1: + raise UnsupportedDatabaseOperation("MongoDB doesn't support " + "multiple aggregates in a single query.") assert len(self.query.aggregates) == 1 agg = self.query.aggregates.values()[0] - assert ( - isinstance(agg, self.query.aggregates_module.Count) and ( - agg.col == "*" or - isinstance(agg.col, tuple) and agg.col == (self.query.model._meta.db_table, self.query.model._meta.pk.column) - ) - ) + if not isinstance(agg, self.query.aggregates_module.Count): + raise UnsupportedDatabaseOperation("MongoDB does not support " + "aggregates other than Count.") + opts = self.query.model._meta + if not (agg.col == "*" or agg.col == (opts.db_table, opts.pk.column)): + raise UnsupportedDatabaseOperation("MongoDB does not support " + "aggregation over fields besides the primary key.") + return [self.build_query(aggregates=True).count()] @@ -152,8 +166,7 @@ class SQLUpdateCompiler(SQLCompiler): def update(self, result_type): # TODO: more asserts filters = self.get_filters(self.query.where) - # TODO: Don't use set for everything, use INC and such where - # appropriate. + vals = {} for field, o, value in self.query.values: if hasattr(value, "evaluate"): diff --git a/django/db/__init__.py b/django/db/__init__.py index 4bae04ab9a..73d25e1f3a 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -6,7 +6,7 @@ from django.db.utils import ConnectionHandler, ConnectionRouter, load_backend, D from django.utils.functional import curry __all__ = ('backend', 'connection', 'connections', 'router', 'DatabaseError', - 'IntegrityError', 'DEFAULT_DB_ALIAS') + 'IntegrityError', 'UnsupportedDatabaseOperation', 'DEFAULT_DB_ALIAS') # For backwards compatibility - Port any old database settings over to @@ -75,6 +75,12 @@ router = ConnectionRouter(settings.DATABASE_ROUTERS) connection = connections[DEFAULT_DB_ALIAS] backend = load_backend(connection.settings_dict['ENGINE']) +class UnsupportedDatabaseOperation(Exception): + """ + Raised when an operation attempted on a QuerySet is unsupported on the + database for it's execution. + """ + # Register an event that closes the database connection # when a Django request is finished. def close_connection(**kwargs): diff --git a/tests/regressiontests/mongodb/models.py b/tests/regressiontests/mongodb/models.py index 039fce4930..274c1936a1 100644 --- a/tests/regressiontests/mongodb/models.py +++ b/tests/regressiontests/mongodb/models.py @@ -7,7 +7,8 @@ class Artist(models.Model): good = models.BooleanField() age = models.IntegerField(null=True) - current_group = models.ForeignKey("Group", null=True) + current_group = models.ForeignKey("Group", null=True, + related_name="current_artists") def __unicode__(self): return self.name diff --git a/tests/regressiontests/mongodb/tests.py b/tests/regressiontests/mongodb/tests.py index 54c36094cf..01bdd6c13a 100644 --- a/tests/regressiontests/mongodb/tests.py +++ b/tests/regressiontests/mongodb/tests.py @@ -1,5 +1,5 @@ -from django.db import connection -from django.db.models import Count, F +from django.db import connection, UnsupportedDatabaseOperation +from django.db.models import Count, Sum, F from django.test import TestCase from models import Artist, Group @@ -359,3 +359,36 @@ class MongoTestCase(TestCase): # Ensure that closing a connection that was never established doesn't # blow up. connection.close() + + def assert_unsupported(self, obj): + if callable(obj): + # Queryset wrapped in a function (for aggregates and such) + self.assertRaises(UnsupportedDatabaseOperation, obj) + else: + # Just a queryset that blows up on evaluation + self.assertRaises(UnsupportedDatabaseOperation, list, obj) + + def test_unsupported_ops(self): + self.assert_unsupported( + Artist.objects.filter(current_group__name="The Beatles") + ) + + self.assert_unsupported( + Artist.objects.extra(select={"a": "1.0"}) + ) + + self.assert_unsupported( + Group.objects.annotate(artists=Count("current_artists")) + ) + + self.assert_unsupported( + lambda: Artist.objects.aggregate(Sum("age")) + ) + + self.assert_unsupported( + lambda: Artist.objects.aggregate(Count("age")) + ) + + self.assert_unsupported( + lambda: Artist.objects.aggregate(Count("id"), Count("pk")) + )