1
0
mirror of https://github.com/django/django.git synced 2025-06-07 12:39:12 +00:00

Fixed #32406 Support UPDATE RETURNING

Implement a new QuerySet.update_returning() that executes a
QuerySet.update() and returns modified rows on databases that support
UPDATE RETURNING (SQLite, PostgreSQL, Oracle)
This commit is contained in:
Aivars Kalvans 2023-09-22 20:10:49 +03:00
parent 51d703a27f
commit c3a0ff4ede
11 changed files with 188 additions and 21 deletions

View File

@ -38,6 +38,7 @@ class BaseDatabaseFeatures:
can_use_chunked_reads = True can_use_chunked_reads = True
can_return_columns_from_insert = False can_return_columns_from_insert = False
can_return_rows_from_bulk_insert = False can_return_rows_from_bulk_insert = False
can_return_columns_from_update = False
has_bulk_insert = True has_bulk_insert = True
uses_savepoints = True uses_savepoints = True
can_release_savepoints = False can_release_savepoints = False

View File

@ -28,6 +28,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_frame_range_fixed_distance = True supports_frame_range_fixed_distance = True
supports_update_conflicts = True supports_update_conflicts = True
delete_can_self_reference_subquery = False delete_can_self_reference_subquery = False
can_return_columns_from_update = False
create_test_procedure_without_params_sql = """ create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure () CREATE PROCEDURE test_procedure ()
BEGIN BEGIN

View File

@ -16,6 +16,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_select_for_update_of = True has_select_for_update_of = True
select_for_update_of_column = True select_for_update_of_column = True
can_return_columns_from_insert = True can_return_columns_from_insert = True
can_return_columns_from_update = True
supports_subqueries_in_group_by = False supports_subqueries_in_group_by = False
ignores_unnecessary_order_by_in_subqueries = False ignores_unnecessary_order_by_in_subqueries = False
supports_transactions = True supports_transactions = True

View File

@ -11,6 +11,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True allows_group_by_selected_pks = True
can_return_columns_from_insert = True can_return_columns_from_insert = True
can_return_rows_from_bulk_insert = True can_return_rows_from_bulk_insert = True
can_return_columns_from_update = True
has_real_datatype = True has_real_datatype = True
has_native_uuid_field = True has_native_uuid_field = True
has_native_duration_field = True has_native_duration_field = True

View File

@ -152,3 +152,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_return_rows_from_bulk_insert = property( can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert") operator.attrgetter("can_return_columns_from_insert")
) )
@cached_property
def can_return_columns_from_update(self):
return Database.sqlite_version_info >= (3, 35)

View File

@ -1213,10 +1213,9 @@ class QuerySet(AltersData):
_raw_delete.alters_data = True _raw_delete.alters_data = True
def update(self, **kwargs): def _update_query(self, **kwargs):
""" """
Update all elements in the current QuerySet, setting all the given Prepare a query for update
fields to the appropriate values.
""" """
self._not_support_combined_queries("update") self._not_support_combined_queries("update")
if self.query.is_sliced: if self.query.is_sliced:
@ -1247,7 +1246,32 @@ class QuerySet(AltersData):
# Clear any annotations so that they won't be present in subqueries. # Clear any annotations so that they won't be present in subqueries.
query.annotations = {} query.annotations = {}
with transaction.mark_for_rollback_on_error(using=self.db): return query
def update_returning(self, **kwargs):
"""
Update all elements in the current QuerySet, setting all the given
fields to the appropriate values. Returns a QuerySet containing all
updated objects.
"""
clone = self._chain()
clone.query = self._update_query(**kwargs)
clone.query.update_returning = True
return clone
update_returning.alters_data = True
async def aupdate_returning(self, **kwargs):
return await sync_to_async(self.update_returning)(**kwargs)
aupdate_returning.alters_data = True
def update(self, **kwargs):
"""
Update all elements in the current QuerySet, setting all the given
fields to the appropriate values.
"""
query = self._update_query(**kwargs)
rows = query.get_compiler(self.db).execute_sql(CURSOR) rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None self._result_cache = None
return rows return rows

View File

@ -21,7 +21,7 @@ from django.db.models.sql.constants import (
) )
from django.db.models.sql.query import Query, get_order_dir from django.db.models.sql.query import Query, get_order_dir
from django.db.models.sql.where import AND from django.db.models.sql.where import AND
from django.db.transaction import TransactionManagementError from django.db.transaction import TransactionManagementError, mark_for_rollback_on_error
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.hashable import make_hashable from django.utils.hashable import make_hashable
from django.utils.regex_helper import _lazy_re_compile from django.utils.regex_helper import _lazy_re_compile
@ -1918,6 +1918,12 @@ class SQLUpdateCompiler(SQLCompiler):
Create the SQL for this query. Return the SQL string and list of Create the SQL for this query. Return the SQL string and list of
parameters. parameters.
""" """
if (
self.query.update_returning
and not self.connection.features.can_return_columns_from_update
):
raise NotSupportedError("This backend does not support UPDATE RETURNING")
self.pre_sql_setup() self.pre_sql_setup()
if not self.query.values: if not self.query.values:
return "", () return "", ()
@ -1975,16 +1981,36 @@ class SQLUpdateCompiler(SQLCompiler):
params = [] params = []
else: else:
result.append("WHERE %s" % where) result.append("WHERE %s" % where)
if self.query.update_returning:
select_mask = self.query.get_select_mask()
if self.query.default_cols:
cols = self.get_default_columns(select_mask)
else:
cols = self.query.select
result.append(
"RETURNING %s" % (", ".join(col.target.column for col in cols))
)
return " ".join(result), tuple(update_params + params) return " ".join(result), tuple(update_params + params)
def execute_sql(self, result_type): def execute_sql(
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
):
""" """
Execute the specified update. Return the number of rows affected by Execute the specified update. Return the number of rows affected by
the primary update query. The "primary update query" is the first the primary update query. The "primary update query" is the first
non-empty query that is executed. Row counts for any subsequent, non-empty query that is executed. Row counts for any subsequent,
related queries are not available. related queries are not available.
""" """
cursor = super().execute_sql(result_type) with mark_for_rollback_on_error(using=self.using):
cursor = super().execute_sql(
result_type=result_type,
chunked_fetch=chunked_fetch,
chunk_size=chunk_size,
)
if self.query.update_returning:
return cursor
try: try:
rows = cursor.rowcount if cursor else 0 rows = cursor.rowcount if cursor else 0
is_empty = cursor is None is_empty = cursor is None
@ -2007,6 +2033,10 @@ class SQLUpdateCompiler(SQLCompiler):
this point so that they don't change as a result of the progressive this point so that they don't change as a result of the progressive
updates. updates.
""" """
if self.query.update_returning:
self.has_extra_select = False
self.setup_query()
refcounts_before = self.query.alias_refcount.copy() refcounts_before = self.query.alias_refcount.copy()
# Ensure base table is in the query # Ensure base table is in the query
self.query.get_initial_alias() self.query.get_initial_alias()

View File

@ -276,6 +276,8 @@ class Query(BaseExpression):
explain_info = None explain_info = None
update_returning = False
def __init__(self, model, alias_cols=True): def __init__(self, model, alias_cols=True):
self.model = model self.model = model
self.alias_refcount = {} self.alias_refcount = {}

View File

@ -59,12 +59,14 @@ class UpdateQuery(Query):
Run on initialization and at the end of chaining. Any attributes that Run on initialization and at the end of chaining. Any attributes that
would normally be set in __init__() should go here instead. would normally be set in __init__() should go here instead.
""" """
if not hasattr(self, "values"):
self.values = [] self.values = []
self.related_ids = None self.related_ids = None
self.related_updates = {} self.related_updates = {}
def clone(self): def clone(self):
obj = super().clone() obj = super().clone()
obj.values = list(self.values)
obj.related_updates = self.related_updates.copy() obj.related_updates = self.related_updates.copy()
return obj return obj

View File

@ -779,6 +779,12 @@ class ReturningModel(models.Model):
created = CreatedField(editable=False) created = CreatedField(editable=False)
class UpdateReturningModel(models.Model):
key = models.CharField(max_length=10)
content = models.CharField(max_length=10)
hits = models.IntegerField(default=0)
class NonIntegerPKReturningModel(models.Model): class NonIntegerPKReturningModel(models.Model):
created = CreatedField(editable=False, primary_key=True) created = CreatedField(editable=False, primary_key=True)

View File

@ -0,0 +1,95 @@
from django.db.models import F
from django.test import TestCase, skipUnlessDBFeature
from .models import UpdateReturningModel
@skipUnlessDBFeature("can_return_columns_from_update")
class UpdateReturningTests(TestCase):
def test_update_returning_single(self):
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
updated = (
UpdateReturningModel.objects.filter(key=obj.key)
.update_returning(hits=F("hits") + 1)
.get()
)
assert updated.pk == obj.pk
assert updated.key == obj.key
assert updated.content == obj.content
assert updated.hits == obj.hits + 1
def test_update_returning_multiple(self):
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
updated = UpdateReturningModel.objects.filter(hits__gt=1).update_returning(
hits=F("hits") + 10
)
assert len(updated) == 2
for obj in updated:
assert obj.hits in (12, 13)
def test_update_returning_only(self):
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
updated = (
UpdateReturningModel.objects.filter(key=obj.key)
.update_returning(hits=F("hits") + 1)
.only("hits")
.get()
)
assert updated.pk == obj.pk
assert updated.key == obj.key
assert updated.content == obj.content
assert updated.hits == obj.hits + 1, updated.hits
def test_update_returning_defer(self):
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
updated = (
UpdateReturningModel.objects.filter(key=obj.key)
.update_returning(hits=F("hits") + 1)
.defer("hits", "content")
.get()
)
assert updated.pk == obj.pk
assert updated.key == obj.key
assert updated.content == obj.content
assert updated.hits == obj.hits + 1, updated.hits
def test_update_returning_values(self):
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
updated = UpdateReturningModel.objects.update_returning(
hits=F("hits") + 1
).values("pk", "hits")
updated = list(updated)
assert len(updated) == 3
updated.sort(key=lambda x: x["pk"])
assert updated == [
{"pk": 1, "hits": 2},
{"pk": 2, "hits": 3},
{"pk": 3, "hits": 4},
]
def test_update_returning_values_list(self):
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
updated = UpdateReturningModel.objects.update_returning(
hits=F("hits") + 1
).values_list("hits", flat=True)
updated = list(updated)
assert len(updated) == 3
updated.sort(key=lambda x: x)
assert updated == [2, 3, 4]