mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #31685 -- Added support for updating conflicts to QuerySet.bulk_create().
Thanks Florian Apolloner, Chris Jerdonek, Hannes Ljungberg, Nick Pope, and Mariusz Felisiak for reviews.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							ba9de2e74e
						
					
				
				
					commit
					0f6946495a
				
			| @@ -271,6 +271,10 @@ class BaseDatabaseFeatures: | |||||||
|     # Does the backend support ignoring constraint or uniqueness errors during |     # Does the backend support ignoring constraint or uniqueness errors during | ||||||
|     # INSERT? |     # INSERT? | ||||||
|     supports_ignore_conflicts = True |     supports_ignore_conflicts = True | ||||||
|  |     # Does the backend support updating rows on constraint or uniqueness errors | ||||||
|  |     # during INSERT? | ||||||
|  |     supports_update_conflicts = False | ||||||
|  |     supports_update_conflicts_with_target = False | ||||||
|  |  | ||||||
|     # Does this backend require casting the results of CASE expressions used |     # Does this backend require casting the results of CASE expressions used | ||||||
|     # in UPDATE statements to ensure the expression has the correct type? |     # in UPDATE statements to ensure the expression has the correct type? | ||||||
|   | |||||||
| @@ -717,8 +717,8 @@ class BaseDatabaseOperations: | |||||||
|             raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys()))) |             raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys()))) | ||||||
|         return self.explain_prefix |         return self.explain_prefix | ||||||
|  |  | ||||||
|     def insert_statement(self, ignore_conflicts=False): |     def insert_statement(self, on_conflict=None): | ||||||
|         return 'INSERT INTO' |         return 'INSERT INTO' | ||||||
|  |  | ||||||
|     def ignore_conflicts_suffix_sql(self, ignore_conflicts=None): |     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||||
|         return '' |         return '' | ||||||
|   | |||||||
| @@ -24,6 +24,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|     supports_select_difference = False |     supports_select_difference = False | ||||||
|     supports_slicing_ordering_in_compound = True |     supports_slicing_ordering_in_compound = True | ||||||
|     supports_index_on_text_field = False |     supports_index_on_text_field = False | ||||||
|  |     supports_update_conflicts = True | ||||||
|     create_test_procedure_without_params_sql = """ |     create_test_procedure_without_params_sql = """ | ||||||
|         CREATE PROCEDURE test_procedure () |         CREATE PROCEDURE test_procedure () | ||||||
|         BEGIN |         BEGIN | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ from django.conf import settings | |||||||
| from django.db.backends.base.operations import BaseDatabaseOperations | from django.db.backends.base.operations import BaseDatabaseOperations | ||||||
| from django.db.backends.utils import split_tzname_delta | from django.db.backends.utils import split_tzname_delta | ||||||
| from django.db.models import Exists, ExpressionWrapper, Lookup | from django.db.models import Exists, ExpressionWrapper, Lookup | ||||||
|  | from django.db.models.constants import OnConflict | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| from django.utils.encoding import force_str | from django.utils.encoding import force_str | ||||||
|  |  | ||||||
| @@ -365,8 +366,10 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         match_option = 'c' if lookup_type == 'regex' else 'i' |         match_option = 'c' if lookup_type == 'regex' else 'i' | ||||||
|         return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option |         return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option | ||||||
|  |  | ||||||
|     def insert_statement(self, ignore_conflicts=False): |     def insert_statement(self, on_conflict=None): | ||||||
|         return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts) |         if on_conflict == OnConflict.IGNORE: | ||||||
|  |             return 'INSERT IGNORE INTO' | ||||||
|  |         return super().insert_statement(on_conflict=on_conflict) | ||||||
|  |  | ||||||
|     def lookup_cast(self, lookup_type, internal_type=None): |     def lookup_cast(self, lookup_type, internal_type=None): | ||||||
|         lookup = '%s' |         lookup = '%s' | ||||||
| @@ -388,3 +391,27 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         if getattr(expression, 'conditional', False): |         if getattr(expression, 'conditional', False): | ||||||
|             return False |             return False | ||||||
|         return super().conditional_expression_supported_in_where_clause(expression) |         return super().conditional_expression_supported_in_where_clause(expression) | ||||||
|  |  | ||||||
|  |     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||||
|  |         if on_conflict == OnConflict.UPDATE: | ||||||
|  |             conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s' | ||||||
|  |             field_sql = '%(field)s = VALUES(%(field)s)' | ||||||
|  |             # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use | ||||||
|  |             # aliases for the new row and its columns available in MySQL | ||||||
|  |             # 8.0.19+. | ||||||
|  |             if not self.connection.mysql_is_mariadb: | ||||||
|  |                 if self.connection.mysql_version >= (8, 0, 19): | ||||||
|  |                     conflict_suffix_sql = f'AS new {conflict_suffix_sql}' | ||||||
|  |                     field_sql = '%(field)s = new.%(field)s' | ||||||
|  |             # VALUES() was renamed to VALUE() in MariaDB 10.3.3+. | ||||||
|  |             elif self.connection.mysql_version >= (10, 3, 3): | ||||||
|  |                 field_sql = '%(field)s = VALUE(%(field)s)' | ||||||
|  |  | ||||||
|  |             fields = ', '.join([ | ||||||
|  |                 field_sql % {'field': field} | ||||||
|  |                 for field in map(self.quote_name, update_fields) | ||||||
|  |             ]) | ||||||
|  |             return conflict_suffix_sql % {'fields': fields} | ||||||
|  |         return super().on_conflict_suffix_sql( | ||||||
|  |             fields, on_conflict, update_fields, unique_fields, | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -57,6 +57,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|     supports_deferrable_unique_constraints = True |     supports_deferrable_unique_constraints = True | ||||||
|     has_json_operators = True |     has_json_operators = True | ||||||
|     json_key_contains_list_matching_requires_list = True |     json_key_contains_list_matching_requires_list = True | ||||||
|  |     supports_update_conflicts = True | ||||||
|  |     supports_update_conflicts_with_target = True | ||||||
|     test_collations = { |     test_collations = { | ||||||
|         'non_default': 'sv-x-icu', |         'non_default': 'sv-x-icu', | ||||||
|         'swedish_ci': 'sv-x-icu', |         'swedish_ci': 'sv-x-icu', | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ from psycopg2.extras import Inet | |||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db.backends.base.operations import BaseDatabaseOperations | from django.db.backends.base.operations import BaseDatabaseOperations | ||||||
| from django.db.backends.utils import split_tzname_delta | from django.db.backends.utils import split_tzname_delta | ||||||
|  | from django.db.models.constants import OnConflict | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseOperations(BaseDatabaseOperations): | class DatabaseOperations(BaseDatabaseOperations): | ||||||
| @@ -272,5 +273,17 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|             prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items()) |             prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items()) | ||||||
|         return prefix |         return prefix | ||||||
|  |  | ||||||
|     def ignore_conflicts_suffix_sql(self, ignore_conflicts=None): |     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||||
|         return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts) |         if on_conflict == OnConflict.IGNORE: | ||||||
|  |             return 'ON CONFLICT DO NOTHING' | ||||||
|  |         if on_conflict == OnConflict.UPDATE: | ||||||
|  |             return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( | ||||||
|  |                 ', '.join(map(self.quote_name, unique_fields)), | ||||||
|  |                 ', '.join([ | ||||||
|  |                     f'{field} = EXCLUDED.{field}' | ||||||
|  |                     for field in map(self.quote_name, update_fields) | ||||||
|  |                 ]), | ||||||
|  |             ) | ||||||
|  |         return super().on_conflict_suffix_sql( | ||||||
|  |             fields, on_conflict, update_fields, unique_fields, | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|     supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0) |     supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0) | ||||||
|     order_by_nulls_first = True |     order_by_nulls_first = True | ||||||
|     supports_json_field_contains = False |     supports_json_field_contains = False | ||||||
|  |     supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0) | ||||||
|  |     supports_update_conflicts_with_target = supports_update_conflicts | ||||||
|     test_collations = { |     test_collations = { | ||||||
|         'ci': 'nocase', |         'ci': 'nocase', | ||||||
|         'cs': 'binary', |         'cs': 'binary', | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ from django.conf import settings | |||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db import DatabaseError, NotSupportedError, models | from django.db import DatabaseError, NotSupportedError, models | ||||||
| from django.db.backends.base.operations import BaseDatabaseOperations | from django.db.backends.base.operations import BaseDatabaseOperations | ||||||
|  | from django.db.models.constants import OnConflict | ||||||
| from django.db.models.expressions import Col | from django.db.models.expressions import Col | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| from django.utils.dateparse import parse_date, parse_datetime, parse_time | from django.utils.dateparse import parse_date, parse_datetime, parse_time | ||||||
| @@ -370,8 +371,10 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|             return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params |             return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params | ||||||
|         return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params |         return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params | ||||||
|  |  | ||||||
|     def insert_statement(self, ignore_conflicts=False): |     def insert_statement(self, on_conflict=None): | ||||||
|         return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts) |         if on_conflict == OnConflict.IGNORE: | ||||||
|  |             return 'INSERT OR IGNORE INTO' | ||||||
|  |         return super().insert_statement(on_conflict=on_conflict) | ||||||
|  |  | ||||||
|     def return_insert_columns(self, fields): |     def return_insert_columns(self, fields): | ||||||
|         # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. |         # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. | ||||||
| @@ -384,3 +387,19 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|             ) for field in fields |             ) for field in fields | ||||||
|         ] |         ] | ||||||
|         return 'RETURNING %s' % ', '.join(columns), () |         return 'RETURNING %s' % ', '.join(columns), () | ||||||
|  |  | ||||||
|  |     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||||
|  |         if ( | ||||||
|  |             on_conflict == OnConflict.UPDATE and | ||||||
|  |             self.connection.features.supports_update_conflicts_with_target | ||||||
|  |         ): | ||||||
|  |             return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( | ||||||
|  |                 ', '.join(map(self.quote_name, unique_fields)), | ||||||
|  |                 ', '.join([ | ||||||
|  |                     f'{field} = EXCLUDED.{field}' | ||||||
|  |                     for field in map(self.quote_name, update_fields) | ||||||
|  |                 ]), | ||||||
|  |             ) | ||||||
|  |         return super().on_conflict_suffix_sql( | ||||||
|  |             fields, on_conflict, update_fields, unique_fields, | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -1,6 +1,12 @@ | |||||||
| """ | """ | ||||||
| Constants used across the ORM in general. | Constants used across the ORM in general. | ||||||
| """ | """ | ||||||
|  | from enum import Enum | ||||||
|  |  | ||||||
| # Separator used to split filter strings apart. | # Separator used to split filter strings apart. | ||||||
| LOOKUP_SEP = '__' | LOOKUP_SEP = '__' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OnConflict(Enum): | ||||||
|  |     IGNORE = 'ignore' | ||||||
|  |     UPDATE = 'update' | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ from django.db import ( | |||||||
|     router, transaction, |     router, transaction, | ||||||
| ) | ) | ||||||
| from django.db.models import AutoField, DateField, DateTimeField, sql | from django.db.models import AutoField, DateField, DateTimeField, sql | ||||||
| from django.db.models.constants import LOOKUP_SEP | from django.db.models.constants import LOOKUP_SEP, OnConflict | ||||||
| from django.db.models.deletion import Collector | from django.db.models.deletion import Collector | ||||||
| from django.db.models.expressions import Case, Expression, F, Ref, Value, When | from django.db.models.expressions import Case, Expression, F, Ref, Value, When | ||||||
| from django.db.models.functions import Cast, Trunc | from django.db.models.functions import Cast, Trunc | ||||||
| @@ -466,7 +466,69 @@ class QuerySet: | |||||||
|                 obj.pk = obj._meta.pk.get_pk_value_on_save(obj) |                 obj.pk = obj._meta.pk.get_pk_value_on_save(obj) | ||||||
|             obj._prepare_related_fields_for_save(operation_name='bulk_create') |             obj._prepare_related_fields_for_save(operation_name='bulk_create') | ||||||
|  |  | ||||||
|     def bulk_create(self, objs, batch_size=None, ignore_conflicts=False): |     def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields): | ||||||
|  |         if ignore_conflicts and update_conflicts: | ||||||
|  |             raise ValueError( | ||||||
|  |                 'ignore_conflicts and update_conflicts are mutually exclusive.' | ||||||
|  |             ) | ||||||
|  |         db_features = connections[self.db].features | ||||||
|  |         if ignore_conflicts: | ||||||
|  |             if not db_features.supports_ignore_conflicts: | ||||||
|  |                 raise NotSupportedError( | ||||||
|  |                     'This database backend does not support ignoring conflicts.' | ||||||
|  |                 ) | ||||||
|  |             return OnConflict.IGNORE | ||||||
|  |         elif update_conflicts: | ||||||
|  |             if not db_features.supports_update_conflicts: | ||||||
|  |                 raise NotSupportedError( | ||||||
|  |                     'This database backend does not support updating conflicts.' | ||||||
|  |                 ) | ||||||
|  |             if not update_fields: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     'Fields that will be updated when a row insertion fails ' | ||||||
|  |                     'on conflicts must be provided.' | ||||||
|  |                 ) | ||||||
|  |             if unique_fields and not db_features.supports_update_conflicts_with_target: | ||||||
|  |                 raise NotSupportedError( | ||||||
|  |                     'This database backend does not support updating ' | ||||||
|  |                     'conflicts with specifying unique fields that can trigger ' | ||||||
|  |                     'the upsert.' | ||||||
|  |                 ) | ||||||
|  |             if not unique_fields and db_features.supports_update_conflicts_with_target: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     'Unique fields that can trigger the upsert must be ' | ||||||
|  |                     'provided.' | ||||||
|  |                 ) | ||||||
|  |             # Updating primary keys and non-concrete fields is forbidden. | ||||||
|  |             update_fields = [self.model._meta.get_field(name) for name in update_fields] | ||||||
|  |             if any(not f.concrete or f.many_to_many for f in update_fields): | ||||||
|  |                 raise ValueError( | ||||||
|  |                     'bulk_create() can only be used with concrete fields in ' | ||||||
|  |                     'update_fields.' | ||||||
|  |                 ) | ||||||
|  |             if any(f.primary_key for f in update_fields): | ||||||
|  |                 raise ValueError( | ||||||
|  |                     'bulk_create() cannot be used with primary keys in ' | ||||||
|  |                     'update_fields.' | ||||||
|  |                 ) | ||||||
|  |             if unique_fields: | ||||||
|  |                 # Primary key is allowed in unique_fields. | ||||||
|  |                 unique_fields = [ | ||||||
|  |                     self.model._meta.get_field(name) | ||||||
|  |                     for name in unique_fields if name != 'pk' | ||||||
|  |                 ] | ||||||
|  |                 if any(not f.concrete or f.many_to_many for f in unique_fields): | ||||||
|  |                     raise ValueError( | ||||||
|  |                         'bulk_create() can only be used with concrete fields ' | ||||||
|  |                         'in unique_fields.' | ||||||
|  |                     ) | ||||||
|  |             return OnConflict.UPDATE | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |     def bulk_create( | ||||||
|  |         self, objs, batch_size=None, ignore_conflicts=False, | ||||||
|  |         update_conflicts=False, update_fields=None, unique_fields=None, | ||||||
|  |     ): | ||||||
|         """ |         """ | ||||||
|         Insert each of the instances into the database. Do *not* call |         Insert each of the instances into the database. Do *not* call | ||||||
|         save() on each of the instances, do not send any pre/post_save |         save() on each of the instances, do not send any pre/post_save | ||||||
| @@ -497,6 +559,12 @@ class QuerySet: | |||||||
|                 raise ValueError("Can't bulk create a multi-table inherited model") |                 raise ValueError("Can't bulk create a multi-table inherited model") | ||||||
|         if not objs: |         if not objs: | ||||||
|             return objs |             return objs | ||||||
|  |         on_conflict = self._check_bulk_create_options( | ||||||
|  |             ignore_conflicts, | ||||||
|  |             update_conflicts, | ||||||
|  |             update_fields, | ||||||
|  |             unique_fields, | ||||||
|  |         ) | ||||||
|         self._for_write = True |         self._for_write = True | ||||||
|         opts = self.model._meta |         opts = self.model._meta | ||||||
|         fields = opts.concrete_fields |         fields = opts.concrete_fields | ||||||
| @@ -506,7 +574,12 @@ class QuerySet: | |||||||
|             objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) |             objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) | ||||||
|             if objs_with_pk: |             if objs_with_pk: | ||||||
|                 returned_columns = self._batched_insert( |                 returned_columns = self._batched_insert( | ||||||
|                     objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, |                     objs_with_pk, | ||||||
|  |                     fields, | ||||||
|  |                     batch_size, | ||||||
|  |                     on_conflict=on_conflict, | ||||||
|  |                     update_fields=update_fields, | ||||||
|  |                     unique_fields=unique_fields, | ||||||
|                 ) |                 ) | ||||||
|                 for obj_with_pk, results in zip(objs_with_pk, returned_columns): |                 for obj_with_pk, results in zip(objs_with_pk, returned_columns): | ||||||
|                     for result, field in zip(results, opts.db_returning_fields): |                     for result, field in zip(results, opts.db_returning_fields): | ||||||
| @@ -518,10 +591,15 @@ class QuerySet: | |||||||
|             if objs_without_pk: |             if objs_without_pk: | ||||||
|                 fields = [f for f in fields if not isinstance(f, AutoField)] |                 fields = [f for f in fields if not isinstance(f, AutoField)] | ||||||
|                 returned_columns = self._batched_insert( |                 returned_columns = self._batched_insert( | ||||||
|                     objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, |                     objs_without_pk, | ||||||
|  |                     fields, | ||||||
|  |                     batch_size, | ||||||
|  |                     on_conflict=on_conflict, | ||||||
|  |                     update_fields=update_fields, | ||||||
|  |                     unique_fields=unique_fields, | ||||||
|                 ) |                 ) | ||||||
|                 connection = connections[self.db] |                 connection = connections[self.db] | ||||||
|                 if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts: |                 if connection.features.can_return_rows_from_bulk_insert and on_conflict is None: | ||||||
|                     assert len(returned_columns) == len(objs_without_pk) |                     assert len(returned_columns) == len(objs_without_pk) | ||||||
|                 for obj_without_pk, results in zip(objs_without_pk, returned_columns): |                 for obj_without_pk, results in zip(objs_without_pk, returned_columns): | ||||||
|                     for result, field in zip(results, opts.db_returning_fields): |                     for result, field in zip(results, opts.db_returning_fields): | ||||||
| @@ -1293,7 +1371,10 @@ class QuerySet: | |||||||
|     # PRIVATE METHODS # |     # PRIVATE METHODS # | ||||||
|     ################### |     ################### | ||||||
|  |  | ||||||
|     def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False): |     def _insert( | ||||||
|  |         self, objs, fields, returning_fields=None, raw=False, using=None, | ||||||
|  |         on_conflict=None, update_fields=None, unique_fields=None, | ||||||
|  |     ): | ||||||
|         """ |         """ | ||||||
|         Insert a new record for the given model. This provides an interface to |         Insert a new record for the given model. This provides an interface to | ||||||
|         the InsertQuery class and is how Model.save() is implemented. |         the InsertQuery class and is how Model.save() is implemented. | ||||||
| @@ -1301,33 +1382,45 @@ class QuerySet: | |||||||
|         self._for_write = True |         self._for_write = True | ||||||
|         if using is None: |         if using is None: | ||||||
|             using = self.db |             using = self.db | ||||||
|         query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts) |         query = sql.InsertQuery( | ||||||
|  |             self.model, | ||||||
|  |             on_conflict=on_conflict, | ||||||
|  |             update_fields=update_fields, | ||||||
|  |             unique_fields=unique_fields, | ||||||
|  |         ) | ||||||
|         query.insert_values(fields, objs, raw=raw) |         query.insert_values(fields, objs, raw=raw) | ||||||
|         return query.get_compiler(using=using).execute_sql(returning_fields) |         return query.get_compiler(using=using).execute_sql(returning_fields) | ||||||
|     _insert.alters_data = True |     _insert.alters_data = True | ||||||
|     _insert.queryset_only = False |     _insert.queryset_only = False | ||||||
|  |  | ||||||
|     def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False): |     def _batched_insert( | ||||||
|  |         self, objs, fields, batch_size, on_conflict=None, update_fields=None, | ||||||
|  |         unique_fields=None, | ||||||
|  |     ): | ||||||
|         """ |         """ | ||||||
|         Helper method for bulk_create() to insert objs one batch at a time. |         Helper method for bulk_create() to insert objs one batch at a time. | ||||||
|         """ |         """ | ||||||
|         connection = connections[self.db] |         connection = connections[self.db] | ||||||
|         if ignore_conflicts and not connection.features.supports_ignore_conflicts: |  | ||||||
|             raise NotSupportedError('This database backend does not support ignoring conflicts.') |  | ||||||
|         ops = connection.ops |         ops = connection.ops | ||||||
|         max_batch_size = max(ops.bulk_batch_size(fields, objs), 1) |         max_batch_size = max(ops.bulk_batch_size(fields, objs), 1) | ||||||
|         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size |         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size | ||||||
|         inserted_rows = [] |         inserted_rows = [] | ||||||
|         bulk_return = connection.features.can_return_rows_from_bulk_insert |         bulk_return = connection.features.can_return_rows_from_bulk_insert | ||||||
|         for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: |         for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: | ||||||
|             if bulk_return and not ignore_conflicts: |             if bulk_return and on_conflict is None: | ||||||
|                 inserted_rows.extend(self._insert( |                 inserted_rows.extend(self._insert( | ||||||
|                     item, fields=fields, using=self.db, |                     item, fields=fields, using=self.db, | ||||||
|                     returning_fields=self.model._meta.db_returning_fields, |                     returning_fields=self.model._meta.db_returning_fields, | ||||||
|                     ignore_conflicts=ignore_conflicts, |  | ||||||
|                 )) |                 )) | ||||||
|             else: |             else: | ||||||
|                 self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts) |                 self._insert( | ||||||
|  |                     item, | ||||||
|  |                     fields=fields, | ||||||
|  |                     using=self.db, | ||||||
|  |                     on_conflict=on_conflict, | ||||||
|  |                     update_fields=update_fields, | ||||||
|  |                     unique_fields=unique_fields, | ||||||
|  |                 ) | ||||||
|         return inserted_rows |         return inserted_rows | ||||||
|  |  | ||||||
|     def _chain(self): |     def _chain(self): | ||||||
|   | |||||||
| @@ -1387,7 +1387,9 @@ class SQLInsertCompiler(SQLCompiler): | |||||||
|         # going to be column names (so we can avoid the extra overhead). |         # going to be column names (so we can avoid the extra overhead). | ||||||
|         qn = self.connection.ops.quote_name |         qn = self.connection.ops.quote_name | ||||||
|         opts = self.query.get_meta() |         opts = self.query.get_meta() | ||||||
|         insert_statement = self.connection.ops.insert_statement(ignore_conflicts=self.query.ignore_conflicts) |         insert_statement = self.connection.ops.insert_statement( | ||||||
|  |             on_conflict=self.query.on_conflict, | ||||||
|  |         ) | ||||||
|         result = ['%s %s' % (insert_statement, qn(opts.db_table))] |         result = ['%s %s' % (insert_statement, qn(opts.db_table))] | ||||||
|         fields = self.query.fields or [opts.pk] |         fields = self.query.fields or [opts.pk] | ||||||
|         result.append('(%s)' % ', '.join(qn(f.column) for f in fields)) |         result.append('(%s)' % ', '.join(qn(f.column) for f in fields)) | ||||||
| @@ -1410,8 +1412,11 @@ class SQLInsertCompiler(SQLCompiler): | |||||||
|  |  | ||||||
|         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) |         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) | ||||||
|  |  | ||||||
|         ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql( |         on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql( | ||||||
|             ignore_conflicts=self.query.ignore_conflicts |             fields, | ||||||
|  |             self.query.on_conflict, | ||||||
|  |             self.query.update_fields, | ||||||
|  |             self.query.unique_fields, | ||||||
|         ) |         ) | ||||||
|         if self.returning_fields and self.connection.features.can_return_columns_from_insert: |         if self.returning_fields and self.connection.features.can_return_columns_from_insert: | ||||||
|             if self.connection.features.can_return_rows_from_bulk_insert: |             if self.connection.features.can_return_rows_from_bulk_insert: | ||||||
| @@ -1420,8 +1425,8 @@ class SQLInsertCompiler(SQLCompiler): | |||||||
|             else: |             else: | ||||||
|                 result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) |                 result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) | ||||||
|                 params = [param_rows[0]] |                 params = [param_rows[0]] | ||||||
|             if ignore_conflicts_suffix_sql: |             if on_conflict_suffix_sql: | ||||||
|                 result.append(ignore_conflicts_suffix_sql) |                 result.append(on_conflict_suffix_sql) | ||||||
|             # Skip empty r_sql to allow subclasses to customize behavior for |             # Skip empty r_sql to allow subclasses to customize behavior for | ||||||
|             # 3rd party backends. Refs #19096. |             # 3rd party backends. Refs #19096. | ||||||
|             r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields) |             r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields) | ||||||
| @@ -1432,12 +1437,12 @@ class SQLInsertCompiler(SQLCompiler): | |||||||
|  |  | ||||||
|         if can_bulk: |         if can_bulk: | ||||||
|             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) |             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) | ||||||
|             if ignore_conflicts_suffix_sql: |             if on_conflict_suffix_sql: | ||||||
|                 result.append(ignore_conflicts_suffix_sql) |                 result.append(on_conflict_suffix_sql) | ||||||
|             return [(" ".join(result), tuple(p for ps in param_rows for p in ps))] |             return [(" ".join(result), tuple(p for ps in param_rows for p in ps))] | ||||||
|         else: |         else: | ||||||
|             if ignore_conflicts_suffix_sql: |             if on_conflict_suffix_sql: | ||||||
|                 result.append(ignore_conflicts_suffix_sql) |                 result.append(on_conflict_suffix_sql) | ||||||
|             return [ |             return [ | ||||||
|                 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) |                 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) | ||||||
|                 for p, vals in zip(placeholder_rows, param_rows) |                 for p, vals in zip(placeholder_rows, param_rows) | ||||||
|   | |||||||
| @@ -138,11 +138,13 @@ class UpdateQuery(Query): | |||||||
| class InsertQuery(Query): | class InsertQuery(Query): | ||||||
|     compiler = 'SQLInsertCompiler' |     compiler = 'SQLInsertCompiler' | ||||||
|  |  | ||||||
|     def __init__(self, *args, ignore_conflicts=False, **kwargs): |     def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.fields = [] |         self.fields = [] | ||||||
|         self.objs = [] |         self.objs = [] | ||||||
|         self.ignore_conflicts = ignore_conflicts |         self.on_conflict = on_conflict | ||||||
|  |         self.update_fields = update_fields or [] | ||||||
|  |         self.unique_fields = unique_fields or [] | ||||||
|  |  | ||||||
|     def insert_values(self, fields, objs, raw=False): |     def insert_values(self, fields, objs, raw=False): | ||||||
|         self.fields = fields |         self.fields = fields | ||||||
|   | |||||||
| @@ -2155,7 +2155,7 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised. | |||||||
| ``bulk_create()`` | ``bulk_create()`` | ||||||
| ~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
| .. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False) | .. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False, update_conflicts=False, update_fields=None, unique_fields=None) | ||||||
|  |  | ||||||
| This method inserts the provided list of objects into the database in an | This method inserts the provided list of objects into the database in an | ||||||
| efficient manner (generally only 1 query, no matter how many objects there | efficient manner (generally only 1 query, no matter how many objects there | ||||||
| @@ -2198,9 +2198,17 @@ where the default is such that at most 999 variables per query are used. | |||||||
|  |  | ||||||
| On databases that support it (all but Oracle), setting the ``ignore_conflicts`` | On databases that support it (all but Oracle), setting the ``ignore_conflicts`` | ||||||
| parameter to ``True`` tells the database to ignore failure to insert any rows | parameter to ``True`` tells the database to ignore failure to insert any rows | ||||||
| that fail constraints such as duplicate unique values. Enabling this parameter | that fail constraints such as duplicate unique values. | ||||||
| disables setting the primary key on each model instance (if the database |  | ||||||
| normally supports it). | On databases that support it (all except Oracle and SQLite < 3.24), setting the | ||||||
|  | ``update_conflicts`` parameter to ``True``, tells the database to update | ||||||
|  | ``update_fields`` when a row insertion fails on conflicts. On PostgreSQL and | ||||||
|  | SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may | ||||||
|  | be in conflict must be provided. | ||||||
|  |  | ||||||
|  | Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable | ||||||
|  | setting the primary key on each model instance (if the database normally | ||||||
|  | support it). | ||||||
|  |  | ||||||
| .. warning:: | .. warning:: | ||||||
|  |  | ||||||
| @@ -2217,6 +2225,12 @@ normally supports it). | |||||||
|  |  | ||||||
|     Support for the fetching primary key attributes on SQLite 3.35+ was added. |     Support for the fetching primary key attributes on SQLite 3.35+ was added. | ||||||
|  |  | ||||||
|  | .. versionchanged:: 4.1 | ||||||
|  |  | ||||||
|  |     The ``update_conflicts``, ``update_fields``, and ``unique_fields`` | ||||||
|  |     parameters were added to support updating fields when a row insertion fails | ||||||
|  |     on conflict. | ||||||
|  |  | ||||||
| ``bulk_update()`` | ``bulk_update()`` | ||||||
| ~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -232,6 +232,10 @@ Models | |||||||
|   in order to reduce the number of failed requests, e.g. after database server |   in order to reduce the number of failed requests, e.g. after database server | ||||||
|   restart. |   restart. | ||||||
|  |  | ||||||
|  | * :meth:`.QuerySet.bulk_create` now supports updating fields when a row | ||||||
|  |   insertion fails uniqueness constraints. This is supported on MariaDB, MySQL, | ||||||
|  |   PostgreSQL, and SQLite 3.24+. | ||||||
|  |  | ||||||
| Requests and Responses | Requests and Responses | ||||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
| @@ -298,6 +302,14 @@ backends. | |||||||
| * ``DatabaseIntrospection.get_key_columns()`` is removed. Use | * ``DatabaseIntrospection.get_key_columns()`` is removed. Use | ||||||
|   ``DatabaseIntrospection.get_relations()`` instead. |   ``DatabaseIntrospection.get_relations()`` instead. | ||||||
|  |  | ||||||
|  | * ``DatabaseOperations.ignore_conflicts_suffix_sql()`` method is replaced by | ||||||
|  |   ``DatabaseOperations.on_conflict_suffix_sql()`` that accepts the ``fields``, | ||||||
|  |   ``on_conflict``, ``update_fields``, and ``unique_fields`` arguments. | ||||||
|  |  | ||||||
|  | * The ``ignore_conflicts`` argument of the | ||||||
|  |   ``DatabaseOperations.insert_statement()`` method is replaced by | ||||||
|  |   ``on_conflict`` that accepts ``django.db.models.constants.OnConflict``. | ||||||
|  |  | ||||||
| Dropped support for MariaDB 10.2 | Dropped support for MariaDB 10.2 | ||||||
| -------------------------------- | -------------------------------- | ||||||
|  |  | ||||||
|   | |||||||
| @@ -16,6 +16,14 @@ class Country(models.Model): | |||||||
|     iso_two_letter = models.CharField(max_length=2) |     iso_two_letter = models.CharField(max_length=2) | ||||||
|     description = models.TextField() |     description = models.TextField() | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         constraints = [ | ||||||
|  |             models.UniqueConstraint( | ||||||
|  |                 fields=['iso_two_letter', 'name'], | ||||||
|  |                 name='country_name_iso_unique', | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
| class ProxyCountry(Country): | class ProxyCountry(Country): | ||||||
|     class Meta: |     class Meta: | ||||||
| @@ -58,6 +66,13 @@ class State(models.Model): | |||||||
| class TwoFields(models.Model): | class TwoFields(models.Model): | ||||||
|     f1 = models.IntegerField(unique=True) |     f1 = models.IntegerField(unique=True) | ||||||
|     f2 = models.IntegerField(unique=True) |     f2 = models.IntegerField(unique=True) | ||||||
|  |     name = models.CharField(max_length=15, null=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UpsertConflict(models.Model): | ||||||
|  |     number = models.IntegerField(unique=True) | ||||||
|  |     rank = models.IntegerField() | ||||||
|  |     name = models.CharField(max_length=15) | ||||||
|  |  | ||||||
|  |  | ||||||
| class NoFields(models.Model): | class NoFields(models.Model): | ||||||
| @@ -103,3 +118,9 @@ class NullableFields(models.Model): | |||||||
|     text_field = models.TextField(null=True, default='text') |     text_field = models.TextField(null=True, default='text') | ||||||
|     url_field = models.URLField(null=True, default='/') |     url_field = models.URLField(null=True, default='/') | ||||||
|     uuid_field = models.UUIDField(null=True, default=uuid.uuid4) |     uuid_field = models.UUIDField(null=True, default=uuid.uuid4) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RelatedModel(models.Model): | ||||||
|  |     name = models.CharField(max_length=15, null=True) | ||||||
|  |     country = models.OneToOneField(Country, models.CASCADE, primary_key=True) | ||||||
|  |     big_auto_fields = models.ManyToManyField(BigAutoFieldModel) | ||||||
|   | |||||||
| @@ -1,7 +1,11 @@ | |||||||
| from math import ceil | from math import ceil | ||||||
| from operator import attrgetter | from operator import attrgetter | ||||||
|  |  | ||||||
| from django.db import IntegrityError, NotSupportedError, connection | from django.core.exceptions import FieldDoesNotExist | ||||||
|  | from django.db import ( | ||||||
|  |     IntegrityError, NotSupportedError, OperationalError, ProgrammingError, | ||||||
|  |     connection, | ||||||
|  | ) | ||||||
| from django.db.models import FileField, Value | from django.db.models import FileField, Value | ||||||
| from django.db.models.functions import Lower | from django.db.models.functions import Lower | ||||||
| from django.test import ( | from django.test import ( | ||||||
| @@ -11,7 +15,8 @@ from django.test import ( | |||||||
| from .models import ( | from .models import ( | ||||||
|     BigAutoFieldModel, Country, NoFields, NullableFields, Pizzeria, |     BigAutoFieldModel, Country, NoFields, NullableFields, Pizzeria, | ||||||
|     ProxyCountry, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, |     ProxyCountry, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, | ||||||
|     Restaurant, SmallAutoFieldModel, State, TwoFields, |     RelatedModel, Restaurant, SmallAutoFieldModel, State, TwoFields, | ||||||
|  |     UpsertConflict, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -53,10 +58,10 @@ class BulkCreateTests(TestCase): | |||||||
|     @skipUnlessDBFeature('has_bulk_insert') |     @skipUnlessDBFeature('has_bulk_insert') | ||||||
|     def test_long_and_short_text(self): |     def test_long_and_short_text(self): | ||||||
|         Country.objects.bulk_create([ |         Country.objects.bulk_create([ | ||||||
|             Country(description='a' * 4001), |             Country(description='a' * 4001, iso_two_letter='A'), | ||||||
|             Country(description='a'), |             Country(description='a', iso_two_letter='B'), | ||||||
|             Country(description='Ж' * 2001), |             Country(description='Ж' * 2001, iso_two_letter='C'), | ||||||
|             Country(description='Ж'), |             Country(description='Ж', iso_two_letter='D'), | ||||||
|         ]) |         ]) | ||||||
|         self.assertEqual(Country.objects.count(), 4) |         self.assertEqual(Country.objects.count(), 4) | ||||||
|  |  | ||||||
| @@ -218,7 +223,7 @@ class BulkCreateTests(TestCase): | |||||||
|  |  | ||||||
|     @skipUnlessDBFeature('has_bulk_insert') |     @skipUnlessDBFeature('has_bulk_insert') | ||||||
|     def test_explicit_batch_size_respects_max_batch_size(self): |     def test_explicit_batch_size_respects_max_batch_size(self): | ||||||
|         objs = [Country() for i in range(1000)] |         objs = [Country(name=f'Country {i}') for i in range(1000)] | ||||||
|         fields = ['name', 'iso_two_letter', 'description'] |         fields = ['name', 'iso_two_letter', 'description'] | ||||||
|         max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1) |         max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1) | ||||||
|         with self.assertNumQueries(ceil(len(objs) / max_batch_size)): |         with self.assertNumQueries(ceil(len(objs) / max_batch_size)): | ||||||
| @@ -352,3 +357,276 @@ class BulkCreateTests(TestCase): | |||||||
|         msg = 'Batch size must be a positive integer.' |         msg = 'Batch size must be a positive integer.' | ||||||
|         with self.assertRaisesMessage(ValueError, msg): |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|             Country.objects.bulk_create([], batch_size=-1) |             Country.objects.bulk_create([], batch_size=-1) | ||||||
|  |  | ||||||
|  |     @skipIfDBFeature('supports_update_conflicts') | ||||||
|  |     def test_update_conflicts_unsupported(self): | ||||||
|  |         msg = 'This database backend does not support updating conflicts.' | ||||||
|  |         with self.assertRaisesMessage(NotSupportedError, msg): | ||||||
|  |             Country.objects.bulk_create(self.data, update_conflicts=True) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_ignore_conflicts', 'supports_update_conflicts') | ||||||
|  |     def test_ignore_update_conflicts_exclusive(self): | ||||||
|  |         msg = 'ignore_conflicts and update_conflicts are mutually exclusive' | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             Country.objects.bulk_create( | ||||||
|  |                 self.data, | ||||||
|  |                 ignore_conflicts=True, | ||||||
|  |                 update_conflicts=True, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts') | ||||||
|  |     def test_update_conflicts_no_update_fields(self): | ||||||
|  |         msg = ( | ||||||
|  |             'Fields that will be updated when a row insertion fails on ' | ||||||
|  |             'conflicts must be provided.' | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             Country.objects.bulk_create(self.data, update_conflicts=True) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts') | ||||||
|  |     @skipIfDBFeature('supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_unique_field_unsupported(self): | ||||||
|  |         msg = ( | ||||||
|  |             'This database backend does not support updating conflicts with ' | ||||||
|  |             'specifying unique fields that can trigger the upsert.' | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(NotSupportedError, msg): | ||||||
|  |             TwoFields.objects.bulk_create( | ||||||
|  |                 [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)], | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['f2'], | ||||||
|  |                 unique_fields=['f1'], | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts') | ||||||
|  |     def test_update_conflicts_nonexistent_update_fields(self): | ||||||
|  |         unique_fields = None | ||||||
|  |         if connection.features.supports_update_conflicts_with_target: | ||||||
|  |             unique_fields = ['f1'] | ||||||
|  |         msg = "TwoFields has no field named 'nonexistent'" | ||||||
|  |         with self.assertRaisesMessage(FieldDoesNotExist, msg): | ||||||
|  |             TwoFields.objects.bulk_create( | ||||||
|  |                 [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)], | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['nonexistent'], | ||||||
|  |                 unique_fields=unique_fields, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature( | ||||||
|  |         'supports_update_conflicts', 'supports_update_conflicts_with_target', | ||||||
|  |     ) | ||||||
|  |     def test_update_conflicts_unique_fields_required(self): | ||||||
|  |         msg = 'Unique fields that can trigger the upsert must be provided.' | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             TwoFields.objects.bulk_create( | ||||||
|  |                 [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)], | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['f1'], | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature( | ||||||
|  |         'supports_update_conflicts', 'supports_update_conflicts_with_target', | ||||||
|  |     ) | ||||||
|  |     def test_update_conflicts_invalid_update_fields(self): | ||||||
|  |         msg = ( | ||||||
|  |             'bulk_create() can only be used with concrete fields in ' | ||||||
|  |             'update_fields.' | ||||||
|  |         ) | ||||||
|  |         # Reverse one-to-one relationship. | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             Country.objects.bulk_create( | ||||||
|  |                 self.data, | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['relatedmodel'], | ||||||
|  |                 unique_fields=['pk'], | ||||||
|  |             ) | ||||||
|  |         # Many-to-many relationship. | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             RelatedModel.objects.bulk_create( | ||||||
|  |                 [RelatedModel(country=self.data[0])], | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['big_auto_fields'], | ||||||
|  |                 unique_fields=['country'], | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature( | ||||||
|  |         'supports_update_conflicts', 'supports_update_conflicts_with_target', | ||||||
|  |     ) | ||||||
|  |     def test_update_conflicts_pk_in_update_fields(self): | ||||||
|  |         msg = 'bulk_create() cannot be used with primary keys in update_fields.' | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             BigAutoFieldModel.objects.bulk_create( | ||||||
|  |                 [BigAutoFieldModel()], | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['id'], | ||||||
|  |                 unique_fields=['id'], | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature( | ||||||
|  |         'supports_update_conflicts', 'supports_update_conflicts_with_target', | ||||||
|  |     ) | ||||||
|  |     def test_update_conflicts_invalid_unique_fields(self): | ||||||
|  |         msg = ( | ||||||
|  |             'bulk_create() can only be used with concrete fields in ' | ||||||
|  |             'unique_fields.' | ||||||
|  |         ) | ||||||
|  |         # Reverse one-to-one relationship. | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             Country.objects.bulk_create( | ||||||
|  |                 self.data, | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['name'], | ||||||
|  |                 unique_fields=['relatedmodel'], | ||||||
|  |             ) | ||||||
|  |         # Many-to-many relationship. | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             RelatedModel.objects.bulk_create( | ||||||
|  |                 [RelatedModel(country=self.data[0])], | ||||||
|  |                 update_conflicts=True, | ||||||
|  |                 update_fields=['name'], | ||||||
|  |                 unique_fields=['big_auto_fields'], | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def _test_update_conflicts_two_fields(self, unique_fields): | ||||||
|  |         TwoFields.objects.bulk_create([ | ||||||
|  |             TwoFields(f1=1, f2=1, name='a'), | ||||||
|  |             TwoFields(f1=2, f2=2, name='b'), | ||||||
|  |         ]) | ||||||
|  |         self.assertEqual(TwoFields.objects.count(), 2) | ||||||
|  |  | ||||||
|  |         conflicting_objects = [ | ||||||
|  |             TwoFields(f1=1, f2=1, name='c'), | ||||||
|  |             TwoFields(f1=2, f2=2, name='d'), | ||||||
|  |         ] | ||||||
|  |         TwoFields.objects.bulk_create( | ||||||
|  |             conflicting_objects, | ||||||
|  |             update_conflicts=True, | ||||||
|  |             unique_fields=unique_fields, | ||||||
|  |             update_fields=['name'], | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(TwoFields.objects.count(), 2) | ||||||
|  |         self.assertCountEqual(TwoFields.objects.values('f1', 'f2', 'name'), [ | ||||||
|  |             {'f1': 1, 'f2': 1, 'name': 'c'}, | ||||||
|  |             {'f1': 2, 'f2': 2, 'name': 'd'}, | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_two_fields_unique_fields_first(self): | ||||||
|  |         self._test_update_conflicts_two_fields(['f1']) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_two_fields_unique_fields_second(self): | ||||||
|  |         self._test_update_conflicts_two_fields(['f2']) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_two_fields_unique_fields_both(self): | ||||||
|  |         with self.assertRaises((OperationalError, ProgrammingError)): | ||||||
|  |             self._test_update_conflicts_two_fields(['f1', 'f2']) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts') | ||||||
|  |     @skipIfDBFeature('supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_two_fields_no_unique_fields(self): | ||||||
|  |         self._test_update_conflicts_two_fields([]) | ||||||
|  |  | ||||||
|  |     def _test_update_conflicts_unique_two_fields(self, unique_fields): | ||||||
|  |         Country.objects.bulk_create(self.data) | ||||||
|  |         self.assertEqual(Country.objects.count(), 4) | ||||||
|  |  | ||||||
|  |         new_data = [ | ||||||
|  |             # Conflicting countries. | ||||||
|  |             Country(name='Germany', iso_two_letter='DE', description=( | ||||||
|  |                 'Germany is a country in Central Europe.' | ||||||
|  |             )), | ||||||
|  |             Country(name='Czech Republic', iso_two_letter='CZ', description=( | ||||||
|  |                 'The Czech Republic is a landlocked country in Central Europe.' | ||||||
|  |             )), | ||||||
|  |             # New countries. | ||||||
|  |             Country(name='Australia', iso_two_letter='AU'), | ||||||
|  |             Country(name='Japan', iso_two_letter='JP', description=( | ||||||
|  |                 'Japan is an island country in East Asia.' | ||||||
|  |             )), | ||||||
|  |         ] | ||||||
|  |         Country.objects.bulk_create( | ||||||
|  |             new_data, | ||||||
|  |             update_conflicts=True, | ||||||
|  |             update_fields=['description'], | ||||||
|  |             unique_fields=unique_fields, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(Country.objects.count(), 6) | ||||||
|  |         self.assertCountEqual(Country.objects.values('iso_two_letter', 'description'), [ | ||||||
|  |             {'iso_two_letter': 'US', 'description': ''}, | ||||||
|  |             {'iso_two_letter': 'NL', 'description': ''}, | ||||||
|  |             {'iso_two_letter': 'DE', 'description': ( | ||||||
|  |                 'Germany is a country in Central Europe.' | ||||||
|  |             )}, | ||||||
|  |             {'iso_two_letter': 'CZ', 'description': ( | ||||||
|  |                 'The Czech Republic is a landlocked country in Central Europe.' | ||||||
|  |             )}, | ||||||
|  |             {'iso_two_letter': 'AU', 'description': ''}, | ||||||
|  |             {'iso_two_letter': 'JP', 'description': ( | ||||||
|  |                 'Japan is an island country in East Asia.' | ||||||
|  |             )}, | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_unique_two_fields_unique_fields_both(self): | ||||||
|  |         self._test_update_conflicts_unique_two_fields(['iso_two_letter', 'name']) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_unique_two_fields_unique_fields_one(self): | ||||||
|  |         with self.assertRaises((OperationalError, ProgrammingError)): | ||||||
|  |             self._test_update_conflicts_unique_two_fields(['iso_two_letter']) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts') | ||||||
|  |     @skipIfDBFeature('supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_unique_two_fields_unique_no_unique_fields(self): | ||||||
|  |         self._test_update_conflicts_unique_two_fields([]) | ||||||
|  |  | ||||||
|  |     def _test_update_conflicts(self, unique_fields): | ||||||
|  |         UpsertConflict.objects.bulk_create([ | ||||||
|  |             UpsertConflict(number=1, rank=1, name='John'), | ||||||
|  |             UpsertConflict(number=2, rank=2, name='Mary'), | ||||||
|  |             UpsertConflict(number=3, rank=3, name='Hannah'), | ||||||
|  |         ]) | ||||||
|  |         self.assertEqual(UpsertConflict.objects.count(), 3) | ||||||
|  |  | ||||||
|  |         conflicting_objects = [ | ||||||
|  |             UpsertConflict(number=1, rank=4, name='Steve'), | ||||||
|  |             UpsertConflict(number=2, rank=2, name='Olivia'), | ||||||
|  |             UpsertConflict(number=3, rank=1, name='Hannah'), | ||||||
|  |         ] | ||||||
|  |         UpsertConflict.objects.bulk_create( | ||||||
|  |             conflicting_objects, | ||||||
|  |             update_conflicts=True, | ||||||
|  |             update_fields=['name', 'rank'], | ||||||
|  |             unique_fields=unique_fields, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(UpsertConflict.objects.count(), 3) | ||||||
|  |         self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [ | ||||||
|  |             {'number': 1, 'rank': 4, 'name': 'Steve'}, | ||||||
|  |             {'number': 2, 'rank': 2, 'name': 'Olivia'}, | ||||||
|  |             {'number': 3, 'rank': 1, 'name': 'Hannah'}, | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |         UpsertConflict.objects.bulk_create( | ||||||
|  |             conflicting_objects + [UpsertConflict(number=4, rank=4, name='Mark')], | ||||||
|  |             update_conflicts=True, | ||||||
|  |             update_fields=['name', 'rank'], | ||||||
|  |             unique_fields=unique_fields, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(UpsertConflict.objects.count(), 4) | ||||||
|  |         self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [ | ||||||
|  |             {'number': 1, 'rank': 4, 'name': 'Steve'}, | ||||||
|  |             {'number': 2, 'rank': 2, 'name': 'Olivia'}, | ||||||
|  |             {'number': 3, 'rank': 1, 'name': 'Hannah'}, | ||||||
|  |             {'number': 4, 'rank': 4, 'name': 'Mark'}, | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_unique_fields(self): | ||||||
|  |         self._test_update_conflicts(unique_fields=['number']) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_update_conflicts') | ||||||
|  |     @skipIfDBFeature('supports_update_conflicts_with_target') | ||||||
|  |     def test_update_conflicts_no_unique_fields(self): | ||||||
|  |         self._test_update_conflicts([]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user