mirror of
https://github.com/django/django.git
synced 2025-09-15 21:49:24 +00:00
Fixed #27222 -- Refreshed model field values assigned expressions on save().
Removed the can_return_columns_from_insert skip gates on existing field_defaults tests to confirm the expected number of queries are performed and that returning field overrides are respected.
This commit is contained in:
parent
55a0073b3b
commit
94680437a4
@ -1102,6 +1102,11 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
and f.referenced_fields.intersection(non_pks_non_generated)
|
||||
)
|
||||
]
|
||||
for field, _model, value in values:
|
||||
if (update_fields is None or field.name in update_fields) and hasattr(
|
||||
value, "resolve_expression"
|
||||
):
|
||||
returning_fields.append(field)
|
||||
results = self._do_update(
|
||||
base_qs,
|
||||
using,
|
||||
@ -1142,7 +1147,15 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
for f in meta.local_concrete_fields
|
||||
if not f.generated and (pk_set or f is not meta.auto_field)
|
||||
]
|
||||
returning_fields = meta.db_returning_fields
|
||||
returning_fields = list(meta.db_returning_fields)
|
||||
for field in fields:
|
||||
value = (
|
||||
getattr(self, field.attname) if raw else field.pre_save(self, False)
|
||||
)
|
||||
if hasattr(value, "resolve_expression"):
|
||||
returning_fields.append(field)
|
||||
elif field.db_returning:
|
||||
returning_fields.remove(field)
|
||||
results = self._do_insert(
|
||||
cls._base_manager, using, fields, returning_fields, raw
|
||||
)
|
||||
@ -1203,8 +1216,13 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
)
|
||||
|
||||
def _assign_returned_values(self, returned_values, returning_fields):
|
||||
for value, field in zip(returned_values, returning_fields):
|
||||
returning_fields_iter = iter(returning_fields)
|
||||
for value, field in zip(returned_values, returning_fields_iter):
|
||||
setattr(self, field.attname, value)
|
||||
# Defer all fields that were meant to be updated with their database
|
||||
# resolved values but couldn't as they are effectively stale.
|
||||
for field in returning_fields_iter:
|
||||
self.__dict__.pop(field.attname, None)
|
||||
|
||||
def _prepare_related_fields_for_save(self, operation_name, fields=None):
|
||||
# Ensure that a model instance without a PK hasn't been assigned to
|
||||
|
@ -69,8 +69,6 @@ Some examples
|
||||
|
||||
# Create a new company using expressions.
|
||||
>>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog")))
|
||||
# Be sure to refresh it if you need to access the field.
|
||||
>>> company.refresh_from_db()
|
||||
>>> company.ticker
|
||||
'GOOG'
|
||||
|
||||
@ -157,12 +155,6 @@ know about it - it is dealt with entirely by the database. All Python does,
|
||||
through Django's ``F()`` class, is create the SQL syntax to refer to the field
|
||||
and describe the operation.
|
||||
|
||||
To access the new value saved this way, the object must be reloaded::
|
||||
|
||||
reporter = Reporters.objects.get(pk=reporter.pk)
|
||||
# Or, more succinctly:
|
||||
reporter.refresh_from_db()
|
||||
|
||||
As well as being used in operations on single instances as above, ``F()`` can
|
||||
be used with ``update()`` to perform bulk updates on a ``QuerySet``. This
|
||||
reduces the two queries we were using above - the ``get()`` and the
|
||||
@ -199,7 +191,6 @@ array-slicing syntax. The indices are 0-based and the ``step`` argument to
|
||||
>>> writer = Writers.objects.get(name="Priyansh")
|
||||
>>> writer.name = F("name")[1:5]
|
||||
>>> writer.save()
|
||||
>>> writer.refresh_from_db()
|
||||
>>> writer.name
|
||||
'riya'
|
||||
|
||||
@ -221,23 +212,27 @@ robust: it will only ever update the field based on the value of the field in
|
||||
the database when the :meth:`~Model.save` or ``update()`` is executed, rather
|
||||
than based on its value when the instance was retrieved.
|
||||
|
||||
``F()`` assignments persist after ``Model.save()``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
``F()`` assignments are refreshed after ``Model.save()``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
``F()`` objects assigned to model fields persist after saving the model
|
||||
instance and will be applied on each :meth:`~Model.save`. For example::
|
||||
``F()`` objects assigned to model fields are refreshed from the database on
|
||||
:meth:`~Model.save` on backends that support it without incurring a subsequent
|
||||
query (SQLite, PostgreSQL, and Oracle) and deferred otherwise (MySQL or
|
||||
MariaDB). For example:
|
||||
|
||||
reporter = Reporters.objects.get(name="Tintin")
|
||||
reporter.stories_filed = F("stories_filed") + 1
|
||||
reporter.save()
|
||||
.. code-block:: pycon
|
||||
|
||||
reporter.name = "Tintin Jr."
|
||||
reporter.save()
|
||||
>>> reporter = Reporters.objects.get(name="Tintin")
|
||||
>>> reporter.stories_filed = F("stories_filed") + 1
|
||||
>>> reporter.save()
|
||||
>>> reporter.stories_filed # This triggers a refresh query on MySQL/MariaDB.
|
||||
14 # Assuming the database value was 13 when the object was saved.
|
||||
|
||||
``stories_filed`` will be updated twice in this case. If it's initially ``1``,
|
||||
the final value will be ``3``. This persistence can be avoided by reloading the
|
||||
model object after saving it, for example, by using
|
||||
:meth:`~Model.refresh_from_db`.
|
||||
.. versionchanged:: 6.0
|
||||
|
||||
In previous versions of Django, ``F()`` objects were not refreshed from the
|
||||
database on :meth:`~Model.save` which resulted in them being evaluated and
|
||||
persisted every time the instance was saved.
|
||||
|
||||
Using ``F()`` in filters
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -331,7 +331,8 @@ Models
|
||||
value from the non-null input values. This is supported on SQLite, MySQL,
|
||||
Oracle, and PostgreSQL 16+.
|
||||
|
||||
* :class:`~django.db.models.GeneratedField`\s are now refreshed from the
|
||||
* :class:`~django.db.models.GeneratedField`\s and :ref:`fields assigned
|
||||
expressions <avoiding-race-conditions-using-f>` are now refreshed from the
|
||||
database after :meth:`~django.db.models.Model.save` on backends that support
|
||||
the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that
|
||||
don't support it (MySQL and MariaDB), the fields are marked as deferred to
|
||||
|
@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest import mock
|
||||
|
||||
@ -12,6 +13,7 @@ from django.db import (
|
||||
models,
|
||||
transaction,
|
||||
)
|
||||
from django.db.models.functions import Now
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
|
||||
from django.test import (
|
||||
@ -558,6 +560,26 @@ class ModelTest(TestCase):
|
||||
with self.subTest(case=case):
|
||||
self.assertIs(case._is_pk_set(), True)
|
||||
|
||||
def test_save_expressions(self):
|
||||
article = Article(pub_date=Now())
|
||||
article.save()
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_columns_from_insert else 1
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
article_pub_date = article.pub_date
|
||||
self.assertIsInstance(article_pub_date, datetime)
|
||||
# Sleep slightly to ensure a different database level NOW().
|
||||
time.sleep(0.1)
|
||||
article.pub_date = Now()
|
||||
article.save()
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_rows_from_update else 1
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertIsInstance(article.pub_date, datetime)
|
||||
self.assertGreater(article.pub_date, article_pub_date)
|
||||
|
||||
|
||||
class ModelLookupTest(TestCase):
|
||||
@classmethod
|
||||
|
@ -420,8 +420,11 @@ class BasicExpressionsTests(TestCase):
|
||||
# F expressions can be used to update attributes on single objects
|
||||
self.gmbh.num_employees = F("num_employees") + 4
|
||||
self.gmbh.save()
|
||||
self.gmbh.refresh_from_db()
|
||||
self.assertEqual(self.gmbh.num_employees, 36)
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_rows_from_update else 1
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(self.gmbh.num_employees, 36)
|
||||
|
||||
def test_new_object_save(self):
|
||||
# We should be able to use Funcs when inserting new data
|
||||
@ -1644,8 +1647,11 @@ class ExpressionsNumericTests(TestCase):
|
||||
n = Number.objects.create(integer=1, decimal_value=Decimal("0.5"))
|
||||
n.decimal_value = F("decimal_value") - Decimal("0.4")
|
||||
n.save()
|
||||
n.refresh_from_db()
|
||||
self.assertEqual(n.decimal_value, Decimal("0.1"))
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_rows_from_update else 1
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(n.decimal_value, Decimal("0.1"))
|
||||
|
||||
|
||||
class ExpressionOperatorTests(TestCase):
|
||||
|
@ -15,13 +15,7 @@ from django.db.models.expressions import (
|
||||
)
|
||||
from django.db.models.functions import Collate
|
||||
from django.db.models.lookups import GreaterThan
|
||||
from django.test import (
|
||||
SimpleTestCase,
|
||||
TestCase,
|
||||
override_settings,
|
||||
skipIfDBFeature,
|
||||
skipUnlessDBFeature,
|
||||
)
|
||||
from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature
|
||||
from django.utils import timezone
|
||||
|
||||
from .models import (
|
||||
@ -44,47 +38,56 @@ class DefaultTests(TestCase):
|
||||
self.assertEqual(a.headline, "Default headline")
|
||||
self.assertLess((now - a.pub_date).seconds, 5)
|
||||
|
||||
@skipUnlessDBFeature(
|
||||
"can_return_columns_from_insert", "supports_expression_defaults"
|
||||
)
|
||||
@skipUnlessDBFeature("supports_expression_defaults")
|
||||
def test_field_db_defaults_returning(self):
|
||||
a = DBArticle()
|
||||
a.save()
|
||||
self.assertIsInstance(a.id, int)
|
||||
self.assertEqual(a.headline, "Default headline")
|
||||
self.assertIsInstance(a.pub_date, datetime)
|
||||
self.assertEqual(a.cost, Decimal("3.33"))
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_columns_from_insert else 3
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(a.headline, "Default headline")
|
||||
self.assertIsInstance(a.pub_date, datetime)
|
||||
self.assertEqual(a.cost, Decimal("3.33"))
|
||||
|
||||
@skipIfDBFeature("can_return_columns_from_insert")
|
||||
@skipUnlessDBFeature("supports_expression_defaults")
|
||||
def test_field_db_defaults_refresh(self):
|
||||
a = DBArticle()
|
||||
a.save()
|
||||
a.refresh_from_db()
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_columns_from_insert else 3
|
||||
)
|
||||
self.assertIsInstance(a.id, int)
|
||||
self.assertEqual(a.headline, "Default headline")
|
||||
self.assertIsInstance(a.pub_date, datetime)
|
||||
self.assertEqual(a.cost, Decimal("3.33"))
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(a.headline, "Default headline")
|
||||
self.assertIsInstance(a.pub_date, datetime)
|
||||
self.assertEqual(a.cost, Decimal("3.33"))
|
||||
|
||||
def test_null_db_default(self):
|
||||
obj1 = DBDefaults.objects.create()
|
||||
if not connection.features.can_return_columns_from_insert:
|
||||
obj1.refresh_from_db()
|
||||
self.assertEqual(obj1.null, 1.1)
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_columns_from_insert else 1
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(obj1.null, 1.1)
|
||||
|
||||
obj2 = DBDefaults.objects.create(null=None)
|
||||
self.assertIsNone(obj2.null)
|
||||
with self.assertNumQueries(0):
|
||||
self.assertIsNone(obj2.null)
|
||||
|
||||
@skipUnlessDBFeature("supports_expression_defaults")
|
||||
@override_settings(USE_TZ=True)
|
||||
def test_db_default_function(self):
|
||||
m = DBDefaultsFunction.objects.create()
|
||||
if not connection.features.can_return_columns_from_insert:
|
||||
m.refresh_from_db()
|
||||
self.assertAlmostEqual(m.number, pi)
|
||||
self.assertEqual(m.year, timezone.now().year)
|
||||
self.assertAlmostEqual(m.added, pi + 4.5)
|
||||
self.assertEqual(m.multiple_subfunctions, 4.5)
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_columns_from_insert else 4
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertAlmostEqual(m.number, pi)
|
||||
self.assertEqual(m.year, timezone.now().year)
|
||||
self.assertAlmostEqual(m.added, pi + 4.5)
|
||||
self.assertEqual(m.multiple_subfunctions, 4.5)
|
||||
|
||||
@skipUnlessDBFeature("insert_test_table_with_defaults")
|
||||
def test_both_default(self):
|
||||
@ -125,14 +128,15 @@ class DefaultTests(TestCase):
|
||||
child2 = DBDefaultsFK.objects.create(language_code=parent2)
|
||||
self.assertEqual(child2.language_code, parent2)
|
||||
|
||||
@skipUnlessDBFeature(
|
||||
"can_return_columns_from_insert", "supports_expression_defaults"
|
||||
)
|
||||
@skipUnlessDBFeature("supports_expression_defaults")
|
||||
def test_case_when_db_default_returning(self):
|
||||
m = DBDefaultsFunction.objects.create()
|
||||
self.assertEqual(m.case_when, 3)
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_columns_from_insert else 1
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(m.case_when, 3)
|
||||
|
||||
@skipIfDBFeature("can_return_columns_from_insert")
|
||||
@skipUnlessDBFeature("supports_expression_defaults")
|
||||
def test_case_when_db_default_no_returning(self):
|
||||
m = DBDefaultsFunction.objects.create()
|
||||
|
@ -1,5 +1,6 @@
|
||||
from django.core.exceptions import ObjectNotUpdated
|
||||
from django.db import DatabaseError, transaction
|
||||
from django.db import DatabaseError, connection, transaction
|
||||
from django.db.models import F
|
||||
from django.db.models.signals import post_save, pre_save
|
||||
from django.test import TestCase
|
||||
|
||||
@ -308,3 +309,16 @@ class UpdateOnlyFieldsTests(TestCase):
|
||||
transaction.atomic(),
|
||||
):
|
||||
obj.save(update_fields=["name"])
|
||||
|
||||
def test_update_fields_expression(self):
|
||||
obj = Person.objects.create(name="Valerie", gender="F", pid=42)
|
||||
updated_pid = F("pid") + 1
|
||||
obj.pid = updated_pid
|
||||
obj.save(update_fields={"gender"})
|
||||
self.assertIs(obj.pid, updated_pid)
|
||||
obj.save(update_fields={"pid"})
|
||||
expected_num_queries = (
|
||||
0 if connection.features.can_return_rows_from_update else 1
|
||||
)
|
||||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(obj.pid, 43)
|
||||
|
Loading…
x
Reference in New Issue
Block a user