mirror of
				https://github.com/django/django.git
				synced 2025-10-29 08:36:09 +00:00 
			
		
		
		
	Fixed #31685 -- Added support for updating conflicts to QuerySet.bulk_create().
Thanks Florian Apolloner, Chris Jerdonek, Hannes Ljungberg, Nick Pope, and Mariusz Felisiak for reviews.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							ba9de2e74e
						
					
				
				
					commit
					0f6946495a
				
			| @@ -15,7 +15,7 @@ from django.db import ( | ||||
|     router, transaction, | ||||
| ) | ||||
| from django.db.models import AutoField, DateField, DateTimeField, sql | ||||
| from django.db.models.constants import LOOKUP_SEP | ||||
| from django.db.models.constants import LOOKUP_SEP, OnConflict | ||||
| from django.db.models.deletion import Collector | ||||
| from django.db.models.expressions import Case, Expression, F, Ref, Value, When | ||||
| from django.db.models.functions import Cast, Trunc | ||||
| @@ -466,7 +466,69 @@ class QuerySet: | ||||
|                 obj.pk = obj._meta.pk.get_pk_value_on_save(obj) | ||||
|             obj._prepare_related_fields_for_save(operation_name='bulk_create') | ||||
|  | ||||
|     def bulk_create(self, objs, batch_size=None, ignore_conflicts=False): | ||||
|     def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields): | ||||
|         if ignore_conflicts and update_conflicts: | ||||
|             raise ValueError( | ||||
|                 'ignore_conflicts and update_conflicts are mutually exclusive.' | ||||
|             ) | ||||
|         db_features = connections[self.db].features | ||||
|         if ignore_conflicts: | ||||
|             if not db_features.supports_ignore_conflicts: | ||||
|                 raise NotSupportedError( | ||||
|                     'This database backend does not support ignoring conflicts.' | ||||
|                 ) | ||||
|             return OnConflict.IGNORE | ||||
|         elif update_conflicts: | ||||
|             if not db_features.supports_update_conflicts: | ||||
|                 raise NotSupportedError( | ||||
|                     'This database backend does not support updating conflicts.' | ||||
|                 ) | ||||
|             if not update_fields: | ||||
|                 raise ValueError( | ||||
|                     'Fields that will be updated when a row insertion fails ' | ||||
|                     'on conflicts must be provided.' | ||||
|                 ) | ||||
|             if unique_fields and not db_features.supports_update_conflicts_with_target: | ||||
|                 raise NotSupportedError( | ||||
|                     'This database backend does not support updating ' | ||||
|                     'conflicts with specifying unique fields that can trigger ' | ||||
|                     'the upsert.' | ||||
|                 ) | ||||
|             if not unique_fields and db_features.supports_update_conflicts_with_target: | ||||
|                 raise ValueError( | ||||
|                     'Unique fields that can trigger the upsert must be ' | ||||
|                     'provided.' | ||||
|                 ) | ||||
|             # Updating primary keys and non-concrete fields is forbidden. | ||||
|             update_fields = [self.model._meta.get_field(name) for name in update_fields] | ||||
|             if any(not f.concrete or f.many_to_many for f in update_fields): | ||||
|                 raise ValueError( | ||||
|                     'bulk_create() can only be used with concrete fields in ' | ||||
|                     'update_fields.' | ||||
|                 ) | ||||
|             if any(f.primary_key for f in update_fields): | ||||
|                 raise ValueError( | ||||
|                     'bulk_create() cannot be used with primary keys in ' | ||||
|                     'update_fields.' | ||||
|                 ) | ||||
|             if unique_fields: | ||||
|                 # Primary key is allowed in unique_fields. | ||||
|                 unique_fields = [ | ||||
|                     self.model._meta.get_field(name) | ||||
|                     for name in unique_fields if name != 'pk' | ||||
|                 ] | ||||
|                 if any(not f.concrete or f.many_to_many for f in unique_fields): | ||||
|                     raise ValueError( | ||||
|                         'bulk_create() can only be used with concrete fields ' | ||||
|                         'in unique_fields.' | ||||
|                     ) | ||||
|             return OnConflict.UPDATE | ||||
|         return None | ||||
|  | ||||
|     def bulk_create( | ||||
|         self, objs, batch_size=None, ignore_conflicts=False, | ||||
|         update_conflicts=False, update_fields=None, unique_fields=None, | ||||
|     ): | ||||
|         """ | ||||
|         Insert each of the instances into the database. Do *not* call | ||||
|         save() on each of the instances, do not send any pre/post_save | ||||
| @@ -497,6 +559,12 @@ class QuerySet: | ||||
|                 raise ValueError("Can't bulk create a multi-table inherited model") | ||||
|         if not objs: | ||||
|             return objs | ||||
|         on_conflict = self._check_bulk_create_options( | ||||
|             ignore_conflicts, | ||||
|             update_conflicts, | ||||
|             update_fields, | ||||
|             unique_fields, | ||||
|         ) | ||||
|         self._for_write = True | ||||
|         opts = self.model._meta | ||||
|         fields = opts.concrete_fields | ||||
| @@ -506,7 +574,12 @@ class QuerySet: | ||||
|             objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) | ||||
|             if objs_with_pk: | ||||
|                 returned_columns = self._batched_insert( | ||||
|                     objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, | ||||
|                     objs_with_pk, | ||||
|                     fields, | ||||
|                     batch_size, | ||||
|                     on_conflict=on_conflict, | ||||
|                     update_fields=update_fields, | ||||
|                     unique_fields=unique_fields, | ||||
|                 ) | ||||
|                 for obj_with_pk, results in zip(objs_with_pk, returned_columns): | ||||
|                     for result, field in zip(results, opts.db_returning_fields): | ||||
| @@ -518,10 +591,15 @@ class QuerySet: | ||||
|             if objs_without_pk: | ||||
|                 fields = [f for f in fields if not isinstance(f, AutoField)] | ||||
|                 returned_columns = self._batched_insert( | ||||
|                     objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, | ||||
|                     objs_without_pk, | ||||
|                     fields, | ||||
|                     batch_size, | ||||
|                     on_conflict=on_conflict, | ||||
|                     update_fields=update_fields, | ||||
|                     unique_fields=unique_fields, | ||||
|                 ) | ||||
|                 connection = connections[self.db] | ||||
|                 if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts: | ||||
|                 if connection.features.can_return_rows_from_bulk_insert and on_conflict is None: | ||||
|                     assert len(returned_columns) == len(objs_without_pk) | ||||
|                 for obj_without_pk, results in zip(objs_without_pk, returned_columns): | ||||
|                     for result, field in zip(results, opts.db_returning_fields): | ||||
| @@ -1293,7 +1371,10 @@ class QuerySet: | ||||
|     # PRIVATE METHODS # | ||||
|     ################### | ||||
|  | ||||
|     def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False): | ||||
|     def _insert( | ||||
|         self, objs, fields, returning_fields=None, raw=False, using=None, | ||||
|         on_conflict=None, update_fields=None, unique_fields=None, | ||||
|     ): | ||||
|         """ | ||||
|         Insert a new record for the given model. This provides an interface to | ||||
|         the InsertQuery class and is how Model.save() is implemented. | ||||
| @@ -1301,33 +1382,45 @@ class QuerySet: | ||||
|         self._for_write = True | ||||
|         if using is None: | ||||
|             using = self.db | ||||
|         query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts) | ||||
|         query = sql.InsertQuery( | ||||
|             self.model, | ||||
|             on_conflict=on_conflict, | ||||
|             update_fields=update_fields, | ||||
|             unique_fields=unique_fields, | ||||
|         ) | ||||
|         query.insert_values(fields, objs, raw=raw) | ||||
|         return query.get_compiler(using=using).execute_sql(returning_fields) | ||||
|     _insert.alters_data = True | ||||
|     _insert.queryset_only = False | ||||
|  | ||||
|     def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False): | ||||
|     def _batched_insert( | ||||
|         self, objs, fields, batch_size, on_conflict=None, update_fields=None, | ||||
|         unique_fields=None, | ||||
|     ): | ||||
|         """ | ||||
|         Helper method for bulk_create() to insert objs one batch at a time. | ||||
|         """ | ||||
|         connection = connections[self.db] | ||||
|         if ignore_conflicts and not connection.features.supports_ignore_conflicts: | ||||
|             raise NotSupportedError('This database backend does not support ignoring conflicts.') | ||||
|         ops = connection.ops | ||||
|         max_batch_size = max(ops.bulk_batch_size(fields, objs), 1) | ||||
|         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size | ||||
|         inserted_rows = [] | ||||
|         bulk_return = connection.features.can_return_rows_from_bulk_insert | ||||
|         for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: | ||||
|             if bulk_return and not ignore_conflicts: | ||||
|             if bulk_return and on_conflict is None: | ||||
|                 inserted_rows.extend(self._insert( | ||||
|                     item, fields=fields, using=self.db, | ||||
|                     returning_fields=self.model._meta.db_returning_fields, | ||||
|                     ignore_conflicts=ignore_conflicts, | ||||
|                 )) | ||||
|             else: | ||||
|                 self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts) | ||||
|                 self._insert( | ||||
|                     item, | ||||
|                     fields=fields, | ||||
|                     using=self.db, | ||||
|                     on_conflict=on_conflict, | ||||
|                     update_fields=update_fields, | ||||
|                     unique_fields=unique_fields, | ||||
|                 ) | ||||
|         return inserted_rows | ||||
|  | ||||
|     def _chain(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user