1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #34698 -- Made QuerySet.bulk_create() retrieve primary keys when updating conflicts.

This commit is contained in:
Thomas Chaumeny 2023-07-07 13:08:17 +02:00 committed by Mariusz Felisiak
parent b7a17b0ea0
commit 89c7454dbd
4 changed files with 47 additions and 11 deletions

View File

@ -1837,12 +1837,17 @@ class QuerySet(AltersData):
inserted_rows = [] inserted_rows = []
bulk_return = connection.features.can_return_rows_from_bulk_insert bulk_return = connection.features.can_return_rows_from_bulk_insert
for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]: for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and on_conflict is None: if bulk_return and (
on_conflict is None or on_conflict == OnConflict.UPDATE
):
inserted_rows.extend( inserted_rows.extend(
self._insert( self._insert(
item, item,
fields=fields, fields=fields,
using=self.db, using=self.db,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
returning_fields=self.model._meta.db_returning_fields, returning_fields=self.model._meta.db_returning_fields,
) )
) )

View File

@ -2411,9 +2411,13 @@ On databases that support it (all except Oracle and SQLite < 3.24), setting the
SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may
be in conflict must be provided. be in conflict must be provided.
Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable Enabling the ``ignore_conflicts`` parameter disables setting the primary key on
setting the primary key on each model instance (if the database normally each model instance (if the database normally supports it).
support it).
.. versionchanged:: 5.0
In older versions, enabling the ``update_conflicts`` parameter prevented
setting the primary key on each model instance.
.. warning:: .. warning::

View File

@ -357,6 +357,10 @@ Models
:meth:`.Model.save` now allows specifying a tuple of parent classes that must :meth:`.Model.save` now allows specifying a tuple of parent classes that must
be forced to be inserted. be forced to be inserted.
* :meth:`.QuerySet.bulk_create` and :meth:`.QuerySet.abulk_create` methods now
set the primary key on each model instance when the ``update_conflicts``
parameter is enabled (if the database supports it).
Pagination Pagination
~~~~~~~~~~ ~~~~~~~~~~

View File

@ -582,12 +582,16 @@ class BulkCreateTests(TestCase):
TwoFields(f1=1, f2=1, name="c"), TwoFields(f1=1, f2=1, name="c"),
TwoFields(f1=2, f2=2, name="d"), TwoFields(f1=2, f2=2, name="d"),
] ]
TwoFields.objects.bulk_create( results = TwoFields.objects.bulk_create(
conflicting_objects, conflicting_objects,
update_conflicts=True, update_conflicts=True,
unique_fields=unique_fields, unique_fields=unique_fields,
update_fields=["name"], update_fields=["name"],
) )
self.assertEqual(len(results), len(conflicting_objects))
if connection.features.can_return_rows_from_bulk_insert:
for instance in results:
self.assertIsNotNone(instance.pk)
self.assertEqual(TwoFields.objects.count(), 2) self.assertEqual(TwoFields.objects.count(), 2)
self.assertCountEqual( self.assertCountEqual(
TwoFields.objects.values("f1", "f2", "name"), TwoFields.objects.values("f1", "f2", "name"),
@ -619,7 +623,6 @@ class BulkCreateTests(TestCase):
TwoFields(f1=2, f2=2, name="b"), TwoFields(f1=2, f2=2, name="b"),
] ]
) )
self.assertEqual(TwoFields.objects.count(), 2)
obj1 = TwoFields.objects.get(f1=1) obj1 = TwoFields.objects.get(f1=1)
obj2 = TwoFields.objects.get(f1=2) obj2 = TwoFields.objects.get(f1=2)
@ -627,12 +630,16 @@ class BulkCreateTests(TestCase):
TwoFields(pk=obj1.pk, f1=3, f2=3, name="c"), TwoFields(pk=obj1.pk, f1=3, f2=3, name="c"),
TwoFields(pk=obj2.pk, f1=4, f2=4, name="d"), TwoFields(pk=obj2.pk, f1=4, f2=4, name="d"),
] ]
TwoFields.objects.bulk_create( results = TwoFields.objects.bulk_create(
conflicting_objects, conflicting_objects,
update_conflicts=True, update_conflicts=True,
unique_fields=["pk"], unique_fields=["pk"],
update_fields=["name"], update_fields=["name"],
) )
self.assertEqual(len(results), len(conflicting_objects))
if connection.features.can_return_rows_from_bulk_insert:
for instance in results:
self.assertIsNotNone(instance.pk)
self.assertEqual(TwoFields.objects.count(), 2) self.assertEqual(TwoFields.objects.count(), 2)
self.assertCountEqual( self.assertCountEqual(
TwoFields.objects.values("f1", "f2", "name"), TwoFields.objects.values("f1", "f2", "name"),
@ -680,12 +687,16 @@ class BulkCreateTests(TestCase):
description=("Japan is an island country in East Asia."), description=("Japan is an island country in East Asia."),
), ),
] ]
Country.objects.bulk_create( results = Country.objects.bulk_create(
new_data, new_data,
update_conflicts=True, update_conflicts=True,
update_fields=["description"], update_fields=["description"],
unique_fields=unique_fields, unique_fields=unique_fields,
) )
self.assertEqual(len(results), len(new_data))
if connection.features.can_return_rows_from_bulk_insert:
for instance in results:
self.assertIsNotNone(instance.pk)
self.assertEqual(Country.objects.count(), 6) self.assertEqual(Country.objects.count(), 6)
self.assertCountEqual( self.assertCountEqual(
Country.objects.values("iso_two_letter", "description"), Country.objects.values("iso_two_letter", "description"),
@ -743,12 +754,16 @@ class BulkCreateTests(TestCase):
UpsertConflict(number=2, rank=2, name="Olivia"), UpsertConflict(number=2, rank=2, name="Olivia"),
UpsertConflict(number=3, rank=1, name="Hannah"), UpsertConflict(number=3, rank=1, name="Hannah"),
] ]
UpsertConflict.objects.bulk_create( results = UpsertConflict.objects.bulk_create(
conflicting_objects, conflicting_objects,
update_conflicts=True, update_conflicts=True,
update_fields=["name", "rank"], update_fields=["name", "rank"],
unique_fields=unique_fields, unique_fields=unique_fields,
) )
self.assertEqual(len(results), len(conflicting_objects))
if connection.features.can_return_rows_from_bulk_insert:
for instance in results:
self.assertIsNotNone(instance.pk)
self.assertEqual(UpsertConflict.objects.count(), 3) self.assertEqual(UpsertConflict.objects.count(), 3)
self.assertCountEqual( self.assertCountEqual(
UpsertConflict.objects.values("number", "rank", "name"), UpsertConflict.objects.values("number", "rank", "name"),
@ -759,12 +774,16 @@ class BulkCreateTests(TestCase):
], ],
) )
UpsertConflict.objects.bulk_create( results = UpsertConflict.objects.bulk_create(
conflicting_objects + [UpsertConflict(number=4, rank=4, name="Mark")], conflicting_objects + [UpsertConflict(number=4, rank=4, name="Mark")],
update_conflicts=True, update_conflicts=True,
update_fields=["name", "rank"], update_fields=["name", "rank"],
unique_fields=unique_fields, unique_fields=unique_fields,
) )
self.assertEqual(len(results), 4)
if connection.features.can_return_rows_from_bulk_insert:
for instance in results:
self.assertIsNotNone(instance.pk)
self.assertEqual(UpsertConflict.objects.count(), 4) self.assertEqual(UpsertConflict.objects.count(), 4)
self.assertCountEqual( self.assertCountEqual(
UpsertConflict.objects.values("number", "rank", "name"), UpsertConflict.objects.values("number", "rank", "name"),
@ -803,12 +822,16 @@ class BulkCreateTests(TestCase):
FieldsWithDbColumns(rank=1, name="c"), FieldsWithDbColumns(rank=1, name="c"),
FieldsWithDbColumns(rank=2, name="d"), FieldsWithDbColumns(rank=2, name="d"),
] ]
FieldsWithDbColumns.objects.bulk_create( results = FieldsWithDbColumns.objects.bulk_create(
conflicting_objects, conflicting_objects,
update_conflicts=True, update_conflicts=True,
unique_fields=["rank"], unique_fields=["rank"],
update_fields=["name"], update_fields=["name"],
) )
self.assertEqual(len(results), len(conflicting_objects))
if connection.features.can_return_rows_from_bulk_insert:
for instance in results:
self.assertIsNotNone(instance.pk)
self.assertEqual(FieldsWithDbColumns.objects.count(), 2) self.assertEqual(FieldsWithDbColumns.objects.count(), 2)
self.assertCountEqual( self.assertCountEqual(
FieldsWithDbColumns.objects.values("rank", "name"), FieldsWithDbColumns.objects.values("rank", "name"),