diff --git a/django/db/models/query.py b/django/db/models/query.py index 4aa7f03a5f..84806a5f72 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -670,21 +670,10 @@ class QuerySet(AltersData): acreate.alters_data = True def _prepare_for_bulk_create(self, objs): - from django.db.models.expressions import DatabaseDefault - - connection = connections[self.db] for obj in objs: if not obj._is_pk_set(): # Populate new PK values. obj.pk = obj._meta.pk.get_pk_value_on_save(obj) - if not connection.features.supports_default_keyword_in_bulk_insert: - for field in obj._meta.fields: - if field.generated: - continue - value = getattr(obj, field.attname) - if isinstance(value, DatabaseDefault): - setattr(obj, field.attname, field.db_default) - obj._prepare_related_fields_for_save(operation_name="bulk_create") def _check_bulk_create_options( diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 04372c509e..3bfb3bd631 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1810,23 +1810,65 @@ class SQLInsertCompiler(SQLCompiler): on_conflict=self.query.on_conflict, ) result = ["%s %s" % (insert_statement, qn(opts.db_table))] - fields = self.query.fields or [opts.pk] - result.append("(%s)" % ", ".join(qn(f.column) for f in fields)) - if self.query.fields: - value_rows = [ - [ - self.prepare_value(field, self.pre_save_val(field, obj)) - for field in fields + if fields := list(self.query.fields): + from django.db.models.expressions import DatabaseDefault + + supports_default_keyword_in_bulk_insert = ( + self.connection.features.supports_default_keyword_in_bulk_insert + ) + value_cols = [] + for field in list(fields): + field_prepare = partial(self.prepare_value, field) + field_pre_save = partial(self.pre_save_val, field) + field_values = [ + field_prepare(field_pre_save(obj)) for obj in self.query.objs ] - for obj in self.query.objs - ] + + if not field.has_db_default(): + value_cols.append(field_values) + continue + + # If all values are DEFAULT don't include the field and its + # values in the query as they are redundant and could prevent + # optimizations. This cannot be done if we're dealing with the + # last field as INSERT statements require at least one. + if len(fields) > 1 and all( + isinstance(value, DatabaseDefault) for value in field_values + ): + fields.remove(field) + continue + + if supports_default_keyword_in_bulk_insert: + value_cols.append(field_values) + continue + + # If the field cannot be excluded from the INSERT for the + # reasons listed above and the backend doesn't support the + # DEFAULT keyword each values must be expanded into their + # underlying expressions. + prepared_db_default = field_prepare(field.db_default) + field_values = [ + ( + prepared_db_default + if isinstance(value, DatabaseDefault) + else value + ) + for value in field_values + ] + value_cols.append(field_values) + value_rows = list(zip(*value_cols)) + result.append("(%s)" % ", ".join(qn(f.column) for f in fields)) else: - # An empty object. + # No fields were specified but an INSERT statement must include at + # least one column. This can only happen when the model's primary + # key is composed of a single auto-field so default to including it + # as a placeholder to generate a valid INSERT statement. value_rows = [ [self.connection.ops.pk_default_value()] for _ in self.query.objs ] fields = [None] + result.append("(%s)" % qn(opts.pk.column)) # Currently the backends just accept values when generating bulk # queries and generate their own placeholders. Doing that isn't diff --git a/tests/backends/models.py b/tests/backends/models.py index 1ed108c2b8..afb6ebe303 100644 --- a/tests/backends/models.py +++ b/tests/backends/models.py @@ -5,7 +5,7 @@ from django.db import models class Square(models.Model): root = models.IntegerField() - square = models.PositiveIntegerField() + square = models.PositiveIntegerField(db_default=9) def __str__(self): return "%s ** 2 == %s" % (self.root, self.square) diff --git a/tests/backends/postgresql/test_compilation.py b/tests/backends/postgresql/test_compilation.py index 67fe893e35..5a86a427ff 100644 --- a/tests/backends/postgresql/test_compilation.py +++ b/tests/backends/postgresql/test_compilation.py @@ -27,3 +27,9 @@ class BulkCreateUnnestTests(TestCase): [Square(root=2, square=4), Square(root=3, square=9)] ) self.assertIn("UNNEST", ctx[0]["sql"]) + + def test_unnest_eligible_db_default(self): + with self.assertNumQueries(1) as ctx: + squares = Square.objects.bulk_create([Square(root=3), Square(root=3)]) + self.assertIn("UNNEST", ctx[0]["sql"]) + self.assertEqual([square.square for square in squares], [9, 9]) diff --git a/tests/bulk_create/models.py b/tests/bulk_create/models.py index 8a21c7dfa1..f0df9da66e 100644 --- a/tests/bulk_create/models.py +++ b/tests/bulk_create/models.py @@ -3,6 +3,7 @@ import uuid from decimal import Decimal from django.db import models +from django.db.models.functions import Now from django.utils import timezone try: @@ -141,3 +142,8 @@ class RelatedModel(models.Model): name = models.CharField(max_length=15, null=True) country = models.OneToOneField(Country, models.CASCADE, primary_key=True) big_auto_fields = models.ManyToManyField(BigAutoFieldModel) + + +class DbDefaultModel(models.Model): + name = models.CharField(max_length=10) + created_at = models.DateTimeField(db_default=Now()) diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index 7b86a2def5..83ff8e4514 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -17,10 +17,12 @@ from django.test import ( skipIfDBFeature, skipUnlessDBFeature, ) +from django.utils import timezone from .models import ( BigAutoFieldModel, Country, + DbDefaultModel, FieldsWithDbColumns, NoFields, NullableFields, @@ -840,3 +842,27 @@ class BulkCreateTests(TestCase): {"rank": 2, "name": "d"}, ], ) + + def test_db_default_field_excluded(self): + # created_at is excluded when no db_default override is provided. + with self.assertNumQueries(1) as ctx: + DbDefaultModel.objects.bulk_create( + [DbDefaultModel(name="foo"), DbDefaultModel(name="bar")] + ) + created_at_quoted_name = connection.ops.quote_name("created_at") + self.assertEqual( + ctx[0]["sql"].count(created_at_quoted_name), + 1 if connection.features.can_return_rows_from_bulk_insert else 0, + ) + # created_at is included when a db_default override is provided. + with self.assertNumQueries(1) as ctx: + DbDefaultModel.objects.bulk_create( + [ + DbDefaultModel(name="foo", created_at=timezone.now()), + DbDefaultModel(name="bar"), + ] + ) + self.assertEqual( + ctx[0]["sql"].count(created_at_quoted_name), + 2 if connection.features.can_return_rows_from_bulk_insert else 1, + )