1
0
mirror of https://github.com/django/django.git synced 2025-01-30 12:09:25 +00:00

[5.2.x] Fixed #36118 -- Accounted for multiple primary keys in bulk_update max_batch_size.

Co-authored-by: Simon Charette <charette.s@gmail.com>

Backport of 5a2c1bc07d126ce32efaa157e712a8f3a7457b74 from main.
This commit is contained in:
Sarah Boyce 2025-01-27 10:28:21 +01:00
parent 4aa2cd6f68
commit a469397dd3
6 changed files with 69 additions and 11 deletions

View File

@ -1,12 +1,19 @@
import datetime
import uuid
from functools import lru_cache
from itertools import chain
from django.conf import settings
from django.db import DatabaseError, NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
from django.db.models import (
AutoField,
CompositePrimaryKey,
Exists,
ExpressionWrapper,
Lookup,
)
from django.db.models.expressions import RawSQL
from django.db.models.sql.where import WhereNode
from django.utils import timezone
@ -708,6 +715,12 @@ END;
def bulk_batch_size(self, fields, objs):
"""Oracle restricts the number of parameters in a query."""
fields = list(
chain.from_iterable(
field.fields if isinstance(field, CompositePrimaryKey) else [field]
for field in fields
)
)
if fields:
return self.connection.features.max_query_params // len(fields)
return len(objs)

View File

@ -36,6 +36,16 @@ class DatabaseOperations(BaseDatabaseOperations):
If there's only a single field to insert, the limit is 500
(SQLITE_MAX_COMPOUND_SELECT).
"""
fields = list(
chain.from_iterable(
(
field.fields
if isinstance(field, models.CompositePrimaryKey)
else [field]
)
for field in fields
)
)
if len(fields) == 1:
return 500
elif len(fields) > 1:

View File

@ -230,9 +230,8 @@ class Collector:
"""
Return the objs in suitably sized batches for the used connection.
"""
field_names = [field.name for field in fields]
conn_batch_size = max(
connections[self.using].ops.bulk_batch_size(field_names, objs), 1
connections[self.using].ops.bulk_batch_size(fields, objs), 1
)
if len(objs) > conn_batch_size:
return [

View File

@ -875,11 +875,12 @@ class QuerySet(AltersData):
objs = tuple(objs)
if not all(obj._is_pk_set() for obj in objs):
raise ValueError("All bulk_update() objects must have a primary key set.")
fields = [self.model._meta.get_field(name) for name in fields]
opts = self.model._meta
fields = [opts.get_field(name) for name in fields]
if any(not f.concrete or f.many_to_many for f in fields):
raise ValueError("bulk_update() can only be used with concrete fields.")
all_pk_fields = set(self.model._meta.pk_fields)
for parent in self.model._meta.all_parents:
all_pk_fields = set(opts.pk_fields)
for parent in opts.all_parents:
all_pk_fields.update(parent._meta.pk_fields)
if any(f in all_pk_fields for f in fields):
raise ValueError("bulk_update() cannot be used with primary key fields.")
@ -893,7 +894,9 @@ class QuerySet(AltersData):
# and once in the WHEN. Each field will also have one CAST.
self._for_write = True
connection = connections[self.db]
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
max_batch_size = connection.ops.bulk_batch_size(
[opts.pk, opts.pk] + fields, objs
)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
requires_casting = connection.features.requires_casted_case_in_updates
batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))

View File

@ -1,7 +1,7 @@
import unittest
from django.core.management.color import no_style
from django.db import connection
from django.db import connection, models
from django.test import TransactionTestCase
from ..models import Person, Tag
@ -22,14 +22,25 @@ class OperationsTests(TransactionTestCase):
objects = range(2**16)
self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects))
# Each field is a parameter for each object.
first_name_field = Person._meta.get_field("first_name")
last_name_field = Person._meta.get_field("last_name")
self.assertEqual(
connection.ops.bulk_batch_size(["id"], objects),
connection.ops.bulk_batch_size([first_name_field], objects),
connection.features.max_query_params,
)
self.assertEqual(
connection.ops.bulk_batch_size(["id", "other"], objects),
connection.ops.bulk_batch_size(
[first_name_field, last_name_field],
objects,
),
connection.features.max_query_params // 2,
)
composite_pk = models.CompositePrimaryKey("first_name", "last_name")
composite_pk.fields = [first_name_field, last_name_field]
self.assertEqual(
connection.ops.bulk_batch_size([composite_pk, first_name_field], objects),
connection.features.max_query_params // 3,
)
def test_sql_flush(self):
statements = connection.ops.sql_flush(

View File

@ -1,7 +1,7 @@
import unittest
from django.core.management.color import no_style
from django.db import connection
from django.db import connection, models
from django.test import TestCase
from ..models import Person, Tag
@ -86,3 +86,25 @@ class SQLiteOperationsTests(TestCase):
"zzz'",
statements[-1],
)
def test_bulk_batch_size(self):
self.assertEqual(connection.ops.bulk_batch_size([], [Person()]), 1)
first_name_field = Person._meta.get_field("first_name")
last_name_field = Person._meta.get_field("last_name")
self.assertEqual(
connection.ops.bulk_batch_size([first_name_field], [Person()]), 500
)
self.assertEqual(
connection.ops.bulk_batch_size(
[first_name_field, last_name_field], [Person()]
),
connection.features.max_query_params // 2,
)
composite_pk = models.CompositePrimaryKey("first_name", "last_name")
composite_pk.fields = [first_name_field, last_name_field]
self.assertEqual(
connection.ops.bulk_batch_size(
[composite_pk, first_name_field], [Person()]
),
connection.features.max_query_params // 3,
)