1
0
mirror of https://github.com/django/django.git synced 2025-10-29 08:36:09 +00:00

Fixed #31685 -- Added support for updating conflicts to QuerySet.bulk_create().

Thanks Florian Apolloner, Chris Jerdonek, Hannes Ljungberg, Nick Pope,
and Mariusz Felisiak for reviews.
This commit is contained in:
sean_c_hsu
2020-06-15 00:58:06 +08:00
committed by Mariusz Felisiak
parent ba9de2e74e
commit 0f6946495a
16 changed files with 542 additions and 43 deletions

View File

@@ -15,7 +15,7 @@ from django.db import (
router, transaction,
)
from django.db.models import AutoField, DateField, DateTimeField, sql
from django.db.models.constants import LOOKUP_SEP
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, Expression, F, Ref, Value, When
from django.db.models.functions import Cast, Trunc
@@ -466,7 +466,69 @@ class QuerySet:
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
obj._prepare_related_fields_for_save(operation_name='bulk_create')
def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields):
if ignore_conflicts and update_conflicts:
raise ValueError(
'ignore_conflicts and update_conflicts are mutually exclusive.'
)
db_features = connections[self.db].features
if ignore_conflicts:
if not db_features.supports_ignore_conflicts:
raise NotSupportedError(
'This database backend does not support ignoring conflicts.'
)
return OnConflict.IGNORE
elif update_conflicts:
if not db_features.supports_update_conflicts:
raise NotSupportedError(
'This database backend does not support updating conflicts.'
)
if not update_fields:
raise ValueError(
'Fields that will be updated when a row insertion fails '
'on conflicts must be provided.'
)
if unique_fields and not db_features.supports_update_conflicts_with_target:
raise NotSupportedError(
'This database backend does not support updating '
'conflicts with specifying unique fields that can trigger '
'the upsert.'
)
if not unique_fields and db_features.supports_update_conflicts_with_target:
raise ValueError(
'Unique fields that can trigger the upsert must be '
'provided.'
)
# Updating primary keys and non-concrete fields is forbidden.
update_fields = [self.model._meta.get_field(name) for name in update_fields]
if any(not f.concrete or f.many_to_many for f in update_fields):
raise ValueError(
'bulk_create() can only be used with concrete fields in '
'update_fields.'
)
if any(f.primary_key for f in update_fields):
raise ValueError(
'bulk_create() cannot be used with primary keys in '
'update_fields.'
)
if unique_fields:
# Primary key is allowed in unique_fields.
unique_fields = [
self.model._meta.get_field(name)
for name in unique_fields if name != 'pk'
]
if any(not f.concrete or f.many_to_many for f in unique_fields):
raise ValueError(
'bulk_create() can only be used with concrete fields '
'in unique_fields.'
)
return OnConflict.UPDATE
return None
def bulk_create(
self, objs, batch_size=None, ignore_conflicts=False,
update_conflicts=False, update_fields=None, unique_fields=None,
):
"""
Insert each of the instances into the database. Do *not* call
save() on each of the instances, do not send any pre/post_save
@@ -497,6 +559,12 @@ class QuerySet:
raise ValueError("Can't bulk create a multi-table inherited model")
if not objs:
return objs
on_conflict = self._check_bulk_create_options(
ignore_conflicts,
update_conflicts,
update_fields,
unique_fields,
)
self._for_write = True
opts = self.model._meta
fields = opts.concrete_fields
@@ -506,7 +574,12 @@ class QuerySet:
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
if objs_with_pk:
returned_columns = self._batched_insert(
objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
objs_with_pk,
fields,
batch_size,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
for obj_with_pk, results in zip(objs_with_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
@@ -518,10 +591,15 @@ class QuerySet:
if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)]
returned_columns = self._batched_insert(
objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
objs_without_pk,
fields,
batch_size,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
connection = connections[self.db]
if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
if connection.features.can_return_rows_from_bulk_insert and on_conflict is None:
assert len(returned_columns) == len(objs_without_pk)
for obj_without_pk, results in zip(objs_without_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
@@ -1293,7 +1371,10 @@ class QuerySet:
# PRIVATE METHODS #
###################
def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
def _insert(
self, objs, fields, returning_fields=None, raw=False, using=None,
on_conflict=None, update_fields=None, unique_fields=None,
):
"""
Insert a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented.
@@ -1301,33 +1382,45 @@ class QuerySet:
self._for_write = True
if using is None:
using = self.db
query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
query = sql.InsertQuery(
self.model,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(returning_fields)
_insert.alters_data = True
_insert.queryset_only = False
def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False):
def _batched_insert(
self, objs, fields, batch_size, on_conflict=None, update_fields=None,
unique_fields=None,
):
"""
Helper method for bulk_create() to insert objs one batch at a time.
"""
connection = connections[self.db]
if ignore_conflicts and not connection.features.supports_ignore_conflicts:
raise NotSupportedError('This database backend does not support ignoring conflicts.')
ops = connection.ops
max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
inserted_rows = []
bulk_return = connection.features.can_return_rows_from_bulk_insert
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and not ignore_conflicts:
if bulk_return and on_conflict is None:
inserted_rows.extend(self._insert(
item, fields=fields, using=self.db,
returning_fields=self.model._meta.db_returning_fields,
ignore_conflicts=ignore_conflicts,
))
else:
self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
self._insert(
item,
fields=fields,
using=self.db,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
return inserted_rows
def _chain(self):