1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #33463 -- Fixed QuerySet.bulk_update() with F() expressions.

This commit is contained in:
Jörg Breitbart 2022-01-27 14:42:59 +01:00 committed by Mariusz Felisiak
parent e972620ada
commit 0af9a5fc7d
2 changed files with 12 additions and 2 deletions

View File

@ -17,7 +17,7 @@ from django.db import (
from django.db.models import AutoField, DateField, DateTimeField, sql from django.db.models import AutoField, DateField, DateTimeField, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.expressions import Case, Expression, F, Ref, Value, When from django.db.models.expressions import Case, F, Ref, Value, When
from django.db.models.functions import Cast, Trunc from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
@ -670,7 +670,7 @@ class QuerySet:
when_statements = [] when_statements = []
for obj in batch_objs: for obj in batch_objs:
attr = getattr(obj, field.attname) attr = getattr(obj, field.attname)
if not isinstance(attr, Expression): if not hasattr(attr, 'resolve_expression'):
attr = Value(attr, output_field=field) attr = Value(attr, output_field=field)
when_statements.append(When(pk=obj.pk, then=attr)) when_statements.append(When(pk=obj.pk, then=attr))
case_statement = Case(*when_statements, output_field=field) case_statement = Case(*when_statements, output_field=field)

View File

@ -211,6 +211,16 @@ class BulkUpdateTests(TestCase):
Number.objects.bulk_update(numbers, ['num']) Number.objects.bulk_update(numbers, ['num'])
self.assertCountEqual(Number.objects.filter(num=1), numbers) self.assertCountEqual(Number.objects.filter(num=1), numbers)
def test_f_expression(self):
notes = [
Note.objects.create(note='test_note', misc='test_misc')
for _ in range(10)
]
for note in notes:
note.misc = F('note')
Note.objects.bulk_update(notes, ['misc'])
self.assertCountEqual(Note.objects.filter(misc='test_note'), notes)
def test_booleanfield(self): def test_booleanfield(self):
individuals = [Individual.objects.create(alive=False) for _ in range(10)] individuals = [Individual.objects.create(alive=False) for _ in range(10)]
for individual in individuals: for individual in individuals: