diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 0009de7a60..425abf8540 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -23,6 +23,7 @@ class BaseDatabaseFeatures: can_use_chunked_reads = True can_return_columns_from_insert = False + can_return_multiple_columns_from_insert = False can_return_rows_from_bulk_insert = False has_bulk_insert = True uses_savepoints = True diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index a0c84a8ff4..51370ef2ac 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -176,13 +176,12 @@ class BaseDatabaseOperations: else: return ['DISTINCT'], [] - def fetch_returned_insert_id(self, cursor): + def fetch_returned_insert_columns(self, cursor): """ Given a cursor object that has just performed an INSERT...RETURNING - statement into a table that has an auto-incrementing ID, return the - newly created ID. + statement into a table, return the newly created data. """ - return cursor.fetchone()[0] + return cursor.fetchone() def field_cast_sql(self, db_type, internal_type): """ @@ -314,12 +313,11 @@ class BaseDatabaseOperations: """ return value - def return_insert_id(self, field): + def return_insert_columns(self, fields): """ - For backends that support returning the last insert ID as part of an - insert query, return the SQL and params to append to the INSERT query. - The returned fragment should contain a format string to hold the - appropriate column. + For backends that support returning columns as part of an insert query, + return the SQL and params to append to the INSERT query. The returned + fragment should contain a format string to hold the appropriate column. """ pass diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 95126d37be..df3710f66b 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -248,7 +248,7 @@ END; def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" - def fetch_returned_insert_id(self, cursor): + def fetch_returned_insert_columns(self, cursor): value = cursor._insert_id_var.getvalue() if value is None or value == []: # cx_Oracle < 6.3 returns None, >= 6.3 returns empty list. @@ -258,7 +258,7 @@ END; 'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).' ) # cx_Oracle < 7 returns value, >= 7 returns list with single value. - return value[0] if isinstance(value, list) else value + return value if isinstance(value, list) else [value] def field_cast_sql(self, db_type, internal_type): if db_type and db_type.endswith('LOB'): @@ -341,8 +341,14 @@ END; match_option = "'i'" return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option - def return_insert_id(self, field): - return 'RETURNING %s INTO %%s', (InsertVar(field),) + def return_insert_columns(self, fields): + if not fields: + return '', () + sql = 'RETURNING %s.%s INTO %%s' % ( + self.quote_name(fields[0].model._meta.db_table), + self.quote_name(fields[0].column), + ) + return sql, (InsertVar(fields[0]),) def __foreign_key_constraints(self, table_name, recursive): with self.connection.cursor() as cursor: diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 907ba136fb..866a0f98e7 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -8,6 +8,7 @@ from django.utils.functional import cached_property class DatabaseFeatures(BaseDatabaseFeatures): allows_group_by_selected_pks = True can_return_columns_from_insert = True + can_return_multiple_columns_from_insert = True can_return_rows_from_bulk_insert = True has_real_datatype = True has_native_uuid_field = True diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 61bac5e55a..fe5b208c6a 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -76,13 +76,12 @@ class DatabaseOperations(BaseDatabaseOperations): def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" - def fetch_returned_insert_ids(self, cursor): + def fetch_returned_insert_rows(self, cursor): """ Given a cursor object that has just performed an INSERT...RETURNING - statement into a table that has an auto-incrementing ID, return the - list of newly created IDs. + statement into a table, return the tuple of returned data. """ - return [item[0] for item in cursor.fetchall()] + return cursor.fetchall() def lookup_cast(self, lookup_type, internal_type=None): lookup = '%s' @@ -236,8 +235,16 @@ class DatabaseOperations(BaseDatabaseOperations): return cursor.query.decode() return None - def return_insert_id(self, field): - return "RETURNING %s", () + def return_insert_columns(self, fields): + if not fields: + return '', () + columns = [ + '%s.%s' % ( + self.quote_name(field.model._meta.db_table), + self.quote_name(field.column), + ) for field in fields + ] + return 'RETURNING %s' % ', '.join(columns), () def bulk_insert_sql(self, fields, placeholder_rows): placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) diff --git a/django/db/models/base.py b/django/db/models/base.py index ae27d3691a..4f3145ebc2 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -876,10 +876,10 @@ class Model(metaclass=ModelBase): if not pk_set: fields = [f for f in fields if f is not meta.auto_field] - update_pk = meta.auto_field and not pk_set - result = self._do_insert(cls._base_manager, using, fields, update_pk, raw) - if update_pk: - setattr(self, meta.pk.attname, result) + returning_fields = meta.db_returning_fields + results = self._do_insert(cls._base_manager, using, fields, returning_fields, raw) + for result, field in zip(results, returning_fields): + setattr(self, field.attname, result) return updated def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): @@ -909,13 +909,15 @@ class Model(metaclass=ModelBase): ) return filtered._update(values) > 0 - def _do_insert(self, manager, using, fields, update_pk, raw): + def _do_insert(self, manager, using, fields, returning_fields, raw): """ - Do an INSERT. If update_pk is defined then this method should return - the new pk for the model. + Do an INSERT. If returning_fields is defined then this method should + return the newly created data for the model. """ - return manager._insert([self], fields=fields, return_id=update_pk, - using=using, raw=raw) + return manager._insert( + [self], fields=fields, returning_fields=returning_fields, + using=using, raw=raw, + ) def delete(self, using=None, keep_parents=False): using = using or router.db_for_write(self.__class__, instance=self) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index b55d41bb85..d073324745 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -735,6 +735,14 @@ class Field(RegisterLookupMixin): def db_tablespace(self): return self._db_tablespace or settings.DEFAULT_INDEX_TABLESPACE + @property + def db_returning(self): + """ + Private API intended only to be used by Django itself. Currently only + the PostgreSQL backend supports returning multiple fields on a model. + """ + return False + def set_attributes_from_name(self, name): self.name = self.name or name self.attname, self.column = self.get_attname_column() @@ -2311,6 +2319,7 @@ class UUIDField(Field): class AutoFieldMixin: + db_returning = True def __init__(self, *args, **kwargs): kwargs['blank'] = True diff --git a/django/db/models/options.py b/django/db/models/options.py index 1f11e26d87..baa0c875b2 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -842,3 +842,14 @@ class Options: if isinstance(attr, property): names.append(name) return frozenset(names) + + @cached_property + def db_returning_fields(self): + """ + Private API intended only to be used by Django itself. + Fields to be returned after a database insert. + """ + return [ + field for field in self._get_fields(forward=True, reverse=False, include_parents=PROXY_PARENTS) + if getattr(field, 'db_returning', False) + ] diff --git a/django/db/models/query.py b/django/db/models/query.py index ab4f3fc534..180f4a41fc 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -470,23 +470,33 @@ class QuerySet: return objs self._for_write = True connection = connections[self.db] - fields = self.model._meta.concrete_fields + opts = self.model._meta + fields = opts.concrete_fields objs = list(objs) self._populate_pk_values(objs) with transaction.atomic(using=self.db, savepoint=False): objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) if objs_with_pk: - self._batched_insert(objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts) + returned_columns = self._batched_insert( + objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, + ) + for obj_with_pk, results in zip(objs_with_pk, returned_columns): + for result, field in zip(results, opts.db_returning_fields): + if field != opts.pk: + setattr(obj_with_pk, field.attname, result) for obj_with_pk in objs_with_pk: obj_with_pk._state.adding = False obj_with_pk._state.db = self.db if objs_without_pk: fields = [f for f in fields if not isinstance(f, AutoField)] - ids = self._batched_insert(objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts) + returned_columns = self._batched_insert( + objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, + ) if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts: - assert len(ids) == len(objs_without_pk) - for obj_without_pk, pk in zip(objs_without_pk, ids): - obj_without_pk.pk = pk + assert len(returned_columns) == len(objs_without_pk) + for obj_without_pk, results in zip(objs_without_pk, returned_columns): + for result, field in zip(results, opts.db_returning_fields): + setattr(obj_without_pk, field.attname, result) obj_without_pk._state.adding = False obj_without_pk._state.db = self.db @@ -1181,7 +1191,7 @@ class QuerySet: # PRIVATE METHODS # ################### - def _insert(self, objs, fields, return_id=False, raw=False, using=None, ignore_conflicts=False): + def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False): """ Insert a new record for the given model. This provides an interface to the InsertQuery class and is how Model.save() is implemented. @@ -1191,7 +1201,7 @@ class QuerySet: using = self.db query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts) query.insert_values(fields, objs, raw=raw) - return query.get_compiler(using=using).execute_sql(return_id) + return query.get_compiler(using=using).execute_sql(returning_fields) _insert.alters_data = True _insert.queryset_only = False @@ -1203,21 +1213,22 @@ class QuerySet: raise NotSupportedError('This database backend does not support ignoring conflicts.') ops = connections[self.db].ops batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) - inserted_ids = [] + inserted_rows = [] bulk_return = connections[self.db].features.can_return_rows_from_bulk_insert for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: if bulk_return and not ignore_conflicts: - inserted_id = self._insert( - item, fields=fields, using=self.db, return_id=True, + inserted_columns = self._insert( + item, fields=fields, using=self.db, + returning_fields=self.model._meta.db_returning_fields, ignore_conflicts=ignore_conflicts, ) - if isinstance(inserted_id, list): - inserted_ids.extend(inserted_id) + if isinstance(inserted_columns, list): + inserted_rows.extend(inserted_columns) else: - inserted_ids.append(inserted_id) + inserted_rows.append(inserted_columns) else: self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts) - return inserted_ids + return inserted_rows def _chain(self, **kwargs): """ diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 77e023b92f..5193362a92 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1159,7 +1159,7 @@ class SQLCompiler: class SQLInsertCompiler(SQLCompiler): - return_id = False + returning_fields = None def field_as_sql(self, field, val): """ @@ -1290,14 +1290,14 @@ class SQLInsertCompiler(SQLCompiler): # queries and generate their own placeholders. Doing that isn't # necessary and it should be possible to use placeholders and # expressions in bulk inserts too. - can_bulk = (not self.return_id and self.connection.features.has_bulk_insert) + can_bulk = (not self.returning_fields and self.connection.features.has_bulk_insert) placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql( ignore_conflicts=self.query.ignore_conflicts ) - if self.return_id and self.connection.features.can_return_columns_from_insert: + if self.returning_fields and self.connection.features.can_return_columns_from_insert: if self.connection.features.can_return_rows_from_bulk_insert: result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) params = param_rows @@ -1306,12 +1306,11 @@ class SQLInsertCompiler(SQLCompiler): params = [param_rows[0]] if ignore_conflicts_suffix_sql: result.append(ignore_conflicts_suffix_sql) - col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) - r_fmt, r_params = self.connection.ops.return_insert_id(opts.pk) - # Skip empty r_fmt to allow subclasses to customize behavior for + # Skip empty r_sql to allow subclasses to customize behavior for # 3rd party backends. Refs #19096. - if r_fmt: - result.append(r_fmt % col) + r_sql, r_params = self.connection.ops.return_insert_columns(self.returning_fields) + if r_sql: + result.append(r_sql) params += [r_params] return [(" ".join(result), tuple(chain.from_iterable(params)))] @@ -1328,25 +1327,33 @@ class SQLInsertCompiler(SQLCompiler): for p, vals in zip(placeholder_rows, param_rows) ] - def execute_sql(self, return_id=False): + def execute_sql(self, returning_fields=None): assert not ( - return_id and len(self.query.objs) != 1 and + returning_fields and len(self.query.objs) != 1 and not self.connection.features.can_return_rows_from_bulk_insert ) - self.return_id = return_id + self.returning_fields = returning_fields with self.connection.cursor() as cursor: for sql, params in self.as_sql(): cursor.execute(sql, params) - if not return_id: - return + if not self.returning_fields: + return [] if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1: - return self.connection.ops.fetch_returned_insert_ids(cursor) + return self.connection.ops.fetch_returned_insert_rows(cursor) if self.connection.features.can_return_columns_from_insert: + if ( + len(self.returning_fields) > 1 and + not self.connection.features.can_return_multiple_columns_from_insert + ): + raise NotSupportedError( + 'Returning multiple columns from INSERT statements is ' + 'not supported on this database backend.' + ) assert len(self.query.objs) == 1 - return self.connection.ops.fetch_returned_insert_id(cursor) - return self.connection.ops.last_insert_id( + return self.connection.ops.fetch_returned_insert_columns(cursor) + return [self.connection.ops.last_insert_id( cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column - ) + )] class SQLDeleteCompiler(SQLCompiler): diff --git a/docs/releases/3.0.txt b/docs/releases/3.0.txt index 1fc64a442d..11fb46b6a5 100644 --- a/docs/releases/3.0.txt +++ b/docs/releases/3.0.txt @@ -448,14 +448,20 @@ backends. :class:`~django.db.models.DateTimeField` in ``datetime_cast_date_sql()``, ``datetime_extract_sql()``, etc. -* ``DatabaseOperations.return_insert_id()`` now requires an additional - ``field`` argument with the model field. - * Entries for ``AutoField``, ``BigAutoField``, and ``SmallAutoField`` are added to ``DatabaseOperations.integer_field_ranges`` to support the integer range validators on these field types. Third-party backends may need to customize the default entries. +* ``DatabaseOperations.fetch_returned_insert_id()`` is replaced by + ``fetch_returned_insert_columns()`` which returns a list of values returned + by the ``INSERT … RETURNING`` statement, instead of a single value. + +* ``DatabaseOperations.return_insert_id()`` is replaced by + ``return_insert_columns()`` that accepts a ``fields`` + argument, which is an iterable of fields to be returned after insert. Usually + this is only the auto-generated primary key. + :mod:`django.contrib.admin` --------------------------- diff --git a/tests/backends/models.py b/tests/backends/models.py index a2c8616cc6..1fa8d44e63 100644 --- a/tests/backends/models.py +++ b/tests/backends/models.py @@ -5,10 +5,6 @@ from django.contrib.contenttypes.models import ContentType from django.db import models -class NonIntegerAutoField(models.Model): - creation_datetime = models.DateTimeField(primary_key=True) - - class Square(models.Model): root = models.IntegerField() square = models.PositiveIntegerField() diff --git a/tests/backends/oracle/tests.py b/tests/backends/oracle/tests.py index 30d981da69..a0d49854d9 100644 --- a/tests/backends/oracle/tests.py +++ b/tests/backends/oracle/tests.py @@ -1,4 +1,3 @@ -import datetime import unittest from django.db import connection @@ -6,7 +5,7 @@ from django.db.models.fields import BooleanField, NullBooleanField from django.db.utils import DatabaseError from django.test import TransactionTestCase -from ..models import NonIntegerAutoField, Square +from ..models import Square @unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') @@ -96,23 +95,3 @@ class TransactionalTests(TransactionTestCase): self.assertIn('ORA-01017', context.exception.args[0].message) finally: connection.settings_dict['PASSWORD'] = old_password - - def test_non_integer_auto_field(self): - with connection.cursor() as cursor: - # Create trigger that fill non-integer auto field. - cursor.execute(""" - CREATE OR REPLACE TRIGGER "TRG_FILL_CREATION_DATETIME" - BEFORE INSERT ON "BACKENDS_NONINTEGERAUTOFIELD" - FOR EACH ROW - BEGIN - :NEW.CREATION_DATETIME := SYSTIMESTAMP; - END; - """) - try: - NonIntegerAutoField._meta.auto_field = NonIntegerAutoField.creation_datetime - obj = NonIntegerAutoField.objects.create() - self.assertIsNotNone(obj.creation_datetime) - self.assertIsInstance(obj.creation_datetime, datetime.datetime) - finally: - with connection.cursor() as cursor: - cursor.execute('DROP TRIGGER "TRG_FILL_CREATION_DATETIME"') diff --git a/tests/model_meta/tests.py b/tests/model_meta/tests.py index 7867f4c620..0a3049c6dc 100644 --- a/tests/model_meta/tests.py +++ b/tests/model_meta/tests.py @@ -279,3 +279,8 @@ class PropertyNamesTests(SimpleTestCase): # Instance only descriptors don't appear in _property_names. self.assertEqual(AbstractPerson().test_instance_only_descriptor, 1) self.assertEqual(AbstractPerson._meta._property_names, frozenset(['pk', 'test_property'])) + + +class ReturningFieldsTests(SimpleTestCase): + def test_pk(self): + self.assertEqual(Relation._meta.db_returning_fields, [Relation._meta.pk]) diff --git a/tests/queries/models.py b/tests/queries/models.py index f228a61a70..e9eec5718d 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -4,6 +4,7 @@ Various complex queries that have been problematic in the past. import threading from django.db import models +from django.db.models.functions import Now class DumbCategory(models.Model): @@ -730,3 +731,19 @@ class RelatedIndividual(models.Model): class CustomDbColumn(models.Model): custom_column = models.IntegerField(db_column='custom_name', null=True) ip_address = models.GenericIPAddressField(null=True) + + +class CreatedField(models.DateTimeField): + db_returning = True + + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', Now) + super().__init__(*args, **kwargs) + + +class ReturningModel(models.Model): + created = CreatedField(editable=False) + + +class NonIntegerPKReturningModel(models.Model): + created = CreatedField(editable=False, primary_key=True) diff --git a/tests/queries/test_db_returning.py b/tests/queries/test_db_returning.py new file mode 100644 index 0000000000..af9d041393 --- /dev/null +++ b/tests/queries/test_db_returning.py @@ -0,0 +1,64 @@ +import datetime + +from django.db import NotSupportedError, connection +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature +from django.test.utils import CaptureQueriesContext + +from .models import DumbCategory, NonIntegerPKReturningModel, ReturningModel + + +@skipUnlessDBFeature('can_return_columns_from_insert') +class ReturningValuesTests(TestCase): + def test_insert_returning(self): + with CaptureQueriesContext(connection) as captured_queries: + DumbCategory.objects.create() + self.assertIn( + 'RETURNING %s.%s' % ( + connection.ops.quote_name(DumbCategory._meta.db_table), + connection.ops.quote_name(DumbCategory._meta.get_field('id').column), + ), + captured_queries[-1]['sql'], + ) + + def test_insert_returning_non_integer(self): + obj = NonIntegerPKReturningModel.objects.create() + self.assertTrue(obj.created) + self.assertIsInstance(obj.created, datetime.datetime) + + @skipUnlessDBFeature('can_return_multiple_columns_from_insert') + def test_insert_returning_multiple(self): + with CaptureQueriesContext(connection) as captured_queries: + obj = ReturningModel.objects.create() + table_name = connection.ops.quote_name(ReturningModel._meta.db_table) + self.assertIn( + 'RETURNING %s.%s, %s.%s' % ( + table_name, + connection.ops.quote_name(ReturningModel._meta.get_field('id').column), + table_name, + connection.ops.quote_name(ReturningModel._meta.get_field('created').column), + ), + captured_queries[-1]['sql'], + ) + self.assertTrue(obj.pk) + self.assertIsInstance(obj.created, datetime.datetime) + + @skipIfDBFeature('can_return_multiple_columns_from_insert') + def test_insert_returning_multiple_not_supported(self): + msg = ( + 'Returning multiple columns from INSERT statements is ' + 'not supported on this database backend.' + ) + with self.assertRaisesMessage(NotSupportedError, msg): + ReturningModel.objects.create() + + @skipUnlessDBFeature( + 'can_return_rows_from_bulk_insert', + 'can_return_multiple_columns_from_insert', + ) + def test_bulk_insert(self): + objs = [ReturningModel(), ReturningModel(pk=2 ** 11), ReturningModel()] + ReturningModel.objects.bulk_create(objs) + for obj in objs: + with self.subTest(obj=obj): + self.assertTrue(obj.pk) + self.assertIsInstance(obj.created, datetime.datetime)