1
0
mirror of https://github.com/django/django.git synced 2025-04-06 14:36:41 +00:00

Fixed #32406 Support UPDATE RETURNING

Implement a new QuerySet.update_returning() that executes a
QuerySet.update() and returns modified rows on databases that support
UPDATE RETURNING (SQLite, PostgreSQL, Oracle)
This commit is contained in:
Aivars Kalvans 2023-09-22 20:10:49 +03:00
parent 51d703a27f
commit c3a0ff4ede
11 changed files with 188 additions and 21 deletions

View File

@ -38,6 +38,7 @@ class BaseDatabaseFeatures:
can_use_chunked_reads = True
can_return_columns_from_insert = False
can_return_rows_from_bulk_insert = False
can_return_columns_from_update = False
has_bulk_insert = True
uses_savepoints = True
can_release_savepoints = False

View File

@ -28,6 +28,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_frame_range_fixed_distance = True
supports_update_conflicts = True
delete_can_self_reference_subquery = False
can_return_columns_from_update = False
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure ()
BEGIN

View File

@ -16,6 +16,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_select_for_update_of = True
select_for_update_of_column = True
can_return_columns_from_insert = True
can_return_columns_from_update = True
supports_subqueries_in_group_by = False
ignores_unnecessary_order_by_in_subqueries = False
supports_transactions = True

View File

@ -11,6 +11,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True
can_return_columns_from_insert = True
can_return_rows_from_bulk_insert = True
can_return_columns_from_update = True
has_real_datatype = True
has_native_uuid_field = True
has_native_duration_field = True

View File

@ -152,3 +152,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert")
)
@cached_property
def can_return_columns_from_update(self):
return Database.sqlite_version_info >= (3, 35)

View File

@ -1213,10 +1213,9 @@ class QuerySet(AltersData):
_raw_delete.alters_data = True
def update(self, **kwargs):
def _update_query(self, **kwargs):
"""
Update all elements in the current QuerySet, setting all the given
fields to the appropriate values.
Prepare a query for update
"""
self._not_support_combined_queries("update")
if self.query.is_sliced:
@ -1247,8 +1246,33 @@ class QuerySet(AltersData):
# Clear any annotations so that they won't be present in subqueries.
query.annotations = {}
with transaction.mark_for_rollback_on_error(using=self.db):
rows = query.get_compiler(self.db).execute_sql(CURSOR)
return query
def update_returning(self, **kwargs):
"""
Update all elements in the current QuerySet, setting all the given
fields to the appropriate values. Returns a QuerySet containing all
updated objects.
"""
clone = self._chain()
clone.query = self._update_query(**kwargs)
clone.query.update_returning = True
return clone
update_returning.alters_data = True
async def aupdate_returning(self, **kwargs):
return await sync_to_async(self.update_returning)(**kwargs)
aupdate_returning.alters_data = True
def update(self, **kwargs):
"""
Update all elements in the current QuerySet, setting all the given
fields to the appropriate values.
"""
query = self._update_query(**kwargs)
rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None
return rows

View File

@ -21,7 +21,7 @@ from django.db.models.sql.constants import (
)
from django.db.models.sql.query import Query, get_order_dir
from django.db.models.sql.where import AND
from django.db.transaction import TransactionManagementError
from django.db.transaction import TransactionManagementError, mark_for_rollback_on_error
from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
from django.utils.regex_helper import _lazy_re_compile
@ -1918,6 +1918,12 @@ class SQLUpdateCompiler(SQLCompiler):
Create the SQL for this query. Return the SQL string and list of
parameters.
"""
if (
self.query.update_returning
and not self.connection.features.can_return_columns_from_update
):
raise NotSupportedError("This backend does not support UPDATE RETURNING")
self.pre_sql_setup()
if not self.query.values:
return "", ()
@ -1975,28 +1981,48 @@ class SQLUpdateCompiler(SQLCompiler):
params = []
else:
result.append("WHERE %s" % where)
if self.query.update_returning:
select_mask = self.query.get_select_mask()
if self.query.default_cols:
cols = self.get_default_columns(select_mask)
else:
cols = self.query.select
result.append(
"RETURNING %s" % (", ".join(col.target.column for col in cols))
)
return " ".join(result), tuple(update_params + params)
def execute_sql(self, result_type):
def execute_sql(
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
):
"""
Execute the specified update. Return the number of rows affected by
the primary update query. The "primary update query" is the first
non-empty query that is executed. Row counts for any subsequent,
related queries are not available.
"""
cursor = super().execute_sql(result_type)
try:
rows = cursor.rowcount if cursor else 0
is_empty = cursor is None
finally:
if cursor:
cursor.close()
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows
with mark_for_rollback_on_error(using=self.using):
cursor = super().execute_sql(
result_type=result_type,
chunked_fetch=chunked_fetch,
chunk_size=chunk_size,
)
if self.query.update_returning:
return cursor
try:
rows = cursor.rowcount if cursor else 0
is_empty = cursor is None
finally:
if cursor:
cursor.close()
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows
def pre_sql_setup(self):
"""
@ -2007,6 +2033,10 @@ class SQLUpdateCompiler(SQLCompiler):
this point so that they don't change as a result of the progressive
updates.
"""
if self.query.update_returning:
self.has_extra_select = False
self.setup_query()
refcounts_before = self.query.alias_refcount.copy()
# Ensure base table is in the query
self.query.get_initial_alias()

View File

@ -276,6 +276,8 @@ class Query(BaseExpression):
explain_info = None
update_returning = False
def __init__(self, model, alias_cols=True):
self.model = model
self.alias_refcount = {}

View File

@ -59,12 +59,14 @@ class UpdateQuery(Query):
Run on initialization and at the end of chaining. Any attributes that
would normally be set in __init__() should go here instead.
"""
self.values = []
if not hasattr(self, "values"):
self.values = []
self.related_ids = None
self.related_updates = {}
def clone(self):
obj = super().clone()
obj.values = list(self.values)
obj.related_updates = self.related_updates.copy()
return obj

View File

@ -779,6 +779,12 @@ class ReturningModel(models.Model):
created = CreatedField(editable=False)
class UpdateReturningModel(models.Model):
key = models.CharField(max_length=10)
content = models.CharField(max_length=10)
hits = models.IntegerField(default=0)
class NonIntegerPKReturningModel(models.Model):
created = CreatedField(editable=False, primary_key=True)

View File

@ -0,0 +1,95 @@
from django.db.models import F
from django.test import TestCase, skipUnlessDBFeature
from .models import UpdateReturningModel
@skipUnlessDBFeature("can_return_columns_from_update")
class UpdateReturningTests(TestCase):
def test_update_returning_single(self):
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
updated = (
UpdateReturningModel.objects.filter(key=obj.key)
.update_returning(hits=F("hits") + 1)
.get()
)
assert updated.pk == obj.pk
assert updated.key == obj.key
assert updated.content == obj.content
assert updated.hits == obj.hits + 1
def test_update_returning_multiple(self):
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
updated = UpdateReturningModel.objects.filter(hits__gt=1).update_returning(
hits=F("hits") + 10
)
assert len(updated) == 2
for obj in updated:
assert obj.hits in (12, 13)
def test_update_returning_only(self):
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
updated = (
UpdateReturningModel.objects.filter(key=obj.key)
.update_returning(hits=F("hits") + 1)
.only("hits")
.get()
)
assert updated.pk == obj.pk
assert updated.key == obj.key
assert updated.content == obj.content
assert updated.hits == obj.hits + 1, updated.hits
def test_update_returning_defer(self):
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
updated = (
UpdateReturningModel.objects.filter(key=obj.key)
.update_returning(hits=F("hits") + 1)
.defer("hits", "content")
.get()
)
assert updated.pk == obj.pk
assert updated.key == obj.key
assert updated.content == obj.content
assert updated.hits == obj.hits + 1, updated.hits
def test_update_returning_values(self):
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
updated = UpdateReturningModel.objects.update_returning(
hits=F("hits") + 1
).values("pk", "hits")
updated = list(updated)
assert len(updated) == 3
updated.sort(key=lambda x: x["pk"])
assert updated == [
{"pk": 1, "hits": 2},
{"pk": 2, "hits": 3},
{"pk": 3, "hits": 4},
]
def test_update_returning_values_list(self):
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
updated = UpdateReturningModel.objects.update_returning(
hits=F("hits") + 1
).values_list("hits", flat=True)
updated = list(updated)
assert len(updated) == 3
updated.sort(key=lambda x: x)
assert updated == [2, 3, 4]