1
0
mirror of https://github.com/django/django.git synced 2024-12-31 21:46:05 +00:00

Refs #29444 -- Allowed returning multiple fields from INSERT statements on PostgreSQL.

Thanks Florian Apolloner, Tim Graham, Simon Charette, Nick Pope, and
Mariusz Felisiak for reviews.
This commit is contained in:
Johannes Hoppe 2019-07-24 08:42:41 +02:00 committed by Mariusz Felisiak
parent 736e7d44de
commit 7254f1138d
16 changed files with 209 additions and 89 deletions

View File

@ -23,6 +23,7 @@ class BaseDatabaseFeatures:
can_use_chunked_reads = True can_use_chunked_reads = True
can_return_columns_from_insert = False can_return_columns_from_insert = False
can_return_multiple_columns_from_insert = False
can_return_rows_from_bulk_insert = False can_return_rows_from_bulk_insert = False
has_bulk_insert = True has_bulk_insert = True
uses_savepoints = True uses_savepoints = True

View File

@ -176,13 +176,12 @@ class BaseDatabaseOperations:
else: else:
return ['DISTINCT'], [] return ['DISTINCT'], []
def fetch_returned_insert_id(self, cursor): def fetch_returned_insert_columns(self, cursor):
""" """
Given a cursor object that has just performed an INSERT...RETURNING Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, return the statement into a table, return the newly created data.
newly created ID.
""" """
return cursor.fetchone()[0] return cursor.fetchone()
def field_cast_sql(self, db_type, internal_type): def field_cast_sql(self, db_type, internal_type):
""" """
@ -314,12 +313,11 @@ class BaseDatabaseOperations:
""" """
return value return value
def return_insert_id(self, field): def return_insert_columns(self, fields):
""" """
For backends that support returning the last insert ID as part of an For backends that support returning columns as part of an insert query,
insert query, return the SQL and params to append to the INSERT query. return the SQL and params to append to the INSERT query. The returned
The returned fragment should contain a format string to hold the fragment should contain a format string to hold the appropriate column.
appropriate column.
""" """
pass pass

View File

@ -248,7 +248,7 @@ END;
def deferrable_sql(self): def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED" return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_id(self, cursor): def fetch_returned_insert_columns(self, cursor):
value = cursor._insert_id_var.getvalue() value = cursor._insert_id_var.getvalue()
if value is None or value == []: if value is None or value == []:
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list. # cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
@ -258,7 +258,7 @@ END;
'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).' 'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).'
) )
# cx_Oracle < 7 returns value, >= 7 returns list with single value. # cx_Oracle < 7 returns value, >= 7 returns list with single value.
return value[0] if isinstance(value, list) else value return value if isinstance(value, list) else [value]
def field_cast_sql(self, db_type, internal_type): def field_cast_sql(self, db_type, internal_type):
if db_type and db_type.endswith('LOB'): if db_type and db_type.endswith('LOB'):
@ -341,8 +341,14 @@ END;
match_option = "'i'" match_option = "'i'"
return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
def return_insert_id(self, field): def return_insert_columns(self, fields):
return 'RETURNING %s INTO %%s', (InsertVar(field),) if not fields:
return '', ()
sql = 'RETURNING %s.%s INTO %%s' % (
self.quote_name(fields[0].model._meta.db_table),
self.quote_name(fields[0].column),
)
return sql, (InsertVar(fields[0]),)
def __foreign_key_constraints(self, table_name, recursive): def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:

View File

@ -8,6 +8,7 @@ from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True allows_group_by_selected_pks = True
can_return_columns_from_insert = True can_return_columns_from_insert = True
can_return_multiple_columns_from_insert = True
can_return_rows_from_bulk_insert = True can_return_rows_from_bulk_insert = True
has_real_datatype = True has_real_datatype = True
has_native_uuid_field = True has_native_uuid_field = True

View File

@ -76,13 +76,12 @@ class DatabaseOperations(BaseDatabaseOperations):
def deferrable_sql(self): def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED" return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_ids(self, cursor): def fetch_returned_insert_rows(self, cursor):
""" """
Given a cursor object that has just performed an INSERT...RETURNING Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, return the statement into a table, return the tuple of returned data.
list of newly created IDs.
""" """
return [item[0] for item in cursor.fetchall()] return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None): def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s' lookup = '%s'
@ -236,8 +235,16 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.query.decode() return cursor.query.decode()
return None return None
def return_insert_id(self, field): def return_insert_columns(self, fields):
return "RETURNING %s", () if not fields:
return '', ()
columns = [
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows): def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)

View File

@ -876,10 +876,10 @@ class Model(metaclass=ModelBase):
if not pk_set: if not pk_set:
fields = [f for f in fields if f is not meta.auto_field] fields = [f for f in fields if f is not meta.auto_field]
update_pk = meta.auto_field and not pk_set returning_fields = meta.db_returning_fields
result = self._do_insert(cls._base_manager, using, fields, update_pk, raw) results = self._do_insert(cls._base_manager, using, fields, returning_fields, raw)
if update_pk: for result, field in zip(results, returning_fields):
setattr(self, meta.pk.attname, result) setattr(self, field.attname, result)
return updated return updated
def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update):
@ -909,13 +909,15 @@ class Model(metaclass=ModelBase):
) )
return filtered._update(values) > 0 return filtered._update(values) > 0
def _do_insert(self, manager, using, fields, update_pk, raw): def _do_insert(self, manager, using, fields, returning_fields, raw):
""" """
Do an INSERT. If update_pk is defined then this method should return Do an INSERT. If returning_fields is defined then this method should
the new pk for the model. return the newly created data for the model.
""" """
return manager._insert([self], fields=fields, return_id=update_pk, return manager._insert(
using=using, raw=raw) [self], fields=fields, returning_fields=returning_fields,
using=using, raw=raw,
)
def delete(self, using=None, keep_parents=False): def delete(self, using=None, keep_parents=False):
using = using or router.db_for_write(self.__class__, instance=self) using = using or router.db_for_write(self.__class__, instance=self)

View File

@ -735,6 +735,14 @@ class Field(RegisterLookupMixin):
def db_tablespace(self): def db_tablespace(self):
return self._db_tablespace or settings.DEFAULT_INDEX_TABLESPACE return self._db_tablespace or settings.DEFAULT_INDEX_TABLESPACE
@property
def db_returning(self):
"""
Private API intended only to be used by Django itself. Currently only
the PostgreSQL backend supports returning multiple fields on a model.
"""
return False
def set_attributes_from_name(self, name): def set_attributes_from_name(self, name):
self.name = self.name or name self.name = self.name or name
self.attname, self.column = self.get_attname_column() self.attname, self.column = self.get_attname_column()
@ -2311,6 +2319,7 @@ class UUIDField(Field):
class AutoFieldMixin: class AutoFieldMixin:
db_returning = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['blank'] = True kwargs['blank'] = True

View File

@ -842,3 +842,14 @@ class Options:
if isinstance(attr, property): if isinstance(attr, property):
names.append(name) names.append(name)
return frozenset(names) return frozenset(names)
@cached_property
def db_returning_fields(self):
"""
Private API intended only to be used by Django itself.
Fields to be returned after a database insert.
"""
return [
field for field in self._get_fields(forward=True, reverse=False, include_parents=PROXY_PARENTS)
if getattr(field, 'db_returning', False)
]

View File

@ -470,23 +470,33 @@ class QuerySet:
return objs return objs
self._for_write = True self._for_write = True
connection = connections[self.db] connection = connections[self.db]
fields = self.model._meta.concrete_fields opts = self.model._meta
fields = opts.concrete_fields
objs = list(objs) objs = list(objs)
self._populate_pk_values(objs) self._populate_pk_values(objs)
with transaction.atomic(using=self.db, savepoint=False): with transaction.atomic(using=self.db, savepoint=False):
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
if objs_with_pk: if objs_with_pk:
self._batched_insert(objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts) returned_columns = self._batched_insert(
objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
)
for obj_with_pk, results in zip(objs_with_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
if field != opts.pk:
setattr(obj_with_pk, field.attname, result)
for obj_with_pk in objs_with_pk: for obj_with_pk in objs_with_pk:
obj_with_pk._state.adding = False obj_with_pk._state.adding = False
obj_with_pk._state.db = self.db obj_with_pk._state.db = self.db
if objs_without_pk: if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)] fields = [f for f in fields if not isinstance(f, AutoField)]
ids = self._batched_insert(objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts) returned_columns = self._batched_insert(
objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
)
if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts: if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
assert len(ids) == len(objs_without_pk) assert len(returned_columns) == len(objs_without_pk)
for obj_without_pk, pk in zip(objs_without_pk, ids): for obj_without_pk, results in zip(objs_without_pk, returned_columns):
obj_without_pk.pk = pk for result, field in zip(results, opts.db_returning_fields):
setattr(obj_without_pk, field.attname, result)
obj_without_pk._state.adding = False obj_without_pk._state.adding = False
obj_without_pk._state.db = self.db obj_without_pk._state.db = self.db
@ -1181,7 +1191,7 @@ class QuerySet:
# PRIVATE METHODS # # PRIVATE METHODS #
################### ###################
def _insert(self, objs, fields, return_id=False, raw=False, using=None, ignore_conflicts=False): def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
""" """
Insert a new record for the given model. This provides an interface to Insert a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. the InsertQuery class and is how Model.save() is implemented.
@ -1191,7 +1201,7 @@ class QuerySet:
using = self.db using = self.db
query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts) query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
query.insert_values(fields, objs, raw=raw) query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(return_id) return query.get_compiler(using=using).execute_sql(returning_fields)
_insert.alters_data = True _insert.alters_data = True
_insert.queryset_only = False _insert.queryset_only = False
@ -1203,21 +1213,22 @@ class QuerySet:
raise NotSupportedError('This database backend does not support ignoring conflicts.') raise NotSupportedError('This database backend does not support ignoring conflicts.')
ops = connections[self.db].ops ops = connections[self.db].ops
batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
inserted_ids = [] inserted_rows = []
bulk_return = connections[self.db].features.can_return_rows_from_bulk_insert bulk_return = connections[self.db].features.can_return_rows_from_bulk_insert
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and not ignore_conflicts: if bulk_return and not ignore_conflicts:
inserted_id = self._insert( inserted_columns = self._insert(
item, fields=fields, using=self.db, return_id=True, item, fields=fields, using=self.db,
returning_fields=self.model._meta.db_returning_fields,
ignore_conflicts=ignore_conflicts, ignore_conflicts=ignore_conflicts,
) )
if isinstance(inserted_id, list): if isinstance(inserted_columns, list):
inserted_ids.extend(inserted_id) inserted_rows.extend(inserted_columns)
else: else:
inserted_ids.append(inserted_id) inserted_rows.append(inserted_columns)
else: else:
self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts) self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
return inserted_ids return inserted_rows
def _chain(self, **kwargs): def _chain(self, **kwargs):
""" """

View File

@ -1159,7 +1159,7 @@ class SQLCompiler:
class SQLInsertCompiler(SQLCompiler): class SQLInsertCompiler(SQLCompiler):
return_id = False returning_fields = None
def field_as_sql(self, field, val): def field_as_sql(self, field, val):
""" """
@ -1290,14 +1290,14 @@ class SQLInsertCompiler(SQLCompiler):
# queries and generate their own placeholders. Doing that isn't # queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and # necessary and it should be possible to use placeholders and
# expressions in bulk inserts too. # expressions in bulk inserts too.
can_bulk = (not self.return_id and self.connection.features.has_bulk_insert) can_bulk = (not self.returning_fields and self.connection.features.has_bulk_insert)
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql( ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql(
ignore_conflicts=self.query.ignore_conflicts ignore_conflicts=self.query.ignore_conflicts
) )
if self.return_id and self.connection.features.can_return_columns_from_insert: if self.returning_fields and self.connection.features.can_return_columns_from_insert:
if self.connection.features.can_return_rows_from_bulk_insert: if self.connection.features.can_return_rows_from_bulk_insert:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
params = param_rows params = param_rows
@ -1306,12 +1306,11 @@ class SQLInsertCompiler(SQLCompiler):
params = [param_rows[0]] params = [param_rows[0]]
if ignore_conflicts_suffix_sql: if ignore_conflicts_suffix_sql:
result.append(ignore_conflicts_suffix_sql) result.append(ignore_conflicts_suffix_sql)
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) # Skip empty r_sql to allow subclasses to customize behavior for
r_fmt, r_params = self.connection.ops.return_insert_id(opts.pk)
# Skip empty r_fmt to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096. # 3rd party backends. Refs #19096.
if r_fmt: r_sql, r_params = self.connection.ops.return_insert_columns(self.returning_fields)
result.append(r_fmt % col) if r_sql:
result.append(r_sql)
params += [r_params] params += [r_params]
return [(" ".join(result), tuple(chain.from_iterable(params)))] return [(" ".join(result), tuple(chain.from_iterable(params)))]
@ -1328,25 +1327,33 @@ class SQLInsertCompiler(SQLCompiler):
for p, vals in zip(placeholder_rows, param_rows) for p, vals in zip(placeholder_rows, param_rows)
] ]
def execute_sql(self, return_id=False): def execute_sql(self, returning_fields=None):
assert not ( assert not (
return_id and len(self.query.objs) != 1 and returning_fields and len(self.query.objs) != 1 and
not self.connection.features.can_return_rows_from_bulk_insert not self.connection.features.can_return_rows_from_bulk_insert
) )
self.return_id = return_id self.returning_fields = returning_fields
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
for sql, params in self.as_sql(): for sql, params in self.as_sql():
cursor.execute(sql, params) cursor.execute(sql, params)
if not return_id: if not self.returning_fields:
return return []
if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1: if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
return self.connection.ops.fetch_returned_insert_ids(cursor) return self.connection.ops.fetch_returned_insert_rows(cursor)
if self.connection.features.can_return_columns_from_insert: if self.connection.features.can_return_columns_from_insert:
if (
len(self.returning_fields) > 1 and
not self.connection.features.can_return_multiple_columns_from_insert
):
raise NotSupportedError(
'Returning multiple columns from INSERT statements is '
'not supported on this database backend.'
)
assert len(self.query.objs) == 1 assert len(self.query.objs) == 1
return self.connection.ops.fetch_returned_insert_id(cursor) return self.connection.ops.fetch_returned_insert_columns(cursor)
return self.connection.ops.last_insert_id( return [self.connection.ops.last_insert_id(
cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column
) )]
class SQLDeleteCompiler(SQLCompiler): class SQLDeleteCompiler(SQLCompiler):

View File

@ -448,14 +448,20 @@ backends.
:class:`~django.db.models.DateTimeField` in ``datetime_cast_date_sql()``, :class:`~django.db.models.DateTimeField` in ``datetime_cast_date_sql()``,
``datetime_extract_sql()``, etc. ``datetime_extract_sql()``, etc.
* ``DatabaseOperations.return_insert_id()`` now requires an additional
``field`` argument with the model field.
* Entries for ``AutoField``, ``BigAutoField``, and ``SmallAutoField`` are added * Entries for ``AutoField``, ``BigAutoField``, and ``SmallAutoField`` are added
to ``DatabaseOperations.integer_field_ranges`` to support the integer range to ``DatabaseOperations.integer_field_ranges`` to support the integer range
validators on these field types. Third-party backends may need to customize validators on these field types. Third-party backends may need to customize
the default entries. the default entries.
* ``DatabaseOperations.fetch_returned_insert_id()`` is replaced by
``fetch_returned_insert_columns()`` which returns a list of values returned
by the ``INSERT … RETURNING`` statement, instead of a single value.
* ``DatabaseOperations.return_insert_id()`` is replaced by
``return_insert_columns()`` that accepts a ``fields``
argument, which is an iterable of fields to be returned after insert. Usually
this is only the auto-generated primary key.
:mod:`django.contrib.admin` :mod:`django.contrib.admin`
--------------------------- ---------------------------

View File

@ -5,10 +5,6 @@ from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models
class NonIntegerAutoField(models.Model):
creation_datetime = models.DateTimeField(primary_key=True)
class Square(models.Model): class Square(models.Model):
root = models.IntegerField() root = models.IntegerField()
square = models.PositiveIntegerField() square = models.PositiveIntegerField()

View File

@ -1,4 +1,3 @@
import datetime
import unittest import unittest
from django.db import connection from django.db import connection
@ -6,7 +5,7 @@ from django.db.models.fields import BooleanField, NullBooleanField
from django.db.utils import DatabaseError from django.db.utils import DatabaseError
from django.test import TransactionTestCase from django.test import TransactionTestCase
from ..models import NonIntegerAutoField, Square from ..models import Square
@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') @unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests')
@ -96,23 +95,3 @@ class TransactionalTests(TransactionTestCase):
self.assertIn('ORA-01017', context.exception.args[0].message) self.assertIn('ORA-01017', context.exception.args[0].message)
finally: finally:
connection.settings_dict['PASSWORD'] = old_password connection.settings_dict['PASSWORD'] = old_password
def test_non_integer_auto_field(self):
with connection.cursor() as cursor:
# Create trigger that fill non-integer auto field.
cursor.execute("""
CREATE OR REPLACE TRIGGER "TRG_FILL_CREATION_DATETIME"
BEFORE INSERT ON "BACKENDS_NONINTEGERAUTOFIELD"
FOR EACH ROW
BEGIN
:NEW.CREATION_DATETIME := SYSTIMESTAMP;
END;
""")
try:
NonIntegerAutoField._meta.auto_field = NonIntegerAutoField.creation_datetime
obj = NonIntegerAutoField.objects.create()
self.assertIsNotNone(obj.creation_datetime)
self.assertIsInstance(obj.creation_datetime, datetime.datetime)
finally:
with connection.cursor() as cursor:
cursor.execute('DROP TRIGGER "TRG_FILL_CREATION_DATETIME"')

View File

@ -279,3 +279,8 @@ class PropertyNamesTests(SimpleTestCase):
# Instance only descriptors don't appear in _property_names. # Instance only descriptors don't appear in _property_names.
self.assertEqual(AbstractPerson().test_instance_only_descriptor, 1) self.assertEqual(AbstractPerson().test_instance_only_descriptor, 1)
self.assertEqual(AbstractPerson._meta._property_names, frozenset(['pk', 'test_property'])) self.assertEqual(AbstractPerson._meta._property_names, frozenset(['pk', 'test_property']))
class ReturningFieldsTests(SimpleTestCase):
def test_pk(self):
self.assertEqual(Relation._meta.db_returning_fields, [Relation._meta.pk])

View File

@ -4,6 +4,7 @@ Various complex queries that have been problematic in the past.
import threading import threading
from django.db import models from django.db import models
from django.db.models.functions import Now
class DumbCategory(models.Model): class DumbCategory(models.Model):
@ -730,3 +731,19 @@ class RelatedIndividual(models.Model):
class CustomDbColumn(models.Model): class CustomDbColumn(models.Model):
custom_column = models.IntegerField(db_column='custom_name', null=True) custom_column = models.IntegerField(db_column='custom_name', null=True)
ip_address = models.GenericIPAddressField(null=True) ip_address = models.GenericIPAddressField(null=True)
class CreatedField(models.DateTimeField):
db_returning = True
def __init__(self, *args, **kwargs):
kwargs.setdefault('default', Now)
super().__init__(*args, **kwargs)
class ReturningModel(models.Model):
created = CreatedField(editable=False)
class NonIntegerPKReturningModel(models.Model):
created = CreatedField(editable=False, primary_key=True)

View File

@ -0,0 +1,64 @@
import datetime
from django.db import NotSupportedError, connection
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
from .models import DumbCategory, NonIntegerPKReturningModel, ReturningModel
@skipUnlessDBFeature('can_return_columns_from_insert')
class ReturningValuesTests(TestCase):
def test_insert_returning(self):
with CaptureQueriesContext(connection) as captured_queries:
DumbCategory.objects.create()
self.assertIn(
'RETURNING %s.%s' % (
connection.ops.quote_name(DumbCategory._meta.db_table),
connection.ops.quote_name(DumbCategory._meta.get_field('id').column),
),
captured_queries[-1]['sql'],
)
def test_insert_returning_non_integer(self):
obj = NonIntegerPKReturningModel.objects.create()
self.assertTrue(obj.created)
self.assertIsInstance(obj.created, datetime.datetime)
@skipUnlessDBFeature('can_return_multiple_columns_from_insert')
def test_insert_returning_multiple(self):
with CaptureQueriesContext(connection) as captured_queries:
obj = ReturningModel.objects.create()
table_name = connection.ops.quote_name(ReturningModel._meta.db_table)
self.assertIn(
'RETURNING %s.%s, %s.%s' % (
table_name,
connection.ops.quote_name(ReturningModel._meta.get_field('id').column),
table_name,
connection.ops.quote_name(ReturningModel._meta.get_field('created').column),
),
captured_queries[-1]['sql'],
)
self.assertTrue(obj.pk)
self.assertIsInstance(obj.created, datetime.datetime)
@skipIfDBFeature('can_return_multiple_columns_from_insert')
def test_insert_returning_multiple_not_supported(self):
msg = (
'Returning multiple columns from INSERT statements is '
'not supported on this database backend.'
)
with self.assertRaisesMessage(NotSupportedError, msg):
ReturningModel.objects.create()
@skipUnlessDBFeature(
'can_return_rows_from_bulk_insert',
'can_return_multiple_columns_from_insert',
)
def test_bulk_insert(self):
objs = [ReturningModel(), ReturningModel(pk=2 ** 11), ReturningModel()]
ReturningModel.objects.bulk_create(objs)
for obj in objs:
with self.subTest(obj=obj):
self.assertTrue(obj.pk)
self.assertIsInstance(obj.created, datetime.datetime)