diff --git a/django/db/models/query.py b/django/db/models/query.py index cc4d9c1f22..096c7fb864 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -5,6 +5,7 @@ The main QuerySet implementation. This provides the public API for the ORM. import copy import operator import warnings +from contextlib import nullcontext from functools import reduce from itertools import chain, islice @@ -802,7 +803,11 @@ class QuerySet(AltersData): fields = [f for f in opts.concrete_fields if not f.generated] objs = list(objs) objs_with_pk, objs_without_pk = self._prepare_for_bulk_create(objs) - with transaction.atomic(using=self.db, savepoint=False): + if objs_with_pk and objs_without_pk: + context = transaction.atomic(using=self.db, savepoint=False) + else: + context = nullcontext() + with context: self._handle_order_with_respect_to(objs) if objs_with_pk: returned_columns = self._batched_insert( @@ -1919,11 +1924,28 @@ class QuerySet(AltersData): batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size inserted_rows = [] bulk_return = connection.features.can_return_rows_from_bulk_insert - for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]: - if bulk_return and ( - on_conflict is None or on_conflict == OnConflict.UPDATE - ): - inserted_rows.extend( + batches = [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)] + if len(batches) > 1: + context = transaction.atomic(using=self.db, savepoint=False) + else: + context = nullcontext() + with context: + for item in batches: + if bulk_return and ( + on_conflict is None or on_conflict == OnConflict.UPDATE + ): + inserted_rows.extend( + self._insert( + item, + fields=fields, + using=self.db, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + returning_fields=self.model._meta.db_returning_fields, + ) + ) + else: self._insert( item, fields=fields, @@ -1931,18 +1953,7 @@ class QuerySet(AltersData): on_conflict=on_conflict, update_fields=update_fields, unique_fields=unique_fields, - returning_fields=self.model._meta.db_returning_fields, ) - ) - else: - self._insert( - item, - fields=fields, - using=self.db, - on_conflict=on_conflict, - update_fields=update_fields, - unique_fields=unique_fields, - ) return inserted_rows def _chain(self): diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index d590a292de..35180bf487 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -14,6 +14,7 @@ from django.db.models import FileField, Value from django.db.models.functions import Lower, Now from django.test import ( TestCase, + TransactionTestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature, @@ -884,3 +885,35 @@ class BulkCreateTests(TestCase): def test_db_default_primary_key(self): (obj,) = DbDefaultPrimaryKey.objects.bulk_create([DbDefaultPrimaryKey()]) self.assertIsInstance(obj.id, datetime) + + +@skipUnlessDBFeature("supports_transactions", "has_bulk_insert") +class BulkCreateTransactionTests(TransactionTestCase): + available_apps = ["bulk_create"] + + def test_no_unnecessary_transaction(self): + with self.assertNumQueries(1): + Country.objects.bulk_create( + [Country(id=1, name="France", iso_two_letter="FR")] + ) + with self.assertNumQueries(1): + Country.objects.bulk_create([Country(name="Canada", iso_two_letter="CA")]) + + def test_objs_with_and_without_pk(self): + with self.assertNumQueries(4): + Country.objects.bulk_create( + [ + Country(id=1, name="France", iso_two_letter="FR"), + Country(name="Canada", iso_two_letter="CA"), + ] + ) + + def test_multiple_batches(self): + with self.assertNumQueries(4): + Country.objects.bulk_create( + [ + Country(name="France", iso_two_letter="FR"), + Country(name="Canada", iso_two_letter="CA"), + ], + batch_size=1, + )