mirror of
https://github.com/django/django.git
synced 2025-04-06 14:36:41 +00:00
Fixed #32406 Support UPDATE RETURNING
Implement a new QuerySet.update_returning() that executes a QuerySet.update() and returns modified rows on databases that support UPDATE RETURNING (SQLite, PostgreSQL, Oracle)
This commit is contained in:
parent
51d703a27f
commit
c3a0ff4ede
@ -38,6 +38,7 @@ class BaseDatabaseFeatures:
|
||||
can_use_chunked_reads = True
|
||||
can_return_columns_from_insert = False
|
||||
can_return_rows_from_bulk_insert = False
|
||||
can_return_columns_from_update = False
|
||||
has_bulk_insert = True
|
||||
uses_savepoints = True
|
||||
can_release_savepoints = False
|
||||
|
@ -28,6 +28,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
supports_frame_range_fixed_distance = True
|
||||
supports_update_conflicts = True
|
||||
delete_can_self_reference_subquery = False
|
||||
can_return_columns_from_update = False
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE test_procedure ()
|
||||
BEGIN
|
||||
|
@ -16,6 +16,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
has_select_for_update_of = True
|
||||
select_for_update_of_column = True
|
||||
can_return_columns_from_insert = True
|
||||
can_return_columns_from_update = True
|
||||
supports_subqueries_in_group_by = False
|
||||
ignores_unnecessary_order_by_in_subqueries = False
|
||||
supports_transactions = True
|
||||
|
@ -11,6 +11,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
allows_group_by_selected_pks = True
|
||||
can_return_columns_from_insert = True
|
||||
can_return_rows_from_bulk_insert = True
|
||||
can_return_columns_from_update = True
|
||||
has_real_datatype = True
|
||||
has_native_uuid_field = True
|
||||
has_native_duration_field = True
|
||||
|
@ -152,3 +152,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
can_return_rows_from_bulk_insert = property(
|
||||
operator.attrgetter("can_return_columns_from_insert")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_update(self):
|
||||
return Database.sqlite_version_info >= (3, 35)
|
@ -1213,10 +1213,9 @@ class QuerySet(AltersData):
|
||||
|
||||
_raw_delete.alters_data = True
|
||||
|
||||
def update(self, **kwargs):
|
||||
def _update_query(self, **kwargs):
|
||||
"""
|
||||
Update all elements in the current QuerySet, setting all the given
|
||||
fields to the appropriate values.
|
||||
Prepare a query for update
|
||||
"""
|
||||
self._not_support_combined_queries("update")
|
||||
if self.query.is_sliced:
|
||||
@ -1247,8 +1246,33 @@ class QuerySet(AltersData):
|
||||
|
||||
# Clear any annotations so that they won't be present in subqueries.
|
||||
query.annotations = {}
|
||||
with transaction.mark_for_rollback_on_error(using=self.db):
|
||||
rows = query.get_compiler(self.db).execute_sql(CURSOR)
|
||||
return query
|
||||
|
||||
def update_returning(self, **kwargs):
|
||||
"""
|
||||
Update all elements in the current QuerySet, setting all the given
|
||||
fields to the appropriate values. Returns a QuerySet containing all
|
||||
updated objects.
|
||||
"""
|
||||
clone = self._chain()
|
||||
clone.query = self._update_query(**kwargs)
|
||||
clone.query.update_returning = True
|
||||
return clone
|
||||
|
||||
update_returning.alters_data = True
|
||||
|
||||
async def aupdate_returning(self, **kwargs):
|
||||
return await sync_to_async(self.update_returning)(**kwargs)
|
||||
|
||||
aupdate_returning.alters_data = True
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""
|
||||
Update all elements in the current QuerySet, setting all the given
|
||||
fields to the appropriate values.
|
||||
"""
|
||||
query = self._update_query(**kwargs)
|
||||
rows = query.get_compiler(self.db).execute_sql(CURSOR)
|
||||
self._result_cache = None
|
||||
return rows
|
||||
|
||||
|
@ -21,7 +21,7 @@ from django.db.models.sql.constants import (
|
||||
)
|
||||
from django.db.models.sql.query import Query, get_order_dir
|
||||
from django.db.models.sql.where import AND
|
||||
from django.db.transaction import TransactionManagementError
|
||||
from django.db.transaction import TransactionManagementError, mark_for_rollback_on_error
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
@ -1918,6 +1918,12 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
Create the SQL for this query. Return the SQL string and list of
|
||||
parameters.
|
||||
"""
|
||||
if (
|
||||
self.query.update_returning
|
||||
and not self.connection.features.can_return_columns_from_update
|
||||
):
|
||||
raise NotSupportedError("This backend does not support UPDATE RETURNING")
|
||||
|
||||
self.pre_sql_setup()
|
||||
if not self.query.values:
|
||||
return "", ()
|
||||
@ -1975,28 +1981,48 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
params = []
|
||||
else:
|
||||
result.append("WHERE %s" % where)
|
||||
|
||||
if self.query.update_returning:
|
||||
select_mask = self.query.get_select_mask()
|
||||
if self.query.default_cols:
|
||||
cols = self.get_default_columns(select_mask)
|
||||
else:
|
||||
cols = self.query.select
|
||||
result.append(
|
||||
"RETURNING %s" % (", ".join(col.target.column for col in cols))
|
||||
)
|
||||
|
||||
return " ".join(result), tuple(update_params + params)
|
||||
|
||||
def execute_sql(self, result_type):
|
||||
def execute_sql(
|
||||
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
|
||||
):
|
||||
"""
|
||||
Execute the specified update. Return the number of rows affected by
|
||||
the primary update query. The "primary update query" is the first
|
||||
non-empty query that is executed. Row counts for any subsequent,
|
||||
related queries are not available.
|
||||
"""
|
||||
cursor = super().execute_sql(result_type)
|
||||
try:
|
||||
rows = cursor.rowcount if cursor else 0
|
||||
is_empty = cursor is None
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
for query in self.query.get_related_updates():
|
||||
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
|
||||
if is_empty and aux_rows:
|
||||
rows = aux_rows
|
||||
is_empty = False
|
||||
return rows
|
||||
with mark_for_rollback_on_error(using=self.using):
|
||||
cursor = super().execute_sql(
|
||||
result_type=result_type,
|
||||
chunked_fetch=chunked_fetch,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
if self.query.update_returning:
|
||||
return cursor
|
||||
try:
|
||||
rows = cursor.rowcount if cursor else 0
|
||||
is_empty = cursor is None
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
for query in self.query.get_related_updates():
|
||||
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
|
||||
if is_empty and aux_rows:
|
||||
rows = aux_rows
|
||||
is_empty = False
|
||||
return rows
|
||||
|
||||
def pre_sql_setup(self):
|
||||
"""
|
||||
@ -2007,6 +2033,10 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
this point so that they don't change as a result of the progressive
|
||||
updates.
|
||||
"""
|
||||
if self.query.update_returning:
|
||||
self.has_extra_select = False
|
||||
self.setup_query()
|
||||
|
||||
refcounts_before = self.query.alias_refcount.copy()
|
||||
# Ensure base table is in the query
|
||||
self.query.get_initial_alias()
|
||||
|
@ -276,6 +276,8 @@ class Query(BaseExpression):
|
||||
|
||||
explain_info = None
|
||||
|
||||
update_returning = False
|
||||
|
||||
def __init__(self, model, alias_cols=True):
|
||||
self.model = model
|
||||
self.alias_refcount = {}
|
||||
|
@ -59,12 +59,14 @@ class UpdateQuery(Query):
|
||||
Run on initialization and at the end of chaining. Any attributes that
|
||||
would normally be set in __init__() should go here instead.
|
||||
"""
|
||||
self.values = []
|
||||
if not hasattr(self, "values"):
|
||||
self.values = []
|
||||
self.related_ids = None
|
||||
self.related_updates = {}
|
||||
|
||||
def clone(self):
|
||||
obj = super().clone()
|
||||
obj.values = list(self.values)
|
||||
obj.related_updates = self.related_updates.copy()
|
||||
return obj
|
||||
|
||||
|
@ -779,6 +779,12 @@ class ReturningModel(models.Model):
|
||||
created = CreatedField(editable=False)
|
||||
|
||||
|
||||
class UpdateReturningModel(models.Model):
|
||||
key = models.CharField(max_length=10)
|
||||
content = models.CharField(max_length=10)
|
||||
hits = models.IntegerField(default=0)
|
||||
|
||||
|
||||
class NonIntegerPKReturningModel(models.Model):
|
||||
created = CreatedField(editable=False, primary_key=True)
|
||||
|
||||
|
95
tests/queries/test_update_returning.py
Normal file
95
tests/queries/test_update_returning.py
Normal file
@ -0,0 +1,95 @@
|
||||
from django.db.models import F
|
||||
from django.test import TestCase, skipUnlessDBFeature
|
||||
|
||||
from .models import UpdateReturningModel
|
||||
|
||||
|
||||
@skipUnlessDBFeature("can_return_columns_from_update")
|
||||
class UpdateReturningTests(TestCase):
|
||||
def test_update_returning_single(self):
|
||||
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
|
||||
|
||||
updated = (
|
||||
UpdateReturningModel.objects.filter(key=obj.key)
|
||||
.update_returning(hits=F("hits") + 1)
|
||||
.get()
|
||||
)
|
||||
|
||||
assert updated.pk == obj.pk
|
||||
assert updated.key == obj.key
|
||||
assert updated.content == obj.content
|
||||
assert updated.hits == obj.hits + 1
|
||||
|
||||
def test_update_returning_multiple(self):
|
||||
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
|
||||
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
|
||||
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
|
||||
|
||||
updated = UpdateReturningModel.objects.filter(hits__gt=1).update_returning(
|
||||
hits=F("hits") + 10
|
||||
)
|
||||
assert len(updated) == 2
|
||||
for obj in updated:
|
||||
assert obj.hits in (12, 13)
|
||||
|
||||
def test_update_returning_only(self):
|
||||
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
|
||||
|
||||
updated = (
|
||||
UpdateReturningModel.objects.filter(key=obj.key)
|
||||
.update_returning(hits=F("hits") + 1)
|
||||
.only("hits")
|
||||
.get()
|
||||
)
|
||||
|
||||
assert updated.pk == obj.pk
|
||||
assert updated.key == obj.key
|
||||
assert updated.content == obj.content
|
||||
assert updated.hits == obj.hits + 1, updated.hits
|
||||
|
||||
def test_update_returning_defer(self):
|
||||
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
|
||||
|
||||
updated = (
|
||||
UpdateReturningModel.objects.filter(key=obj.key)
|
||||
.update_returning(hits=F("hits") + 1)
|
||||
.defer("hits", "content")
|
||||
.get()
|
||||
)
|
||||
|
||||
assert updated.pk == obj.pk
|
||||
assert updated.key == obj.key
|
||||
assert updated.content == obj.content
|
||||
assert updated.hits == obj.hits + 1, updated.hits
|
||||
|
||||
def test_update_returning_values(self):
|
||||
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
|
||||
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
|
||||
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
|
||||
|
||||
updated = UpdateReturningModel.objects.update_returning(
|
||||
hits=F("hits") + 1
|
||||
).values("pk", "hits")
|
||||
|
||||
updated = list(updated)
|
||||
assert len(updated) == 3
|
||||
updated.sort(key=lambda x: x["pk"])
|
||||
assert updated == [
|
||||
{"pk": 1, "hits": 2},
|
||||
{"pk": 2, "hits": 3},
|
||||
{"pk": 3, "hits": 4},
|
||||
]
|
||||
|
||||
def test_update_returning_values_list(self):
|
||||
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
|
||||
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
|
||||
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
|
||||
|
||||
updated = UpdateReturningModel.objects.update_returning(
|
||||
hits=F("hits") + 1
|
||||
).values_list("hits", flat=True)
|
||||
|
||||
updated = list(updated)
|
||||
assert len(updated) == 3
|
||||
updated.sort(key=lambda x: x)
|
||||
assert updated == [2, 3, 4]
|
Loading…
x
Reference in New Issue
Block a user