diff --git a/django/contrib/mongodb/compiler.py b/django/contrib/mongodb/compiler.py index f4791fd674..ec8d9be2b2 100644 --- a/django/contrib/mongodb/compiler.py +++ b/django/contrib/mongodb/compiler.py @@ -2,6 +2,7 @@ import re from pymongo import ASCENDING, DESCENDING +from django.db.models import F from django.db.models.sql.datastructures import FullResultSet, EmptyResultSet @@ -153,8 +154,25 @@ class SQLUpdateCompiler(SQLCompiler): filters = self.get_filters(self.query.where) # TODO: Don't use set for everything, use INC and such where # appropriate. + vals = {} + for field, o, value in self.query.values: + if hasattr(value, "evaluate"): + assert value.connector in (value.ADD, value.SUB) + assert not value.negated + assert not value.subtree_parents + lhs, rhs = value.children + if isinstance(lhs, F): + assert not isinstance(rhs, F) + if value.connector == value.SUB: + rhs = -rhs + else: + assert value.connector == value.ADD + rhs, lhs = lhs, rhs + vals.setdefault("$inc", {})[lhs.name] = rhs + else: + vals.setdefault("$set", {})[field.column] = value return self.connection.db[self.query.model._meta.db_table].update( filters, - {"$set": dict((f.column, val) for f, o, val in self.query.values)}, + vals, multi=True ) diff --git a/django/db/models/query.py b/django/db/models/query.py index e43fc3150a..d055490713 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -8,7 +8,8 @@ from itertools import izip from django.db import connections, router, transaction, IntegrityError from django.db.models.aggregates import Aggregate from django.db.models.fields import DateField -from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory, InvalidQuery +from django.db.models.query_utils import (Q, select_related_descend, + CollectedObjects, CyclicDependency, deferred_class_factory, InvalidQuery) from django.db.models import signals, sql from django.utils.copycompat import deepcopy @@ -464,7 +465,7 @@ class QuerySet(object): else: forced_managed = False try: - rows = query.get_compiler(self.db).execute_sql(None) + rows = query.get_compiler(self.db).update(None) if forced_managed: transaction.commit(using=self.db) else: diff --git a/tests/regressiontests/mongodb/models.py b/tests/regressiontests/mongodb/models.py index 183663aaf5..039fce4930 100644 --- a/tests/regressiontests/mongodb/models.py +++ b/tests/regressiontests/mongodb/models.py @@ -5,6 +5,7 @@ class Artist(models.Model): id = models.NativeAutoField(primary_key=True) name = models.CharField(max_length=255) good = models.BooleanField() + age = models.IntegerField(null=True) current_group = models.ForeignKey("Group", null=True) diff --git a/tests/regressiontests/mongodb/tests.py b/tests/regressiontests/mongodb/tests.py index ace0b4fdec..9610e7d4c4 100644 --- a/tests/regressiontests/mongodb/tests.py +++ b/tests/regressiontests/mongodb/tests.py @@ -1,4 +1,4 @@ -from django.db.models import Count +from django.db.models import Count, F from django.test import TestCase from models import Artist, Group @@ -27,6 +27,28 @@ class MongoTestCase(TestCase): l = Artist.objects.get(pk=pk) self.assertTrue(not l.good) + def test_bulk_update(self): + # Doesn't actually do an op on more than 1 item, but it's the bulk + # update syntax nonetheless + v = Artist.objects.create(name="Van Morrison", good=False) + # How do you make a mistake like this, I don't know... + Artist.objects.filter(pk=v.pk).update(good=True) + self.assertTrue(Artist.objects.get(pk=v.pk).good) + + def test_f_expressions(self): + k = Artist.objects.create(name="Keb' Mo'", age=57, good=True) + # Birthday! + Artist.objects.filter(pk=k.pk).update(age=F("age") + 1) + self.assertEqual(Artist.objects.get(pk=k.pk).age, 58) + + # Backwards birthday + Artist.objects.filter(pk=k.pk).update(age=F("age") - 1) + self.assertEqual(Artist.objects.get(pk=k.pk).age, 57) + + # Birthday again! + Artist.objects.filter(pk=k.pk).update(age=1 + F("age")) + self.assertEqual(Artist.objects.get(pk=k.pk).age, 58) + def test_count(self): Artist.objects.create(name="Billy Joel", good=True) Artist.objects.create(name="John Mellencamp", good=True) @@ -121,7 +143,7 @@ class MongoTestCase(TestCase): self.assertQuerysetEqual( Artist.objects.values(), [ - {"name": "Steve Perry", "good": True, "current_group_id": None, "id": a.pk}, + {"name": "Steve Perry", "good": True, "current_group_id": None, "id": a.pk, "age": None}, ], lambda a: a, )