mirror of
				https://github.com/django/django.git
				synced 2025-10-25 22:56:12 +00:00 
			
		
		
		
	Fixed #36118 -- Accounted for multiple primary keys in bulk_update max_batch_size.
Co-authored-by: Simon Charette <charette.s@gmail.com>
This commit is contained in:
		| @@ -1,12 +1,19 @@ | |||||||
| import datetime | import datetime | ||||||
| import uuid | import uuid | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
|  | from itertools import chain | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db import NotSupportedError | from django.db import NotSupportedError | ||||||
| from django.db.backends.base.operations import BaseDatabaseOperations | from django.db.backends.base.operations import BaseDatabaseOperations | ||||||
| from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name | from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name | ||||||
| from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup | from django.db.models import ( | ||||||
|  |     AutoField, | ||||||
|  |     CompositePrimaryKey, | ||||||
|  |     Exists, | ||||||
|  |     ExpressionWrapper, | ||||||
|  |     Lookup, | ||||||
|  | ) | ||||||
| from django.db.models.expressions import RawSQL | from django.db.models.expressions import RawSQL | ||||||
| from django.db.models.sql.where import WhereNode | from django.db.models.sql.where import WhereNode | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| @@ -699,6 +706,12 @@ END; | |||||||
|  |  | ||||||
|     def bulk_batch_size(self, fields, objs): |     def bulk_batch_size(self, fields, objs): | ||||||
|         """Oracle restricts the number of parameters in a query.""" |         """Oracle restricts the number of parameters in a query.""" | ||||||
|  |         fields = list( | ||||||
|  |             chain.from_iterable( | ||||||
|  |                 field.fields if isinstance(field, CompositePrimaryKey) else [field] | ||||||
|  |                 for field in fields | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|         if fields: |         if fields: | ||||||
|             return self.connection.features.max_query_params // len(fields) |             return self.connection.features.max_query_params // len(fields) | ||||||
|         return len(objs) |         return len(objs) | ||||||
|   | |||||||
| @@ -36,6 +36,16 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         If there's only a single field to insert, the limit is 500 |         If there's only a single field to insert, the limit is 500 | ||||||
|         (SQLITE_MAX_COMPOUND_SELECT). |         (SQLITE_MAX_COMPOUND_SELECT). | ||||||
|         """ |         """ | ||||||
|  |         fields = list( | ||||||
|  |             chain.from_iterable( | ||||||
|  |                 ( | ||||||
|  |                     field.fields | ||||||
|  |                     if isinstance(field, models.CompositePrimaryKey) | ||||||
|  |                     else [field] | ||||||
|  |                 ) | ||||||
|  |                 for field in fields | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|         if len(fields) == 1: |         if len(fields) == 1: | ||||||
|             return 500 |             return 500 | ||||||
|         elif len(fields) > 1: |         elif len(fields) > 1: | ||||||
|   | |||||||
| @@ -230,9 +230,8 @@ class Collector: | |||||||
|         """ |         """ | ||||||
|         Return the objs in suitably sized batches for the used connection. |         Return the objs in suitably sized batches for the used connection. | ||||||
|         """ |         """ | ||||||
|         field_names = [field.name for field in fields] |  | ||||||
|         conn_batch_size = max( |         conn_batch_size = max( | ||||||
|             connections[self.using].ops.bulk_batch_size(field_names, objs), 1 |             connections[self.using].ops.bulk_batch_size(fields, objs), 1 | ||||||
|         ) |         ) | ||||||
|         if len(objs) > conn_batch_size: |         if len(objs) > conn_batch_size: | ||||||
|             return [ |             return [ | ||||||
|   | |||||||
| @@ -874,11 +874,12 @@ class QuerySet(AltersData): | |||||||
|         objs = tuple(objs) |         objs = tuple(objs) | ||||||
|         if not all(obj._is_pk_set() for obj in objs): |         if not all(obj._is_pk_set() for obj in objs): | ||||||
|             raise ValueError("All bulk_update() objects must have a primary key set.") |             raise ValueError("All bulk_update() objects must have a primary key set.") | ||||||
|         fields = [self.model._meta.get_field(name) for name in fields] |         opts = self.model._meta | ||||||
|  |         fields = [opts.get_field(name) for name in fields] | ||||||
|         if any(not f.concrete or f.many_to_many for f in fields): |         if any(not f.concrete or f.many_to_many for f in fields): | ||||||
|             raise ValueError("bulk_update() can only be used with concrete fields.") |             raise ValueError("bulk_update() can only be used with concrete fields.") | ||||||
|         all_pk_fields = set(self.model._meta.pk_fields) |         all_pk_fields = set(opts.pk_fields) | ||||||
|         for parent in self.model._meta.all_parents: |         for parent in opts.all_parents: | ||||||
|             all_pk_fields.update(parent._meta.pk_fields) |             all_pk_fields.update(parent._meta.pk_fields) | ||||||
|         if any(f in all_pk_fields for f in fields): |         if any(f in all_pk_fields for f in fields): | ||||||
|             raise ValueError("bulk_update() cannot be used with primary key fields.") |             raise ValueError("bulk_update() cannot be used with primary key fields.") | ||||||
| @@ -892,7 +893,9 @@ class QuerySet(AltersData): | |||||||
|         # and once in the WHEN. Each field will also have one CAST. |         # and once in the WHEN. Each field will also have one CAST. | ||||||
|         self._for_write = True |         self._for_write = True | ||||||
|         connection = connections[self.db] |         connection = connections[self.db] | ||||||
|         max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) |         max_batch_size = connection.ops.bulk_batch_size( | ||||||
|  |             [opts.pk, opts.pk] + fields, objs | ||||||
|  |         ) | ||||||
|         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size |         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size | ||||||
|         requires_casting = connection.features.requires_casted_case_in_updates |         requires_casting = connection.features.requires_casted_case_in_updates | ||||||
|         batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) |         batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from django.core.management.color import no_style | from django.core.management.color import no_style | ||||||
| from django.db import connection | from django.db import connection, models | ||||||
| from django.test import TransactionTestCase | from django.test import TransactionTestCase | ||||||
|  |  | ||||||
| from ..models import Person, Tag | from ..models import Person, Tag | ||||||
| @@ -22,14 +22,25 @@ class OperationsTests(TransactionTestCase): | |||||||
|         objects = range(2**16) |         objects = range(2**16) | ||||||
|         self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects)) |         self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects)) | ||||||
|         # Each field is a parameter for each object. |         # Each field is a parameter for each object. | ||||||
|  |         first_name_field = Person._meta.get_field("first_name") | ||||||
|  |         last_name_field = Person._meta.get_field("last_name") | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             connection.ops.bulk_batch_size(["id"], objects), |             connection.ops.bulk_batch_size([first_name_field], objects), | ||||||
|             connection.features.max_query_params, |             connection.features.max_query_params, | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             connection.ops.bulk_batch_size(["id", "other"], objects), |             connection.ops.bulk_batch_size( | ||||||
|  |                 [first_name_field, last_name_field], | ||||||
|  |                 objects, | ||||||
|  |             ), | ||||||
|             connection.features.max_query_params // 2, |             connection.features.max_query_params // 2, | ||||||
|         ) |         ) | ||||||
|  |         composite_pk = models.CompositePrimaryKey("first_name", "last_name") | ||||||
|  |         composite_pk.fields = [first_name_field, last_name_field] | ||||||
|  |         self.assertEqual( | ||||||
|  |             connection.ops.bulk_batch_size([composite_pk, first_name_field], objects), | ||||||
|  |             connection.features.max_query_params // 3, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_sql_flush(self): |     def test_sql_flush(self): | ||||||
|         statements = connection.ops.sql_flush( |         statements = connection.ops.sql_flush( | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from django.core.management.color import no_style | from django.core.management.color import no_style | ||||||
| from django.db import connection | from django.db import connection, models | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from ..models import Person, Tag | from ..models import Person, Tag | ||||||
| @@ -86,3 +86,25 @@ class SQLiteOperationsTests(TestCase): | |||||||
|             "zzz'", |             "zzz'", | ||||||
|             statements[-1], |             statements[-1], | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_bulk_batch_size(self): | ||||||
|  |         self.assertEqual(connection.ops.bulk_batch_size([], [Person()]), 1) | ||||||
|  |         first_name_field = Person._meta.get_field("first_name") | ||||||
|  |         last_name_field = Person._meta.get_field("last_name") | ||||||
|  |         self.assertEqual( | ||||||
|  |             connection.ops.bulk_batch_size([first_name_field], [Person()]), 500 | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             connection.ops.bulk_batch_size( | ||||||
|  |                 [first_name_field, last_name_field], [Person()] | ||||||
|  |             ), | ||||||
|  |             connection.features.max_query_params // 2, | ||||||
|  |         ) | ||||||
|  |         composite_pk = models.CompositePrimaryKey("first_name", "last_name") | ||||||
|  |         composite_pk.fields = [first_name_field, last_name_field] | ||||||
|  |         self.assertEqual( | ||||||
|  |             connection.ops.bulk_batch_size( | ||||||
|  |                 [composite_pk, first_name_field], [Person()] | ||||||
|  |             ), | ||||||
|  |             connection.features.max_query_params // 3, | ||||||
|  |         ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user