mirror of
https://github.com/django/django.git
synced 2025-06-07 12:39:12 +00:00
Fixed #32406 Support UPDATE RETURNING
Implement a new QuerySet.update_returning() that executes a QuerySet.update() and returns modified rows on databases that support UPDATE RETURNING (SQLite, PostgreSQL, Oracle)
This commit is contained in:
parent
51d703a27f
commit
c3a0ff4ede
@ -38,6 +38,7 @@ class BaseDatabaseFeatures:
|
|||||||
can_use_chunked_reads = True
|
can_use_chunked_reads = True
|
||||||
can_return_columns_from_insert = False
|
can_return_columns_from_insert = False
|
||||||
can_return_rows_from_bulk_insert = False
|
can_return_rows_from_bulk_insert = False
|
||||||
|
can_return_columns_from_update = False
|
||||||
has_bulk_insert = True
|
has_bulk_insert = True
|
||||||
uses_savepoints = True
|
uses_savepoints = True
|
||||||
can_release_savepoints = False
|
can_release_savepoints = False
|
||||||
|
@ -28,6 +28,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
supports_frame_range_fixed_distance = True
|
supports_frame_range_fixed_distance = True
|
||||||
supports_update_conflicts = True
|
supports_update_conflicts = True
|
||||||
delete_can_self_reference_subquery = False
|
delete_can_self_reference_subquery = False
|
||||||
|
can_return_columns_from_update = False
|
||||||
create_test_procedure_without_params_sql = """
|
create_test_procedure_without_params_sql = """
|
||||||
CREATE PROCEDURE test_procedure ()
|
CREATE PROCEDURE test_procedure ()
|
||||||
BEGIN
|
BEGIN
|
||||||
|
@ -16,6 +16,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
has_select_for_update_of = True
|
has_select_for_update_of = True
|
||||||
select_for_update_of_column = True
|
select_for_update_of_column = True
|
||||||
can_return_columns_from_insert = True
|
can_return_columns_from_insert = True
|
||||||
|
can_return_columns_from_update = True
|
||||||
supports_subqueries_in_group_by = False
|
supports_subqueries_in_group_by = False
|
||||||
ignores_unnecessary_order_by_in_subqueries = False
|
ignores_unnecessary_order_by_in_subqueries = False
|
||||||
supports_transactions = True
|
supports_transactions = True
|
||||||
|
@ -11,6 +11,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
allows_group_by_selected_pks = True
|
allows_group_by_selected_pks = True
|
||||||
can_return_columns_from_insert = True
|
can_return_columns_from_insert = True
|
||||||
can_return_rows_from_bulk_insert = True
|
can_return_rows_from_bulk_insert = True
|
||||||
|
can_return_columns_from_update = True
|
||||||
has_real_datatype = True
|
has_real_datatype = True
|
||||||
has_native_uuid_field = True
|
has_native_uuid_field = True
|
||||||
has_native_duration_field = True
|
has_native_duration_field = True
|
||||||
|
@ -152,3 +152,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
can_return_rows_from_bulk_insert = property(
|
can_return_rows_from_bulk_insert = property(
|
||||||
operator.attrgetter("can_return_columns_from_insert")
|
operator.attrgetter("can_return_columns_from_insert")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def can_return_columns_from_update(self):
|
||||||
|
return Database.sqlite_version_info >= (3, 35)
|
@ -1213,10 +1213,9 @@ class QuerySet(AltersData):
|
|||||||
|
|
||||||
_raw_delete.alters_data = True
|
_raw_delete.alters_data = True
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def _update_query(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Update all elements in the current QuerySet, setting all the given
|
Prepare a query for update
|
||||||
fields to the appropriate values.
|
|
||||||
"""
|
"""
|
||||||
self._not_support_combined_queries("update")
|
self._not_support_combined_queries("update")
|
||||||
if self.query.is_sliced:
|
if self.query.is_sliced:
|
||||||
@ -1247,7 +1246,32 @@ class QuerySet(AltersData):
|
|||||||
|
|
||||||
# Clear any annotations so that they won't be present in subqueries.
|
# Clear any annotations so that they won't be present in subqueries.
|
||||||
query.annotations = {}
|
query.annotations = {}
|
||||||
with transaction.mark_for_rollback_on_error(using=self.db):
|
return query
|
||||||
|
|
||||||
|
def update_returning(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Update all elements in the current QuerySet, setting all the given
|
||||||
|
fields to the appropriate values. Returns a QuerySet containing all
|
||||||
|
updated objects.
|
||||||
|
"""
|
||||||
|
clone = self._chain()
|
||||||
|
clone.query = self._update_query(**kwargs)
|
||||||
|
clone.query.update_returning = True
|
||||||
|
return clone
|
||||||
|
|
||||||
|
update_returning.alters_data = True
|
||||||
|
|
||||||
|
async def aupdate_returning(self, **kwargs):
|
||||||
|
return await sync_to_async(self.update_returning)(**kwargs)
|
||||||
|
|
||||||
|
aupdate_returning.alters_data = True
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Update all elements in the current QuerySet, setting all the given
|
||||||
|
fields to the appropriate values.
|
||||||
|
"""
|
||||||
|
query = self._update_query(**kwargs)
|
||||||
rows = query.get_compiler(self.db).execute_sql(CURSOR)
|
rows = query.get_compiler(self.db).execute_sql(CURSOR)
|
||||||
self._result_cache = None
|
self._result_cache = None
|
||||||
return rows
|
return rows
|
||||||
|
@ -21,7 +21,7 @@ from django.db.models.sql.constants import (
|
|||||||
)
|
)
|
||||||
from django.db.models.sql.query import Query, get_order_dir
|
from django.db.models.sql.query import Query, get_order_dir
|
||||||
from django.db.models.sql.where import AND
|
from django.db.models.sql.where import AND
|
||||||
from django.db.transaction import TransactionManagementError
|
from django.db.transaction import TransactionManagementError, mark_for_rollback_on_error
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
from django.utils.hashable import make_hashable
|
from django.utils.hashable import make_hashable
|
||||||
from django.utils.regex_helper import _lazy_re_compile
|
from django.utils.regex_helper import _lazy_re_compile
|
||||||
@ -1918,6 +1918,12 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||||||
Create the SQL for this query. Return the SQL string and list of
|
Create the SQL for this query. Return the SQL string and list of
|
||||||
parameters.
|
parameters.
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
self.query.update_returning
|
||||||
|
and not self.connection.features.can_return_columns_from_update
|
||||||
|
):
|
||||||
|
raise NotSupportedError("This backend does not support UPDATE RETURNING")
|
||||||
|
|
||||||
self.pre_sql_setup()
|
self.pre_sql_setup()
|
||||||
if not self.query.values:
|
if not self.query.values:
|
||||||
return "", ()
|
return "", ()
|
||||||
@ -1975,16 +1981,36 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||||||
params = []
|
params = []
|
||||||
else:
|
else:
|
||||||
result.append("WHERE %s" % where)
|
result.append("WHERE %s" % where)
|
||||||
|
|
||||||
|
if self.query.update_returning:
|
||||||
|
select_mask = self.query.get_select_mask()
|
||||||
|
if self.query.default_cols:
|
||||||
|
cols = self.get_default_columns(select_mask)
|
||||||
|
else:
|
||||||
|
cols = self.query.select
|
||||||
|
result.append(
|
||||||
|
"RETURNING %s" % (", ".join(col.target.column for col in cols))
|
||||||
|
)
|
||||||
|
|
||||||
return " ".join(result), tuple(update_params + params)
|
return " ".join(result), tuple(update_params + params)
|
||||||
|
|
||||||
def execute_sql(self, result_type):
|
def execute_sql(
|
||||||
|
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Execute the specified update. Return the number of rows affected by
|
Execute the specified update. Return the number of rows affected by
|
||||||
the primary update query. The "primary update query" is the first
|
the primary update query. The "primary update query" is the first
|
||||||
non-empty query that is executed. Row counts for any subsequent,
|
non-empty query that is executed. Row counts for any subsequent,
|
||||||
related queries are not available.
|
related queries are not available.
|
||||||
"""
|
"""
|
||||||
cursor = super().execute_sql(result_type)
|
with mark_for_rollback_on_error(using=self.using):
|
||||||
|
cursor = super().execute_sql(
|
||||||
|
result_type=result_type,
|
||||||
|
chunked_fetch=chunked_fetch,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
if self.query.update_returning:
|
||||||
|
return cursor
|
||||||
try:
|
try:
|
||||||
rows = cursor.rowcount if cursor else 0
|
rows = cursor.rowcount if cursor else 0
|
||||||
is_empty = cursor is None
|
is_empty = cursor is None
|
||||||
@ -2007,6 +2033,10 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||||||
this point so that they don't change as a result of the progressive
|
this point so that they don't change as a result of the progressive
|
||||||
updates.
|
updates.
|
||||||
"""
|
"""
|
||||||
|
if self.query.update_returning:
|
||||||
|
self.has_extra_select = False
|
||||||
|
self.setup_query()
|
||||||
|
|
||||||
refcounts_before = self.query.alias_refcount.copy()
|
refcounts_before = self.query.alias_refcount.copy()
|
||||||
# Ensure base table is in the query
|
# Ensure base table is in the query
|
||||||
self.query.get_initial_alias()
|
self.query.get_initial_alias()
|
||||||
|
@ -276,6 +276,8 @@ class Query(BaseExpression):
|
|||||||
|
|
||||||
explain_info = None
|
explain_info = None
|
||||||
|
|
||||||
|
update_returning = False
|
||||||
|
|
||||||
def __init__(self, model, alias_cols=True):
|
def __init__(self, model, alias_cols=True):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.alias_refcount = {}
|
self.alias_refcount = {}
|
||||||
|
@ -59,12 +59,14 @@ class UpdateQuery(Query):
|
|||||||
Run on initialization and at the end of chaining. Any attributes that
|
Run on initialization and at the end of chaining. Any attributes that
|
||||||
would normally be set in __init__() should go here instead.
|
would normally be set in __init__() should go here instead.
|
||||||
"""
|
"""
|
||||||
|
if not hasattr(self, "values"):
|
||||||
self.values = []
|
self.values = []
|
||||||
self.related_ids = None
|
self.related_ids = None
|
||||||
self.related_updates = {}
|
self.related_updates = {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
obj = super().clone()
|
obj = super().clone()
|
||||||
|
obj.values = list(self.values)
|
||||||
obj.related_updates = self.related_updates.copy()
|
obj.related_updates = self.related_updates.copy()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
@ -779,6 +779,12 @@ class ReturningModel(models.Model):
|
|||||||
created = CreatedField(editable=False)
|
created = CreatedField(editable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateReturningModel(models.Model):
|
||||||
|
key = models.CharField(max_length=10)
|
||||||
|
content = models.CharField(max_length=10)
|
||||||
|
hits = models.IntegerField(default=0)
|
||||||
|
|
||||||
|
|
||||||
class NonIntegerPKReturningModel(models.Model):
|
class NonIntegerPKReturningModel(models.Model):
|
||||||
created = CreatedField(editable=False, primary_key=True)
|
created = CreatedField(editable=False, primary_key=True)
|
||||||
|
|
||||||
|
95
tests/queries/test_update_returning.py
Normal file
95
tests/queries/test_update_returning.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from django.db.models import F
|
||||||
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
|
from .models import UpdateReturningModel
|
||||||
|
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("can_return_columns_from_update")
|
||||||
|
class UpdateReturningTests(TestCase):
|
||||||
|
def test_update_returning_single(self):
|
||||||
|
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
|
||||||
|
|
||||||
|
updated = (
|
||||||
|
UpdateReturningModel.objects.filter(key=obj.key)
|
||||||
|
.update_returning(hits=F("hits") + 1)
|
||||||
|
.get()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated.pk == obj.pk
|
||||||
|
assert updated.key == obj.key
|
||||||
|
assert updated.content == obj.content
|
||||||
|
assert updated.hits == obj.hits + 1
|
||||||
|
|
||||||
|
def test_update_returning_multiple(self):
|
||||||
|
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
|
||||||
|
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
|
||||||
|
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
|
||||||
|
|
||||||
|
updated = UpdateReturningModel.objects.filter(hits__gt=1).update_returning(
|
||||||
|
hits=F("hits") + 10
|
||||||
|
)
|
||||||
|
assert len(updated) == 2
|
||||||
|
for obj in updated:
|
||||||
|
assert obj.hits in (12, 13)
|
||||||
|
|
||||||
|
def test_update_returning_only(self):
|
||||||
|
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
|
||||||
|
|
||||||
|
updated = (
|
||||||
|
UpdateReturningModel.objects.filter(key=obj.key)
|
||||||
|
.update_returning(hits=F("hits") + 1)
|
||||||
|
.only("hits")
|
||||||
|
.get()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated.pk == obj.pk
|
||||||
|
assert updated.key == obj.key
|
||||||
|
assert updated.content == obj.content
|
||||||
|
assert updated.hits == obj.hits + 1, updated.hits
|
||||||
|
|
||||||
|
def test_update_returning_defer(self):
|
||||||
|
obj = UpdateReturningModel.objects.create(key="key1", content="content", hits=0)
|
||||||
|
|
||||||
|
updated = (
|
||||||
|
UpdateReturningModel.objects.filter(key=obj.key)
|
||||||
|
.update_returning(hits=F("hits") + 1)
|
||||||
|
.defer("hits", "content")
|
||||||
|
.get()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated.pk == obj.pk
|
||||||
|
assert updated.key == obj.key
|
||||||
|
assert updated.content == obj.content
|
||||||
|
assert updated.hits == obj.hits + 1, updated.hits
|
||||||
|
|
||||||
|
def test_update_returning_values(self):
|
||||||
|
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
|
||||||
|
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
|
||||||
|
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
|
||||||
|
|
||||||
|
updated = UpdateReturningModel.objects.update_returning(
|
||||||
|
hits=F("hits") + 1
|
||||||
|
).values("pk", "hits")
|
||||||
|
|
||||||
|
updated = list(updated)
|
||||||
|
assert len(updated) == 3
|
||||||
|
updated.sort(key=lambda x: x["pk"])
|
||||||
|
assert updated == [
|
||||||
|
{"pk": 1, "hits": 2},
|
||||||
|
{"pk": 2, "hits": 3},
|
||||||
|
{"pk": 3, "hits": 4},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_update_returning_values_list(self):
|
||||||
|
UpdateReturningModel.objects.create(key="key1", content="content", hits=1)
|
||||||
|
UpdateReturningModel.objects.create(key="key2", content="content", hits=2)
|
||||||
|
UpdateReturningModel.objects.create(key="key3", content="content", hits=3)
|
||||||
|
|
||||||
|
updated = UpdateReturningModel.objects.update_returning(
|
||||||
|
hits=F("hits") + 1
|
||||||
|
).values_list("hits", flat=True)
|
||||||
|
|
||||||
|
updated = list(updated)
|
||||||
|
assert len(updated) == 3
|
||||||
|
updated.sort(key=lambda x: x)
|
||||||
|
assert updated == [2, 3, 4]
|
Loading…
x
Reference in New Issue
Block a user