1
0
mirror of https://github.com/django/django.git synced 2025-09-15 21:49:24 +00:00

Fixed #27222 -- Refreshed model field values assigned expressions on save().

Removed the can_return_columns_from_insert skip gates on existing
field_defaults tests to confirm the expected number of queries are
performed and that returning field overrides are respected.
This commit is contained in:
Simon Charette 2025-03-19 01:39:19 -04:00 committed by Mariusz Felisiak
parent 55a0073b3b
commit 94680437a4
7 changed files with 123 additions and 63 deletions

View File

@ -1102,6 +1102,11 @@ class Model(AltersData, metaclass=ModelBase):
and f.referenced_fields.intersection(non_pks_non_generated)
)
]
for field, _model, value in values:
if (update_fields is None or field.name in update_fields) and hasattr(
value, "resolve_expression"
):
returning_fields.append(field)
results = self._do_update(
base_qs,
using,
@ -1142,7 +1147,15 @@ class Model(AltersData, metaclass=ModelBase):
for f in meta.local_concrete_fields
if not f.generated and (pk_set or f is not meta.auto_field)
]
returning_fields = meta.db_returning_fields
returning_fields = list(meta.db_returning_fields)
for field in fields:
value = (
getattr(self, field.attname) if raw else field.pre_save(self, False)
)
if hasattr(value, "resolve_expression"):
returning_fields.append(field)
elif field.db_returning:
returning_fields.remove(field)
results = self._do_insert(
cls._base_manager, using, fields, returning_fields, raw
)
@ -1203,8 +1216,13 @@ class Model(AltersData, metaclass=ModelBase):
)
def _assign_returned_values(self, returned_values, returning_fields):
for value, field in zip(returned_values, returning_fields):
returning_fields_iter = iter(returning_fields)
for value, field in zip(returned_values, returning_fields_iter):
setattr(self, field.attname, value)
# Defer all fields that were meant to be updated with their database
# resolved values but couldn't as they are effectively stale.
for field in returning_fields_iter:
self.__dict__.pop(field.attname, None)
def _prepare_related_fields_for_save(self, operation_name, fields=None):
# Ensure that a model instance without a PK hasn't been assigned to

View File

@ -69,8 +69,6 @@ Some examples
# Create a new company using expressions.
>>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog")))
# Be sure to refresh it if you need to access the field.
>>> company.refresh_from_db()
>>> company.ticker
'GOOG'
@ -157,12 +155,6 @@ know about it - it is dealt with entirely by the database. All Python does,
through Django's ``F()`` class, is create the SQL syntax to refer to the field
and describe the operation.
To access the new value saved this way, the object must be reloaded::
reporter = Reporters.objects.get(pk=reporter.pk)
# Or, more succinctly:
reporter.refresh_from_db()
As well as being used in operations on single instances as above, ``F()`` can
be used with ``update()`` to perform bulk updates on a ``QuerySet``. This
reduces the two queries we were using above - the ``get()`` and the
@ -199,7 +191,6 @@ array-slicing syntax. The indices are 0-based and the ``step`` argument to
>>> writer = Writers.objects.get(name="Priyansh")
>>> writer.name = F("name")[1:5]
>>> writer.save()
>>> writer.refresh_from_db()
>>> writer.name
'riya'
@ -221,23 +212,27 @@ robust: it will only ever update the field based on the value of the field in
the database when the :meth:`~Model.save` or ``update()`` is executed, rather
than based on its value when the instance was retrieved.
``F()`` assignments persist after ``Model.save()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``F()`` assignments are refreshed after ``Model.save()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``F()`` objects assigned to model fields persist after saving the model
instance and will be applied on each :meth:`~Model.save`. For example::
``F()`` objects assigned to model fields are refreshed from the database on
:meth:`~Model.save` on backends that support it without incurring a subsequent
query (SQLite, PostgreSQL, and Oracle) and deferred otherwise (MySQL or
MariaDB). For example:
reporter = Reporters.objects.get(name="Tintin")
reporter.stories_filed = F("stories_filed") + 1
reporter.save()
.. code-block:: pycon
reporter.name = "Tintin Jr."
reporter.save()
>>> reporter = Reporters.objects.get(name="Tintin")
>>> reporter.stories_filed = F("stories_filed") + 1
>>> reporter.save()
>>> reporter.stories_filed # This triggers a refresh query on MySQL/MariaDB.
14 # Assuming the database value was 13 when the object was saved.
``stories_filed`` will be updated twice in this case. If it's initially ``1``,
the final value will be ``3``. This persistence can be avoided by reloading the
model object after saving it, for example, by using
:meth:`~Model.refresh_from_db`.
.. versionchanged:: 6.0
In previous versions of Django, ``F()`` objects were not refreshed from the
database on :meth:`~Model.save` which resulted in them being evaluated and
persisted every time the instance was saved.
Using ``F()`` in filters
~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -331,7 +331,8 @@ Models
value from the non-null input values. This is supported on SQLite, MySQL,
Oracle, and PostgreSQL 16+.
* :class:`~django.db.models.GeneratedField`\s are now refreshed from the
* :class:`~django.db.models.GeneratedField`\s and :ref:`fields assigned
expressions <avoiding-race-conditions-using-f>` are now refreshed from the
database after :meth:`~django.db.models.Model.save` on backends that support
the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that
don't support it (MySQL and MariaDB), the fields are marked as deferred to

View File

@ -1,5 +1,6 @@
import inspect
import threading
import time
from datetime import datetime, timedelta
from unittest import mock
@ -12,6 +13,7 @@ from django.db import (
models,
transaction,
)
from django.db.models.functions import Now
from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
from django.test import (
@ -558,6 +560,26 @@ class ModelTest(TestCase):
with self.subTest(case=case):
self.assertIs(case._is_pk_set(), True)
def test_save_expressions(self):
article = Article(pub_date=Now())
article.save()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
article_pub_date = article.pub_date
self.assertIsInstance(article_pub_date, datetime)
# Sleep slightly to ensure a different database level NOW().
time.sleep(0.1)
article.pub_date = Now()
article.save()
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertIsInstance(article.pub_date, datetime)
self.assertGreater(article.pub_date, article_pub_date)
class ModelLookupTest(TestCase):
@classmethod

View File

@ -420,7 +420,10 @@ class BasicExpressionsTests(TestCase):
# F expressions can be used to update attributes on single objects
self.gmbh.num_employees = F("num_employees") + 4
self.gmbh.save()
self.gmbh.refresh_from_db()
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(self.gmbh.num_employees, 36)
def test_new_object_save(self):
@ -1644,7 +1647,10 @@ class ExpressionsNumericTests(TestCase):
n = Number.objects.create(integer=1, decimal_value=Decimal("0.5"))
n.decimal_value = F("decimal_value") - Decimal("0.4")
n.save()
n.refresh_from_db()
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(n.decimal_value, Decimal("0.1"))

View File

@ -15,13 +15,7 @@ from django.db.models.expressions import (
)
from django.db.models.functions import Collate
from django.db.models.lookups import GreaterThan
from django.test import (
SimpleTestCase,
TestCase,
override_settings,
skipIfDBFeature,
skipUnlessDBFeature,
)
from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature
from django.utils import timezone
from .models import (
@ -44,43 +38,52 @@ class DefaultTests(TestCase):
self.assertEqual(a.headline, "Default headline")
self.assertLess((now - a.pub_date).seconds, 5)
@skipUnlessDBFeature(
"can_return_columns_from_insert", "supports_expression_defaults"
)
@skipUnlessDBFeature("supports_expression_defaults")
def test_field_db_defaults_returning(self):
a = DBArticle()
a.save()
self.assertIsInstance(a.id, int)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 3
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(a.headline, "Default headline")
self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))
@skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults")
def test_field_db_defaults_refresh(self):
a = DBArticle()
a.save()
a.refresh_from_db()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 3
)
self.assertIsInstance(a.id, int)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(a.headline, "Default headline")
self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))
def test_null_db_default(self):
obj1 = DBDefaults.objects.create()
if not connection.features.can_return_columns_from_insert:
obj1.refresh_from_db()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj1.null, 1.1)
obj2 = DBDefaults.objects.create(null=None)
with self.assertNumQueries(0):
self.assertIsNone(obj2.null)
@skipUnlessDBFeature("supports_expression_defaults")
@override_settings(USE_TZ=True)
def test_db_default_function(self):
m = DBDefaultsFunction.objects.create()
if not connection.features.can_return_columns_from_insert:
m.refresh_from_db()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 4
)
with self.assertNumQueries(expected_num_queries):
self.assertAlmostEqual(m.number, pi)
self.assertEqual(m.year, timezone.now().year)
self.assertAlmostEqual(m.added, pi + 4.5)
@ -125,14 +128,15 @@ class DefaultTests(TestCase):
child2 = DBDefaultsFK.objects.create(language_code=parent2)
self.assertEqual(child2.language_code, parent2)
@skipUnlessDBFeature(
"can_return_columns_from_insert", "supports_expression_defaults"
)
@skipUnlessDBFeature("supports_expression_defaults")
def test_case_when_db_default_returning(self):
m = DBDefaultsFunction.objects.create()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.case_when, 3)
@skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults")
def test_case_when_db_default_no_returning(self):
m = DBDefaultsFunction.objects.create()

View File

@ -1,5 +1,6 @@
from django.core.exceptions import ObjectNotUpdated
from django.db import DatabaseError, transaction
from django.db import DatabaseError, connection, transaction
from django.db.models import F
from django.db.models.signals import post_save, pre_save
from django.test import TestCase
@ -308,3 +309,16 @@ class UpdateOnlyFieldsTests(TestCase):
transaction.atomic(),
):
obj.save(update_fields=["name"])
def test_update_fields_expression(self):
obj = Person.objects.create(name="Valerie", gender="F", pid=42)
updated_pid = F("pid") + 1
obj.pid = updated_pid
obj.save(update_fields={"gender"})
self.assertIs(obj.pid, updated_pid)
obj.save(update_fields={"pid"})
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj.pid, 43)