1
0
mirror of https://github.com/django/django.git synced 2025-09-15 21:49:24 +00:00

Fixed #27222 -- Refreshed model field values assigned expressions on save().

Removed the can_return_columns_from_insert skip gates on existing
field_defaults tests to confirm the expected number of queries are
performed and that returning field overrides are respected.
This commit is contained in:
Simon Charette 2025-03-19 01:39:19 -04:00 committed by Mariusz Felisiak
parent 55a0073b3b
commit 94680437a4
7 changed files with 123 additions and 63 deletions

View File

@ -1102,6 +1102,11 @@ class Model(AltersData, metaclass=ModelBase):
and f.referenced_fields.intersection(non_pks_non_generated) and f.referenced_fields.intersection(non_pks_non_generated)
) )
] ]
for field, _model, value in values:
if (update_fields is None or field.name in update_fields) and hasattr(
value, "resolve_expression"
):
returning_fields.append(field)
results = self._do_update( results = self._do_update(
base_qs, base_qs,
using, using,
@ -1142,7 +1147,15 @@ class Model(AltersData, metaclass=ModelBase):
for f in meta.local_concrete_fields for f in meta.local_concrete_fields
if not f.generated and (pk_set or f is not meta.auto_field) if not f.generated and (pk_set or f is not meta.auto_field)
] ]
returning_fields = meta.db_returning_fields returning_fields = list(meta.db_returning_fields)
for field in fields:
value = (
getattr(self, field.attname) if raw else field.pre_save(self, False)
)
if hasattr(value, "resolve_expression"):
returning_fields.append(field)
elif field.db_returning:
returning_fields.remove(field)
results = self._do_insert( results = self._do_insert(
cls._base_manager, using, fields, returning_fields, raw cls._base_manager, using, fields, returning_fields, raw
) )
@ -1203,8 +1216,13 @@ class Model(AltersData, metaclass=ModelBase):
) )
def _assign_returned_values(self, returned_values, returning_fields): def _assign_returned_values(self, returned_values, returning_fields):
for value, field in zip(returned_values, returning_fields): returning_fields_iter = iter(returning_fields)
for value, field in zip(returned_values, returning_fields_iter):
setattr(self, field.attname, value) setattr(self, field.attname, value)
# Defer all fields that were meant to be updated with their database
# resolved values but couldn't as they are effectively stale.
for field in returning_fields_iter:
self.__dict__.pop(field.attname, None)
def _prepare_related_fields_for_save(self, operation_name, fields=None): def _prepare_related_fields_for_save(self, operation_name, fields=None):
# Ensure that a model instance without a PK hasn't been assigned to # Ensure that a model instance without a PK hasn't been assigned to

View File

@ -69,8 +69,6 @@ Some examples
# Create a new company using expressions. # Create a new company using expressions.
>>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog"))) >>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog")))
# Be sure to refresh it if you need to access the field.
>>> company.refresh_from_db()
>>> company.ticker >>> company.ticker
'GOOG' 'GOOG'
@ -157,12 +155,6 @@ know about it - it is dealt with entirely by the database. All Python does,
through Django's ``F()`` class, is create the SQL syntax to refer to the field through Django's ``F()`` class, is create the SQL syntax to refer to the field
and describe the operation. and describe the operation.
To access the new value saved this way, the object must be reloaded::
reporter = Reporters.objects.get(pk=reporter.pk)
# Or, more succinctly:
reporter.refresh_from_db()
As well as being used in operations on single instances as above, ``F()`` can As well as being used in operations on single instances as above, ``F()`` can
be used with ``update()`` to perform bulk updates on a ``QuerySet``. This be used with ``update()`` to perform bulk updates on a ``QuerySet``. This
reduces the two queries we were using above - the ``get()`` and the reduces the two queries we were using above - the ``get()`` and the
@ -199,7 +191,6 @@ array-slicing syntax. The indices are 0-based and the ``step`` argument to
>>> writer = Writers.objects.get(name="Priyansh") >>> writer = Writers.objects.get(name="Priyansh")
>>> writer.name = F("name")[1:5] >>> writer.name = F("name")[1:5]
>>> writer.save() >>> writer.save()
>>> writer.refresh_from_db()
>>> writer.name >>> writer.name
'riya' 'riya'
@ -221,23 +212,27 @@ robust: it will only ever update the field based on the value of the field in
the database when the :meth:`~Model.save` or ``update()`` is executed, rather the database when the :meth:`~Model.save` or ``update()`` is executed, rather
than based on its value when the instance was retrieved. than based on its value when the instance was retrieved.
``F()`` assignments persist after ``Model.save()`` ``F()`` assignments are refreshed after ``Model.save()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``F()`` objects assigned to model fields persist after saving the model ``F()`` objects assigned to model fields are refreshed from the database on
instance and will be applied on each :meth:`~Model.save`. For example:: :meth:`~Model.save` on backends that support it without incurring a subsequent
query (SQLite, PostgreSQL, and Oracle) and deferred otherwise (MySQL or
MariaDB). For example:
reporter = Reporters.objects.get(name="Tintin") .. code-block:: pycon
reporter.stories_filed = F("stories_filed") + 1
reporter.save()
reporter.name = "Tintin Jr." >>> reporter = Reporters.objects.get(name="Tintin")
reporter.save() >>> reporter.stories_filed = F("stories_filed") + 1
>>> reporter.save()
>>> reporter.stories_filed # This triggers a refresh query on MySQL/MariaDB.
14 # Assuming the database value was 13 when the object was saved.
``stories_filed`` will be updated twice in this case. If it's initially ``1``, .. versionchanged:: 6.0
the final value will be ``3``. This persistence can be avoided by reloading the
model object after saving it, for example, by using In previous versions of Django, ``F()`` objects were not refreshed from the
:meth:`~Model.refresh_from_db`. database on :meth:`~Model.save` which resulted in them being evaluated and
persisted every time the instance was saved.
Using ``F()`` in filters Using ``F()`` in filters
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -331,7 +331,8 @@ Models
value from the non-null input values. This is supported on SQLite, MySQL, value from the non-null input values. This is supported on SQLite, MySQL,
Oracle, and PostgreSQL 16+. Oracle, and PostgreSQL 16+.
* :class:`~django.db.models.GeneratedField`\s are now refreshed from the * :class:`~django.db.models.GeneratedField`\s and :ref:`fields assigned
expressions <avoiding-race-conditions-using-f>` are now refreshed from the
database after :meth:`~django.db.models.Model.save` on backends that support database after :meth:`~django.db.models.Model.save` on backends that support
the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that
don't support it (MySQL and MariaDB), the fields are marked as deferred to don't support it (MySQL and MariaDB), the fields are marked as deferred to

View File

@ -1,5 +1,6 @@
import inspect import inspect
import threading import threading
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest import mock from unittest import mock
@ -12,6 +13,7 @@ from django.db import (
models, models,
transaction, transaction,
) )
from django.db.models.functions import Now
from django.db.models.manager import BaseManager from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
from django.test import ( from django.test import (
@ -558,6 +560,26 @@ class ModelTest(TestCase):
with self.subTest(case=case): with self.subTest(case=case):
self.assertIs(case._is_pk_set(), True) self.assertIs(case._is_pk_set(), True)
def test_save_expressions(self):
article = Article(pub_date=Now())
article.save()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
article_pub_date = article.pub_date
self.assertIsInstance(article_pub_date, datetime)
# Sleep slightly to ensure a different database level NOW().
time.sleep(0.1)
article.pub_date = Now()
article.save()
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertIsInstance(article.pub_date, datetime)
self.assertGreater(article.pub_date, article_pub_date)
class ModelLookupTest(TestCase): class ModelLookupTest(TestCase):
@classmethod @classmethod

View File

@ -420,8 +420,11 @@ class BasicExpressionsTests(TestCase):
# F expressions can be used to update attributes on single objects # F expressions can be used to update attributes on single objects
self.gmbh.num_employees = F("num_employees") + 4 self.gmbh.num_employees = F("num_employees") + 4
self.gmbh.save() self.gmbh.save()
self.gmbh.refresh_from_db() expected_num_queries = (
self.assertEqual(self.gmbh.num_employees, 36) 0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(self.gmbh.num_employees, 36)
def test_new_object_save(self): def test_new_object_save(self):
# We should be able to use Funcs when inserting new data # We should be able to use Funcs when inserting new data
@ -1644,8 +1647,11 @@ class ExpressionsNumericTests(TestCase):
n = Number.objects.create(integer=1, decimal_value=Decimal("0.5")) n = Number.objects.create(integer=1, decimal_value=Decimal("0.5"))
n.decimal_value = F("decimal_value") - Decimal("0.4") n.decimal_value = F("decimal_value") - Decimal("0.4")
n.save() n.save()
n.refresh_from_db() expected_num_queries = (
self.assertEqual(n.decimal_value, Decimal("0.1")) 0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(n.decimal_value, Decimal("0.1"))
class ExpressionOperatorTests(TestCase): class ExpressionOperatorTests(TestCase):

View File

@ -15,13 +15,7 @@ from django.db.models.expressions import (
) )
from django.db.models.functions import Collate from django.db.models.functions import Collate
from django.db.models.lookups import GreaterThan from django.db.models.lookups import GreaterThan
from django.test import ( from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature
SimpleTestCase,
TestCase,
override_settings,
skipIfDBFeature,
skipUnlessDBFeature,
)
from django.utils import timezone from django.utils import timezone
from .models import ( from .models import (
@ -44,47 +38,56 @@ class DefaultTests(TestCase):
self.assertEqual(a.headline, "Default headline") self.assertEqual(a.headline, "Default headline")
self.assertLess((now - a.pub_date).seconds, 5) self.assertLess((now - a.pub_date).seconds, 5)
@skipUnlessDBFeature( @skipUnlessDBFeature("supports_expression_defaults")
"can_return_columns_from_insert", "supports_expression_defaults"
)
def test_field_db_defaults_returning(self): def test_field_db_defaults_returning(self):
a = DBArticle() a = DBArticle()
a.save() a.save()
self.assertIsInstance(a.id, int) self.assertIsInstance(a.id, int)
self.assertEqual(a.headline, "Default headline") expected_num_queries = (
self.assertIsInstance(a.pub_date, datetime) 0 if connection.features.can_return_columns_from_insert else 3
self.assertEqual(a.cost, Decimal("3.33")) )
with self.assertNumQueries(expected_num_queries):
self.assertEqual(a.headline, "Default headline")
self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))
@skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults") @skipUnlessDBFeature("supports_expression_defaults")
def test_field_db_defaults_refresh(self): def test_field_db_defaults_refresh(self):
a = DBArticle() a = DBArticle()
a.save() a.save()
a.refresh_from_db() expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 3
)
self.assertIsInstance(a.id, int) self.assertIsInstance(a.id, int)
self.assertEqual(a.headline, "Default headline") with self.assertNumQueries(expected_num_queries):
self.assertIsInstance(a.pub_date, datetime) self.assertEqual(a.headline, "Default headline")
self.assertEqual(a.cost, Decimal("3.33")) self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))
def test_null_db_default(self): def test_null_db_default(self):
obj1 = DBDefaults.objects.create() obj1 = DBDefaults.objects.create()
if not connection.features.can_return_columns_from_insert: expected_num_queries = (
obj1.refresh_from_db() 0 if connection.features.can_return_columns_from_insert else 1
self.assertEqual(obj1.null, 1.1) )
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj1.null, 1.1)
obj2 = DBDefaults.objects.create(null=None) obj2 = DBDefaults.objects.create(null=None)
self.assertIsNone(obj2.null) with self.assertNumQueries(0):
self.assertIsNone(obj2.null)
@skipUnlessDBFeature("supports_expression_defaults") @skipUnlessDBFeature("supports_expression_defaults")
@override_settings(USE_TZ=True) @override_settings(USE_TZ=True)
def test_db_default_function(self): def test_db_default_function(self):
m = DBDefaultsFunction.objects.create() m = DBDefaultsFunction.objects.create()
if not connection.features.can_return_columns_from_insert: expected_num_queries = (
m.refresh_from_db() 0 if connection.features.can_return_columns_from_insert else 4
self.assertAlmostEqual(m.number, pi) )
self.assertEqual(m.year, timezone.now().year) with self.assertNumQueries(expected_num_queries):
self.assertAlmostEqual(m.added, pi + 4.5) self.assertAlmostEqual(m.number, pi)
self.assertEqual(m.multiple_subfunctions, 4.5) self.assertEqual(m.year, timezone.now().year)
self.assertAlmostEqual(m.added, pi + 4.5)
self.assertEqual(m.multiple_subfunctions, 4.5)
@skipUnlessDBFeature("insert_test_table_with_defaults") @skipUnlessDBFeature("insert_test_table_with_defaults")
def test_both_default(self): def test_both_default(self):
@ -125,14 +128,15 @@ class DefaultTests(TestCase):
child2 = DBDefaultsFK.objects.create(language_code=parent2) child2 = DBDefaultsFK.objects.create(language_code=parent2)
self.assertEqual(child2.language_code, parent2) self.assertEqual(child2.language_code, parent2)
@skipUnlessDBFeature( @skipUnlessDBFeature("supports_expression_defaults")
"can_return_columns_from_insert", "supports_expression_defaults"
)
def test_case_when_db_default_returning(self): def test_case_when_db_default_returning(self):
m = DBDefaultsFunction.objects.create() m = DBDefaultsFunction.objects.create()
self.assertEqual(m.case_when, 3) expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.case_when, 3)
@skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults") @skipUnlessDBFeature("supports_expression_defaults")
def test_case_when_db_default_no_returning(self): def test_case_when_db_default_no_returning(self):
m = DBDefaultsFunction.objects.create() m = DBDefaultsFunction.objects.create()

View File

@ -1,5 +1,6 @@
from django.core.exceptions import ObjectNotUpdated from django.core.exceptions import ObjectNotUpdated
from django.db import DatabaseError, transaction from django.db import DatabaseError, connection, transaction
from django.db.models import F
from django.db.models.signals import post_save, pre_save from django.db.models.signals import post_save, pre_save
from django.test import TestCase from django.test import TestCase
@ -308,3 +309,16 @@ class UpdateOnlyFieldsTests(TestCase):
transaction.atomic(), transaction.atomic(),
): ):
obj.save(update_fields=["name"]) obj.save(update_fields=["name"])
def test_update_fields_expression(self):
obj = Person.objects.create(name="Valerie", gender="F", pid=42)
updated_pid = F("pid") + 1
obj.pid = updated_pid
obj.save(update_fields={"gender"})
self.assertIs(obj.pid, updated_pid)
obj.save(update_fields={"pid"})
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj.pid, 43)