mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Refs #19527 -- Allowed QuerySet.bulk_create() to set the primary key of its objects.
PostgreSQL support only. Thanks Vladislav Manchev and alesasnouski for working on the patch.
This commit is contained in:
		| @@ -24,6 +24,7 @@ class BaseDatabaseFeatures(object): | |||||||
|  |  | ||||||
|     can_use_chunked_reads = True |     can_use_chunked_reads = True | ||||||
|     can_return_id_from_insert = False |     can_return_id_from_insert = False | ||||||
|  |     can_return_ids_from_bulk_insert = False | ||||||
|     has_bulk_insert = False |     has_bulk_insert = False | ||||||
|     uses_savepoints = False |     uses_savepoints = False | ||||||
|     can_release_savepoints = False |     can_release_savepoints = False | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ from django.db.utils import InterfaceError | |||||||
| class DatabaseFeatures(BaseDatabaseFeatures): | class DatabaseFeatures(BaseDatabaseFeatures): | ||||||
|     allows_group_by_selected_pks = True |     allows_group_by_selected_pks = True | ||||||
|     can_return_id_from_insert = True |     can_return_id_from_insert = True | ||||||
|  |     can_return_ids_from_bulk_insert = True | ||||||
|     has_real_datatype = True |     has_real_datatype = True | ||||||
|     has_native_uuid_field = True |     has_native_uuid_field = True | ||||||
|     has_native_duration_field = True |     has_native_duration_field = True | ||||||
|   | |||||||
| @@ -59,6 +59,14 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|     def deferrable_sql(self): |     def deferrable_sql(self): | ||||||
|         return " DEFERRABLE INITIALLY DEFERRED" |         return " DEFERRABLE INITIALLY DEFERRED" | ||||||
|  |  | ||||||
|  |     def fetch_returned_insert_ids(self, cursor): | ||||||
|  |         """ | ||||||
|  |         Given a cursor object that has just performed an INSERT...RETURNING | ||||||
|  |         statement into a table that has an auto-incrementing ID, return the | ||||||
|  |         list of newly created IDs. | ||||||
|  |         """ | ||||||
|  |         return [item[0] for item in cursor.fetchall()] | ||||||
|  |  | ||||||
|     def lookup_cast(self, lookup_type, internal_type=None): |     def lookup_cast(self, lookup_type, internal_type=None): | ||||||
|         lookup = '%s' |         lookup = '%s' | ||||||
|  |  | ||||||
|   | |||||||
| @@ -411,17 +411,21 @@ class QuerySet(object): | |||||||
|         Inserts each of the instances into the database. This does *not* call |         Inserts each of the instances into the database. This does *not* call | ||||||
|         save() on each of the instances, does not send any pre/post save |         save() on each of the instances, does not send any pre/post save | ||||||
|         signals, and does not set the primary key attribute if it is an |         signals, and does not set the primary key attribute if it is an | ||||||
|         autoincrement field. Multi-table models are not supported. |         autoincrement field (except if features.can_return_ids_from_bulk_insert=True). | ||||||
|  |         Multi-table models are not supported. | ||||||
|         """ |         """ | ||||||
|         # So this case is fun. When you bulk insert you don't get the primary |         # When you bulk insert you don't get the primary keys back (if it's an | ||||||
|         # keys back (if it's an autoincrement), so you can't insert into the |         # autoincrement, except if can_return_ids_from_bulk_insert=True), so | ||||||
|         # child tables which references this. There are two workarounds, 1) |         # you can't insert into the child tables which references this. There | ||||||
|         # this could be implemented if you didn't have an autoincrement pk, |         # are two workarounds: | ||||||
|         # and 2) you could do it by doing O(n) normal inserts into the parent |         # 1) This could be implemented if you didn't have an autoincrement pk | ||||||
|         # tables to get the primary keys back, and then doing a single bulk |         # 2) You could do it by doing O(n) normal inserts into the parent | ||||||
|         # insert into the childmost table. Some databases might allow doing |         #    tables to get the primary keys back and then doing a single bulk | ||||||
|         # this by using RETURNING clause for the insert query. We're punting |         #    insert into the childmost table. | ||||||
|         # on these for now because they are relatively rare cases. |         # We currently set the primary keys on the objects when using | ||||||
|  |         # PostgreSQL via the RETURNING ID clause. It should be possible for | ||||||
|  |         # Oracle as well, but the semantics for  extracting the primary keys is | ||||||
|  |         # trickier so it's not done yet. | ||||||
|         assert batch_size is None or batch_size > 0 |         assert batch_size is None or batch_size > 0 | ||||||
|         # Check that the parents share the same concrete model with the our |         # Check that the parents share the same concrete model with the our | ||||||
|         # model to detect the inheritance pattern ConcreteGrandParent -> |         # model to detect the inheritance pattern ConcreteGrandParent -> | ||||||
| @@ -447,7 +451,11 @@ class QuerySet(object): | |||||||
|                     self._batched_insert(objs_with_pk, fields, batch_size) |                     self._batched_insert(objs_with_pk, fields, batch_size) | ||||||
|                 if objs_without_pk: |                 if objs_without_pk: | ||||||
|                     fields = [f for f in fields if not isinstance(f, AutoField)] |                     fields = [f for f in fields if not isinstance(f, AutoField)] | ||||||
|                     self._batched_insert(objs_without_pk, fields, batch_size) |                     ids = self._batched_insert(objs_without_pk, fields, batch_size) | ||||||
|  |                     if connection.features.can_return_ids_from_bulk_insert: | ||||||
|  |                         assert len(ids) == len(objs_without_pk) | ||||||
|  |                     for i in range(len(ids)): | ||||||
|  |                         objs_without_pk[i].pk = ids[i] | ||||||
|  |  | ||||||
|         return objs |         return objs | ||||||
|  |  | ||||||
| @@ -1051,10 +1059,19 @@ class QuerySet(object): | |||||||
|             return |             return | ||||||
|         ops = connections[self.db].ops |         ops = connections[self.db].ops | ||||||
|         batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) |         batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) | ||||||
|         for batch in [objs[i:i + batch_size] |         inserted_ids = [] | ||||||
|                       for i in range(0, len(objs), batch_size)]: |         for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: | ||||||
|             self.model._base_manager._insert(batch, fields=fields, |             if connections[self.db].features.can_return_ids_from_bulk_insert: | ||||||
|                                              using=self.db) |                 inserted_id = self.model._base_manager._insert( | ||||||
|  |                     item, fields=fields, using=self.db, return_id=True | ||||||
|  |                 ) | ||||||
|  |                 if len(objs) > 1: | ||||||
|  |                     inserted_ids.extend(inserted_id) | ||||||
|  |                 if len(objs) == 1: | ||||||
|  |                     inserted_ids.append(inserted_id) | ||||||
|  |             else: | ||||||
|  |                 self.model._base_manager._insert(item, fields=fields, using=self.db) | ||||||
|  |         return inserted_ids | ||||||
|  |  | ||||||
|     def _clone(self, **kwargs): |     def _clone(self, **kwargs): | ||||||
|         query = self.query.clone() |         query = self.query.clone() | ||||||
|   | |||||||
| @@ -1019,16 +1019,20 @@ class SQLInsertCompiler(SQLCompiler): | |||||||
|         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) |         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) | ||||||
|  |  | ||||||
|         if self.return_id and self.connection.features.can_return_id_from_insert: |         if self.return_id and self.connection.features.can_return_id_from_insert: | ||||||
|             params = param_rows[0] |             if self.connection.features.can_return_ids_from_bulk_insert: | ||||||
|  |                 result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) | ||||||
|  |                 params = param_rows | ||||||
|  |             else: | ||||||
|  |                 result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) | ||||||
|  |                 params = param_rows[0] | ||||||
|             col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) |             col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) | ||||||
|             result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) |  | ||||||
|             r_fmt, r_params = self.connection.ops.return_insert_id() |             r_fmt, r_params = self.connection.ops.return_insert_id() | ||||||
|             # Skip empty r_fmt to allow subclasses to customize behavior for |             # Skip empty r_fmt to allow subclasses to customize behavior for | ||||||
|             # 3rd party backends. Refs #19096. |             # 3rd party backends. Refs #19096. | ||||||
|             if r_fmt: |             if r_fmt: | ||||||
|                 result.append(r_fmt % col) |                 result.append(r_fmt % col) | ||||||
|                 params += r_params |                 params += r_params | ||||||
|             return [(" ".join(result), tuple(params))] |             return [(" ".join(result), tuple(chain.from_iterable(params)))] | ||||||
|  |  | ||||||
|         if can_bulk: |         if can_bulk: | ||||||
|             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) |             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) | ||||||
| @@ -1040,14 +1044,20 @@ class SQLInsertCompiler(SQLCompiler): | |||||||
|             ] |             ] | ||||||
|  |  | ||||||
|     def execute_sql(self, return_id=False): |     def execute_sql(self, return_id=False): | ||||||
|         assert not (return_id and len(self.query.objs) != 1) |         assert not ( | ||||||
|  |             return_id and len(self.query.objs) != 1 and | ||||||
|  |             not self.connection.features.can_return_ids_from_bulk_insert | ||||||
|  |         ) | ||||||
|         self.return_id = return_id |         self.return_id = return_id | ||||||
|         with self.connection.cursor() as cursor: |         with self.connection.cursor() as cursor: | ||||||
|             for sql, params in self.as_sql(): |             for sql, params in self.as_sql(): | ||||||
|                 cursor.execute(sql, params) |                 cursor.execute(sql, params) | ||||||
|             if not (return_id and cursor): |             if not (return_id and cursor): | ||||||
|                 return |                 return | ||||||
|  |             if self.connection.features.can_return_ids_from_bulk_insert and len(self.query.objs) > 1: | ||||||
|  |                 return self.connection.ops.fetch_returned_insert_ids(cursor) | ||||||
|             if self.connection.features.can_return_id_from_insert: |             if self.connection.features.can_return_id_from_insert: | ||||||
|  |                 assert len(self.query.objs) == 1 | ||||||
|                 return self.connection.ops.fetch_returned_insert_id(cursor) |                 return self.connection.ops.fetch_returned_insert_id(cursor) | ||||||
|             return self.connection.ops.last_insert_id(cursor, |             return self.connection.ops.last_insert_id(cursor, | ||||||
|                     self.query.get_meta().db_table, self.query.get_meta().pk.column) |                     self.query.get_meta().db_table, self.query.get_meta().pk.column) | ||||||
|   | |||||||
| @@ -1794,13 +1794,19 @@ This has a number of caveats though: | |||||||
|   ``post_save`` signals will not be sent. |   ``post_save`` signals will not be sent. | ||||||
| * It does not work with child models in a multi-table inheritance scenario. | * It does not work with child models in a multi-table inheritance scenario. | ||||||
| * If the model's primary key is an :class:`~django.db.models.AutoField` it | * If the model's primary key is an :class:`~django.db.models.AutoField` it | ||||||
|   does not retrieve and set the primary key attribute, as ``save()`` does. |   does not retrieve and set the primary key attribute, as ``save()`` does, | ||||||
|  |   unless the database backend supports it (currently PostgreSQL). | ||||||
| * It does not work with many-to-many relationships. | * It does not work with many-to-many relationships. | ||||||
|  |  | ||||||
| .. versionchanged:: 1.9 | .. versionchanged:: 1.9 | ||||||
|  |  | ||||||
|     Support for using ``bulk_create()`` with proxy models was added. |     Support for using ``bulk_create()`` with proxy models was added. | ||||||
|  |  | ||||||
|  | .. versionchanged:: 1.0 | ||||||
|  |  | ||||||
|  |     Support for setting primary keys on objects created using ``bulk_create()`` | ||||||
|  |     when using PostgreSQL was added. | ||||||
|  |  | ||||||
| The ``batch_size`` parameter controls how many objects are created in single | The ``batch_size`` parameter controls how many objects are created in single | ||||||
| query. The default is to create all objects in one batch, except for SQLite | query. The default is to create all objects in one batch, except for SQLite | ||||||
| where the default is such that at most 999 variables per query are used. | where the default is such that at most 999 variables per query are used. | ||||||
|   | |||||||
| @@ -203,6 +203,11 @@ Database backends | |||||||
|  |  | ||||||
| * Temporal data subtraction was unified on all backends. | * Temporal data subtraction was unified on all backends. | ||||||
|  |  | ||||||
|  | * If the database supports it, backends can set | ||||||
|  |   ``DatabaseFeatures.can_return_ids_from_bulk_insert=True`` and implement | ||||||
|  |   ``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys | ||||||
|  |   on objects created using ``QuerySet.bulk_create()``. | ||||||
|  |  | ||||||
| Email | Email | ||||||
| ~~~~~ | ~~~~~ | ||||||
|  |  | ||||||
| @@ -315,6 +320,9 @@ Models | |||||||
| * The :func:`~django.db.models.prefetch_related_objects` function is now a | * The :func:`~django.db.models.prefetch_related_objects` function is now a | ||||||
|   public API. |   public API. | ||||||
|  |  | ||||||
|  | * :meth:`QuerySet.bulk_create() <django.db.models.query.QuerySet.bulk_create>` | ||||||
|  |   sets the primary key on objects when using PostgreSQL. | ||||||
|  |  | ||||||
| Requests and Responses | Requests and Responses | ||||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -198,3 +198,22 @@ class BulkCreateTests(TestCase): | |||||||
|         ]) |         ]) | ||||||
|         bbb = Restaurant.objects.filter(name="betty's beetroot bar") |         bbb = Restaurant.objects.filter(name="betty's beetroot bar") | ||||||
|         self.assertEqual(bbb.count(), 1) |         self.assertEqual(bbb.count(), 1) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('can_return_ids_from_bulk_insert') | ||||||
|  |     def test_set_pk_and_insert_single_item(self): | ||||||
|  |         countries = [] | ||||||
|  |         with self.assertNumQueries(1): | ||||||
|  |             countries = Country.objects.bulk_create([self.data[0]]) | ||||||
|  |         self.assertEqual(len(countries), 1) | ||||||
|  |         self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0]) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('can_return_ids_from_bulk_insert') | ||||||
|  |     def test_set_pk_and_query_efficiency(self): | ||||||
|  |         countries = [] | ||||||
|  |         with self.assertNumQueries(1): | ||||||
|  |             countries = Country.objects.bulk_create(self.data) | ||||||
|  |         self.assertEqual(len(countries), 4) | ||||||
|  |         self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0]) | ||||||
|  |         self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1]) | ||||||
|  |         self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2]) | ||||||
|  |         self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user