From a84344bc539c66589c8d4fe30c6ceaecf8ba1af3 Mon Sep 17 00:00:00 2001 From: David Sanders Date: Sun, 17 Apr 2016 10:03:08 -0700 Subject: [PATCH] Fixed #19513, #18580 -- Fixed crash on QuerySet.update() after annotate(). --- django/db/models/query.py | 2 ++ django/db/models/sql/subqueries.py | 6 +++++- tests/update/tests.py | 22 +++++++++++++++++++++- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 40f7ae6ea9..4085e618cf 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -632,6 +632,8 @@ class QuerySet(object): self._for_write = True query = self.query.clone(sql.UpdateQuery) query.add_update_values(kwargs) + # Clear any annotations so that they won't be present in subqueries. + query._annotations = None with transaction.atomic(using=self.db, savepoint=False): rows = query.get_compiler(self.db).execute_sql(CURSOR) self._result_cache = None diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 316a5c684d..fc9683064f 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -142,7 +142,11 @@ class UpdateQuery(Query): that will be used to generate the UPDATE query. Might be more usefully called add_update_targets() to hint at the extra information here. """ - self.values.extend(values_seq) + for field, model, val in values_seq: + if hasattr(val, 'resolve_expression'): + # Resolve expressions here so that annotations are no longer needed + val = val.resolve_expression(self, allow_joins=False, for_save=True) + self.values.append((field, model, val)) def add_related_update(self, model, field, value): """ diff --git a/tests/update/tests.py b/tests/update/tests.py index 3dc97c9173..89593f8dfc 100644 --- a/tests/update/tests.py +++ b/tests/update/tests.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals from django.core.exceptions import FieldError -from django.db.models import F, Max +from django.db.models import Count, F, Max from django.test import TestCase from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint @@ -158,3 +158,23 @@ class AdvancedTests(TestCase): qs = DataPoint.objects.annotate(max=Max('value')) with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'): qs.update(another_value=F('max')) + + def test_update_annotated_multi_table_queryset(self): + """ + Update of a queryset that's been annotated and involves multiple tables. + """ + # Trivial annotated update + qs = DataPoint.objects.annotate(related_count=Count('relatedpoint')) + self.assertEqual(qs.update(value='Foo'), 3) + # Update where annotation is used for filtering + qs = DataPoint.objects.annotate(related_count=Count('relatedpoint')) + self.assertEqual(qs.filter(related_count=1).update(value='Foo'), 1) + # Update where annotation is used in update parameters + # #26539 - This isn't forbidden but also doesn't generate proper SQL + # qs = RelatedPoint.objects.annotate(data_name=F('data__name')) + # updated = qs.update(name=F('data_name')) + # self.assertEqual(updated, 1) + # Update where aggregation annotation is used in update parameters + qs = RelatedPoint.objects.annotate(max=Max('data__value')) + with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'): + qs.update(name=F('max'))