1
0
mirror of https://github.com/django/django.git synced 2025-09-15 13:39:11 +00:00

Refs #27222 -- Refreshed GeneratedFields values on save() initiated update.

This required implementing UPDATE RETURNING machinery that heavily
borrows from the INSERT one.
This commit is contained in:
Simon Charette 2025-03-19 01:11:34 -04:00 committed by Mariusz Felisiak
parent c48904a225
commit 55a0073b3b
12 changed files with 213 additions and 59 deletions

View File

@ -38,6 +38,7 @@ class BaseDatabaseFeatures:
can_use_chunked_reads = True can_use_chunked_reads = True
can_return_columns_from_insert = False can_return_columns_from_insert = False
can_return_rows_from_bulk_insert = False can_return_rows_from_bulk_insert = False
can_return_rows_from_update = False
has_bulk_insert = True has_bulk_insert = True
uses_savepoints = True uses_savepoints = True
can_release_savepoints = False can_release_savepoints = False

View File

@ -243,6 +243,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"use_returning_into", True "use_returning_into", True
) )
self.features.can_return_columns_from_insert = use_returning_into self.features.can_return_columns_from_insert = use_returning_into
self.features.can_return_rows_from_update = use_returning_into
@property @property
def is_pool(self): def is_pool(self):

View File

@ -19,6 +19,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_select_for_update_of = True has_select_for_update_of = True
select_for_update_of_column = True select_for_update_of_column = True
can_return_columns_from_insert = True can_return_columns_from_insert = True
can_return_rows_from_update = True
supports_subqueries_in_group_by = False supports_subqueries_in_group_by = False
ignores_unnecessary_order_by_in_subqueries = False ignores_unnecessary_order_by_in_subqueries = False
supports_tuple_comparison_against_subquery = False supports_tuple_comparison_against_subquery = False

View File

@ -11,6 +11,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True allows_group_by_selected_pks = True
can_return_columns_from_insert = True can_return_columns_from_insert = True
can_return_rows_from_bulk_insert = True can_return_rows_from_bulk_insert = True
can_return_rows_from_update = True
has_real_datatype = True has_real_datatype = True
has_native_uuid_field = True has_native_uuid_field = True
has_native_duration_field = True has_native_duration_field = True

View File

@ -171,3 +171,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_return_rows_from_bulk_insert = property( can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert") operator.attrgetter("can_return_columns_from_insert")
) )
can_return_rows_from_update = property(
operator.attrgetter("can_return_columns_from_insert")
)

View File

@ -1094,12 +1094,28 @@ class Model(AltersData, metaclass=ModelBase):
] ]
forced_update = update_fields or force_update forced_update = update_fields or force_update
pk_val = self._get_pk_val(meta) pk_val = self._get_pk_val(meta)
updated = self._do_update( returning_fields = [
base_qs, using, pk_val, values, update_fields, forced_update f
for f in meta.local_concrete_fields
if (
f.generated
and f.referenced_fields.intersection(non_pks_non_generated)
)
]
results = self._do_update(
base_qs,
using,
pk_val,
values,
update_fields,
forced_update,
returning_fields,
) )
if force_update and not updated: if updated := bool(results):
self._assign_returned_values(results[0], returning_fields)
elif force_update:
raise self.NotUpdated("Forced update did not affect any rows.") raise self.NotUpdated("Forced update did not affect any rows.")
if update_fields and not updated: elif update_fields:
raise self.NotUpdated( raise self.NotUpdated(
"Save with update_fields did not affect any rows." "Save with update_fields did not affect any rows."
) )
@ -1131,11 +1147,19 @@ class Model(AltersData, metaclass=ModelBase):
cls._base_manager, using, fields, returning_fields, raw cls._base_manager, using, fields, returning_fields, raw
) )
if results: if results:
for value, field in zip(results[0], returning_fields): self._assign_returned_values(results[0], returning_fields)
setattr(self, field.attname, value)
return updated return updated
def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): def _do_update(
self,
base_qs,
using,
pk_val,
values,
update_fields,
forced_update,
returning_fields,
):
""" """
Try to update the model. Return True if the model was updated (if an Try to update the model. Return True if the model was updated (if an
update query was done and a matching row was found in the DB). update query was done and a matching row was found in the DB).
@ -1147,22 +1171,23 @@ class Model(AltersData, metaclass=ModelBase):
# case we just say the update succeeded. Another case ending up # case we just say the update succeeded. Another case ending up
# here is a model with just PK - in that case check that the PK # here is a model with just PK - in that case check that the PK
# still exists. # still exists.
return update_fields is not None or filtered.exists() if update_fields is not None or filtered.exists():
return [()]
return []
if self._meta.select_on_save and not forced_update: if self._meta.select_on_save and not forced_update:
return ( # It may happen that the object is deleted from the DB right after
filtered.exists() # this check, causing the subsequent UPDATE to return zero matching
and # rows. The same result can occur in some rare cases when the
# It may happen that the object is deleted from the DB right # database returns zero despite the UPDATE being executed
# after this check, causing the subsequent UPDATE to return # successfully (a row is matched and updated). In order to
# zero matching rows. The same result can occur in some rare # distinguish these two cases, the object's existence in the
# cases when the database returns zero despite the UPDATE being # database is again checked for if the UPDATE query returns 0.
# executed successfully (a row is matched and updated). In if not filtered.exists():
# order to distinguish these two cases, the object's existence return []
# in the database is again checked for if the UPDATE query if results := filtered._update(values, returning_fields):
# returns 0. return results
(filtered._update(values) > 0 or filtered.exists()) return [()] if filtered.exists() else []
) return filtered._update(values, returning_fields)
return filtered._update(values) > 0
def _do_insert(self, manager, using, fields, returning_fields, raw): def _do_insert(self, manager, using, fields, returning_fields, raw):
""" """
@ -1177,6 +1202,10 @@ class Model(AltersData, metaclass=ModelBase):
raw=raw, raw=raw,
) )
def _assign_returned_values(self, returned_values, returning_fields):
for value, field in zip(returned_values, returning_fields):
setattr(self, field.attname, value)
def _prepare_related_fields_for_save(self, operation_name, fields=None): def _prepare_related_fields_for_save(self, operation_name, fields=None):
# Ensure that a model instance without a PK hasn't been assigned to # Ensure that a model instance without a PK hasn't been assigned to
# a ForeignKey, GenericForeignKey or OneToOneField on this model. If # a ForeignKey, GenericForeignKey or OneToOneField on this model. If

View File

@ -66,6 +66,16 @@ class GeneratedField(Field):
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params return sql, params
@cached_property
def referenced_fields(self):
resolved_expression = self.expression.resolve_expression(
self._query, allow_joins=False
)
referenced_fields = []
for col in self._query._gen_cols([resolved_expression]):
referenced_fields.append(col.target)
return frozenset(referenced_fields)
def check(self, **kwargs): def check(self, **kwargs):
databases = kwargs.get("databases") or [] databases = kwargs.get("databases") or []
errors = [ errors = [

View File

@ -1306,7 +1306,7 @@ class QuerySet(AltersData):
aupdate.alters_data = True aupdate.alters_data = True
def _update(self, values): def _update(self, values, returning_fields=None):
""" """
A version of update() that accepts field objects instead of field A version of update() that accepts field objects instead of field
names. Used primarily for model saving and not intended for use by names. Used primarily for model saving and not intended for use by
@ -1320,7 +1320,9 @@ class QuerySet(AltersData):
# Clear any annotations so that they won't be present in subqueries. # Clear any annotations so that they won't be present in subqueries.
query.annotations = {} query.annotations = {}
self._result_cache = None self._result_cache = None
return query.get_compiler(self.db).execute_sql(ROW_COUNT) if returning_fields is None:
return query.get_compiler(self.db).execute_sql(ROW_COUNT)
return query.get_compiler(self.db).execute_returning_sql(returning_fields)
_update.alters_data = True _update.alters_data = True
_update.queryset_only = False _update.queryset_only = False

View File

@ -2020,6 +2020,9 @@ class SQLDeleteCompiler(SQLCompiler):
class SQLUpdateCompiler(SQLCompiler): class SQLUpdateCompiler(SQLCompiler):
returning_fields = None
returning_params = ()
def as_sql(self): def as_sql(self):
""" """
Create the SQL for this query. Return the SQL string and list of Create the SQL for this query. Return the SQL string and list of
@ -2087,6 +2090,15 @@ class SQLUpdateCompiler(SQLCompiler):
params = [] params = []
else: else:
result.append("WHERE %s" % where) result.append("WHERE %s" % where)
if self.returning_fields:
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
r_sql, self.returning_params = self.connection.ops.returning_columns(
self.returning_fields
)
if r_sql:
result.append(r_sql)
params.extend(self.returning_params)
return " ".join(result), tuple(update_params + params) return " ".join(result), tuple(update_params + params)
def execute_sql(self, result_type): def execute_sql(self, result_type):
@ -2110,6 +2122,38 @@ class SQLUpdateCompiler(SQLCompiler):
is_empty = False is_empty = False
return row_count return row_count
def execute_returning_sql(self, returning_fields):
"""
Execute the specified update and return rows of the returned columns
associated with the specified returning_field if the backend supports
it.
"""
if self.query.get_related_updates():
raise NotImplementedError(
"Update returning is not implemented for queries with related updates."
)
if (
not returning_fields
or not self.connection.features.can_return_rows_from_update
):
row_count = self.execute_sql(ROW_COUNT)
return [()] * row_count
self.returning_fields = returning_fields
with self.connection.cursor() as cursor:
sql, params = self.as_sql()
cursor.execute(sql, params)
rows = self.connection.ops.fetch_returned_rows(
cursor, self.returning_params
)
opts = self.query.get_meta()
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
converters = self.get_converters(cols)
if converters:
rows = self.apply_converters(rows, converters)
return list(rows)
def pre_sql_setup(self): def pre_sql_setup(self):
""" """
If the update depends on results from other tables, munge the "where" If the update depends on results from other tables, munge the "where"

View File

@ -1315,12 +1315,6 @@ materialized view.
PostgreSQL only supports persisted columns. Oracle only supports virtual PostgreSQL only supports persisted columns. Oracle only supports virtual
columns. columns.
.. admonition:: Refresh the data
Since the database computes the value, the object must be reloaded to
access the new value after :meth:`~Model.save`, for example, by using
:meth:`~Model.refresh_from_db`.
.. admonition:: Database limitations .. admonition:: Database limitations
There are many database-specific restrictions on generated fields that There are many database-specific restrictions on generated fields that
@ -1338,6 +1332,12 @@ materialized view.
.. _PostgreSQL: https://www.postgresql.org/docs/current/ddl-generated-columns.html .. _PostgreSQL: https://www.postgresql.org/docs/current/ddl-generated-columns.html
.. _SQLite: https://www.sqlite.org/gencol.html#limitations .. _SQLite: https://www.sqlite.org/gencol.html#limitations
.. versionchanged:: 6.0
``GeneratedField``\s are now automatically refreshed from the database on
backends that support it (SQLite, PostgreSQL, and Oracle) and marked as
deferred otherwise.
``GenericIPAddressField`` ``GenericIPAddressField``
------------------------- -------------------------

View File

@ -331,6 +331,12 @@ Models
value from the non-null input values. This is supported on SQLite, MySQL, value from the non-null input values. This is supported on SQLite, MySQL,
Oracle, and PostgreSQL 16+. Oracle, and PostgreSQL 16+.
* :class:`~django.db.models.GeneratedField`\s are now refreshed from the
database after :meth:`~django.db.models.Model.save` on backends that support
the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that
don't support it (MySQL and MariaDB), the fields are marked as deferred to
trigger a refresh on subsequent accesses.
Pagination Pagination
~~~~~~~~~~ ~~~~~~~~~~
@ -420,6 +426,9 @@ backends.
``returning_params`` to be provided just like ``returning_params`` to be provided just like
``fetch_returned_insert_columns()`` did. ``fetch_returned_insert_columns()`` did.
* If the database supports ``UPDATE … RETURNING`` statements, backends can set
``DatabaseFeatures.can_return_rows_from_update=True``.
Dropped support for MariaDB 10.5 Dropped support for MariaDB 10.5
-------------------------------- --------------------------------

View File

@ -173,11 +173,6 @@ class BaseGeneratedFieldTests(SimpleTestCase):
class GeneratedFieldTestMixin: class GeneratedFieldTestMixin:
def _refresh_if_needed(self, m):
if not connection.features.can_return_columns_from_insert:
m.refresh_from_db()
return m
def test_unsaved_error(self): def test_unsaved_error(self):
m = self.base_model(a=1, b=2) m = self.base_model(a=1, b=2)
msg = "Cannot retrieve deferred field 'field' from an unsaved model." msg = "Cannot retrieve deferred field 'field' from an unsaved model."
@ -189,8 +184,11 @@ class GeneratedFieldTestMixin:
# full_clean() ignores GeneratedFields. # full_clean() ignores GeneratedFields.
m.full_clean() m.full_clean()
m.save() m.save()
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.field, 3) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
@skipUnlessDBFeature("supports_table_check_constraints") @skipUnlessDBFeature("supports_table_check_constraints")
def test_full_clean_with_check_constraint(self): def test_full_clean_with_check_constraint(self):
@ -199,8 +197,11 @@ class GeneratedFieldTestMixin:
m = self.check_constraint_model(a=2) m = self.check_constraint_model(a=2)
m.full_clean() m.full_clean()
m.save() m.save()
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.a_squared, 4) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.a_squared, 4)
m = self.check_constraint_model(a=-1) m = self.check_constraint_model(a=-1)
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
@ -217,8 +218,11 @@ class GeneratedFieldTestMixin:
m = self.unique_constraint_model(a=2) m = self.unique_constraint_model(a=2)
m.full_clean() m.full_clean()
m.save() m.save()
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.a_squared, 4) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.a_squared, 4)
m = self.unique_constraint_model(a=2) m = self.unique_constraint_model(a=2)
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
@ -230,8 +234,11 @@ class GeneratedFieldTestMixin:
def test_create(self): def test_create(self):
m = self.base_model.objects.create(a=1, b=2) m = self.base_model.objects.create(a=1, b=2)
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.field, 3) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
def test_non_nullable_create(self): def test_non_nullable_create(self):
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
@ -241,26 +248,52 @@ class GeneratedFieldTestMixin:
# Insert. # Insert.
m = self.base_model(a=2, b=4) m = self.base_model(a=2, b=4)
m.save() m.save()
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.field, 6) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 6)
# Update. # Update.
m.a = 4 m.a = 4
m.save() m.save()
m.refresh_from_db() expected_num_queries = (
self.assertEqual(m.field, 8) 0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 8)
# Update non-dependent field.
self.base_model.objects.filter(pk=m.pk).update(a=6)
m.save(update_fields=["fk"])
with self.assertNumQueries(0):
self.assertEqual(m.field, 8)
# Update dependent field without persisting local changes.
m.save(update_fields=["b"])
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 10)
# Update dependent field while persisting local changes.
m.a = 8
m.save(update_fields=["a"])
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 12)
def test_save_model_with_pk(self): def test_save_model_with_pk(self):
m = self.base_model(pk=1, a=1, b=2) m = self.base_model(pk=1, a=1, b=2)
m.save() m.save()
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.field, 3) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
def test_save_model_with_foreign_key(self): def test_save_model_with_foreign_key(self):
fk_object = Foo.objects.create(a="abc", d=Decimal("12.34")) fk_object = Foo.objects.create(a="abc", d=Decimal("12.34"))
m = self.base_model(a=1, b=2, fk=fk_object) m = self.base_model(a=1, b=2, fk=fk_object)
m.save() m.save()
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.field, 3) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
def test_generated_fields_can_be_deferred(self): def test_generated_fields_can_be_deferred(self):
fk_object = Foo.objects.create(a="abc", d=Decimal("12.34")) fk_object = Foo.objects.create(a="abc", d=Decimal("12.34"))
@ -330,17 +363,23 @@ class GeneratedFieldTestMixin:
def test_model_with_params(self): def test_model_with_params(self):
m = self.params_model.objects.create() m = self.params_model.objects.create()
m = self._refresh_if_needed(m) expected_num_queries = (
self.assertEqual(m.field, "Constant") 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, "Constant")
def test_nullable(self): def test_nullable(self):
m1 = self.nullable_model.objects.create() m1 = self.nullable_model.objects.create()
m1 = self._refresh_if_needed(m1)
none_val = "" if connection.features.interprets_empty_strings_as_nulls else None none_val = "" if connection.features.interprets_empty_strings_as_nulls else None
self.assertEqual(m1.lower_name, none_val) expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m1.lower_name, none_val)
m2 = self.nullable_model.objects.create(name="NaMe") m2 = self.nullable_model.objects.create(name="NaMe")
m2 = self._refresh_if_needed(m2) with self.assertNumQueries(expected_num_queries):
self.assertEqual(m2.lower_name, "name") self.assertEqual(m2.lower_name, "name")
@skipUnlessDBFeature("supports_stored_generated_columns") @skipUnlessDBFeature("supports_stored_generated_columns")
@ -354,8 +393,21 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
def test_create_field_with_db_converters(self): def test_create_field_with_db_converters(self):
obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4()) obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4())
obj = self._refresh_if_needed(obj) expected_num_queries = (
self.assertEqual(obj.field, obj.field_copy) 0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj.field, obj.field_copy)
def test_save_field_with_db_converters(self):
obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4())
obj.field = uuid.uuid4()
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
obj.save(update_fields={"field"})
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj.field, obj.field_copy)
def test_create_with_non_auto_pk(self): def test_create_with_non_auto_pk(self):
obj = GeneratedModelNonAutoPk.objects.create(id=1, a=2) obj = GeneratedModelNonAutoPk.objects.create(id=1, a=2)