diff --git a/django/contrib/mongodb/compiler.py b/django/contrib/mongodb/compiler.py index 0acdc1f4a5..bcd67b17a1 100644 --- a/django/contrib/mongodb/compiler.py +++ b/django/contrib/mongodb/compiler.py @@ -5,21 +5,27 @@ class SQLCompiler(object): self.connection = connection self.using = using - def get_filters(self, where): + def get_filters(self, where, correct=False): assert where.connector == "AND" - assert not where.negated filters = {} for child in where.children: if isinstance(child, self.query.where_class): - # TODO: probably needs to check for dupe keys - filters.update(self.get_filters(child)) + child_filters = self.get_filters(child) + for k, v in child_filters.iteritems(): + if k in filters: + v = {"$and": [filters[k], v]} + if where.negated: + v = {"$not": v} + filters[k] = v else: - field, val = self.make_atom(*child) + field, val = self.make_atom(*child, **{"negated": where.negated}) filters[field] = val + if correct: + self.correct_filters(filters) return filters - def make_atom(self, lhs, lookup_type, value_annotation, params_or_value): - assert lookup_type == "exact" + def make_atom(self, lhs, lookup_type, value_annotation, params_or_value, negated): + assert lookup_type in ["exact", "isnull"], lookup_type if hasattr(lhs, "process"): lhs, params = lhs.process(lookup_type, params_or_value, self.connection) else: @@ -30,7 +36,33 @@ class SQLCompiler(object): assert table == self.query.model._meta.db_table if column == self.query.model._meta.pk.column: column = "_id" - return column, params[0] + + if lookup_type == "exact": + val = params[0] + if negated: + val = {"$ne": val} + return column, val + elif lookup_type == "isnull": + val = None + if value_annotation == negated: + val = {"$not": val} + return column, val + + def correct_filters(self, filters): + for k, v in filters.items(): + if isinstance(v, dict) and v.keys() == ["$not"]: + if isinstance(v["$not"], dict) and v["$not"].keys() == ["$and"]: + del filters[k] + or_vals = [self.negate(k, v) for v in v["$not"]["$and"]] + assert "$or" not in filters + filters["$or"] = or_vals + + def negate(self, k, v): + if isinstance(v, dict): + if v.keys() == ["$not"]: + return {k: v["$not"]} + return {k: {"$not": v}} + return {k: {"$ne": v}} def build_query(self, aggregates=False): assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1 @@ -42,7 +74,7 @@ class SQLCompiler(object): assert self.query.high_mark is None assert not self.query.order_by - filters = self.get_filters(self.query.where) + filters = self.get_filters(self.query.where, correct=True) return self.connection.db[self.query.model._meta.db_table].find(filters) def results_iter(self): diff --git a/tests/regressiontests/mongodb/models.py b/tests/regressiontests/mongodb/models.py index b942d12e59..183663aaf5 100644 --- a/tests/regressiontests/mongodb/models.py +++ b/tests/regressiontests/mongodb/models.py @@ -15,3 +15,5 @@ class Artist(models.Model): class Group(models.Model): id = models.NativeAutoField(primary_key=True) name = models.CharField(max_length=255) + year_formed = models.IntegerField(null=True) + diff --git a/tests/regressiontests/mongodb/tests.py b/tests/regressiontests/mongodb/tests.py index bd9cd5981f..93adcecd58 100644 --- a/tests/regressiontests/mongodb/tests.py +++ b/tests/regressiontests/mongodb/tests.py @@ -57,3 +57,26 @@ class MongoTestCase(TestCase): self.assertEqual(b.current_group_id, e.pk) self.assertFalse(hasattr(b, "_current_group_cache")) self.assertEqual(b.current_group, e) + + def test_lookup(self): + q = Group.objects.create(name="Queen", year_formed=1971) + e = Group.objects.create(name="The E Street Band", year_formed=1972) + + qs = Group.objects.exclude(year_formed=1972) + v = qs.query.get_compiler(qs.db).get_filters(qs.query.where, correct=True) + self.assertEqual(v, { + "$or": [ + {"year_formed": {"$ne": 1972}}, + {"year_formed": None}, + ] + }) + # A bug in MongoDB prevents this query from actually working, but test + # that we're at least generating the right query. + return + + self.assertQuerysetEqual( + qs, [ + "Queen", + ], + lambda g: g.name, + )