diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index f66fa524b4..59eecdba20 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -1,12 +1,19 @@ import datetime import uuid from functools import lru_cache +from itertools import chain from django.conf import settings from django.db import DatabaseError, NotSupportedError from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name -from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup +from django.db.models import ( + AutoField, + CompositePrimaryKey, + Exists, + ExpressionWrapper, + Lookup, +) from django.db.models.expressions import RawSQL from django.db.models.sql.where import WhereNode from django.utils import timezone @@ -708,6 +715,12 @@ END; def bulk_batch_size(self, fields, objs): """Oracle restricts the number of parameters in a query.""" + fields = list( + chain.from_iterable( + field.fields if isinstance(field, CompositePrimaryKey) else [field] + for field in fields + ) + ) if fields: return self.connection.features.max_query_params // len(fields) return len(objs) diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 0078cc077a..08de246d70 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -36,6 +36,16 @@ class DatabaseOperations(BaseDatabaseOperations): If there's only a single field to insert, the limit is 500 (SQLITE_MAX_COMPOUND_SELECT). """ + fields = list( + chain.from_iterable( + ( + field.fields + if isinstance(field, models.CompositePrimaryKey) + else [field] + ) + for field in fields + ) + ) if len(fields) == 1: return 500 elif len(fields) > 1: diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index fd3d290a96..da2e934c96 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -230,9 +230,8 @@ class Collector: """ Return the objs in suitably sized batches for the used connection. """ - field_names = [field.name for field in fields] conn_batch_size = max( - connections[self.using].ops.bulk_batch_size(field_names, objs), 1 + connections[self.using].ops.bulk_batch_size(fields, objs), 1 ) if len(objs) > conn_batch_size: return [ diff --git a/django/db/models/query.py b/django/db/models/query.py index 25995b0d83..1730aca16d 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -875,11 +875,12 @@ class QuerySet(AltersData): objs = tuple(objs) if not all(obj._is_pk_set() for obj in objs): raise ValueError("All bulk_update() objects must have a primary key set.") - fields = [self.model._meta.get_field(name) for name in fields] + opts = self.model._meta + fields = [opts.get_field(name) for name in fields] if any(not f.concrete or f.many_to_many for f in fields): raise ValueError("bulk_update() can only be used with concrete fields.") - all_pk_fields = set(self.model._meta.pk_fields) - for parent in self.model._meta.all_parents: + all_pk_fields = set(opts.pk_fields) + for parent in opts.all_parents: all_pk_fields.update(parent._meta.pk_fields) if any(f in all_pk_fields for f in fields): raise ValueError("bulk_update() cannot be used with primary key fields.") @@ -893,7 +894,9 @@ class QuerySet(AltersData): # and once in the WHEN. Each field will also have one CAST. self._for_write = True connection = connections[self.db] - max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) + max_batch_size = connection.ops.bulk_batch_size( + [opts.pk, opts.pk] + fields, objs + ) batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size requires_casting = connection.features.requires_casted_case_in_updates batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) diff --git a/tests/backends/oracle/test_operations.py b/tests/backends/oracle/test_operations.py index 523bdcda8a..1f9447bde7 100644 --- a/tests/backends/oracle/test_operations.py +++ b/tests/backends/oracle/test_operations.py @@ -1,7 +1,7 @@ import unittest from django.core.management.color import no_style -from django.db import connection +from django.db import connection, models from django.test import TransactionTestCase from ..models import Person, Tag @@ -22,14 +22,25 @@ class OperationsTests(TransactionTestCase): objects = range(2**16) self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects)) # Each field is a parameter for each object. + first_name_field = Person._meta.get_field("first_name") + last_name_field = Person._meta.get_field("last_name") self.assertEqual( - connection.ops.bulk_batch_size(["id"], objects), + connection.ops.bulk_batch_size([first_name_field], objects), connection.features.max_query_params, ) self.assertEqual( - connection.ops.bulk_batch_size(["id", "other"], objects), + connection.ops.bulk_batch_size( + [first_name_field, last_name_field], + objects, + ), connection.features.max_query_params // 2, ) + composite_pk = models.CompositePrimaryKey("first_name", "last_name") + composite_pk.fields = [first_name_field, last_name_field] + self.assertEqual( + connection.ops.bulk_batch_size([composite_pk, first_name_field], objects), + connection.features.max_query_params // 3, + ) def test_sql_flush(self): statements = connection.ops.sql_flush( diff --git a/tests/backends/sqlite/test_operations.py b/tests/backends/sqlite/test_operations.py index 3ff055248d..10cbffdf80 100644 --- a/tests/backends/sqlite/test_operations.py +++ b/tests/backends/sqlite/test_operations.py @@ -1,7 +1,7 @@ import unittest from django.core.management.color import no_style -from django.db import connection +from django.db import connection, models from django.test import TestCase from ..models import Person, Tag @@ -86,3 +86,25 @@ class SQLiteOperationsTests(TestCase): "zzz'", statements[-1], ) + + def test_bulk_batch_size(self): + self.assertEqual(connection.ops.bulk_batch_size([], [Person()]), 1) + first_name_field = Person._meta.get_field("first_name") + last_name_field = Person._meta.get_field("last_name") + self.assertEqual( + connection.ops.bulk_batch_size([first_name_field], [Person()]), 500 + ) + self.assertEqual( + connection.ops.bulk_batch_size( + [first_name_field, last_name_field], [Person()] + ), + connection.features.max_query_params // 2, + ) + composite_pk = models.CompositePrimaryKey("first_name", "last_name") + composite_pk.fields = [first_name_field, last_name_field] + self.assertEqual( + connection.ops.bulk_batch_size( + [composite_pk, first_name_field], [Person()] + ), + connection.features.max_query_params // 3, + )