diff --git a/django/db/models/base.py b/django/db/models/base.py index 93e53bde95..518cfc44a2 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1102,6 +1102,11 @@ class Model(AltersData, metaclass=ModelBase): and f.referenced_fields.intersection(non_pks_non_generated) ) ] + for field, _model, value in values: + if (update_fields is None or field.name in update_fields) and hasattr( + value, "resolve_expression" + ): + returning_fields.append(field) results = self._do_update( base_qs, using, @@ -1142,7 +1147,15 @@ class Model(AltersData, metaclass=ModelBase): for f in meta.local_concrete_fields if not f.generated and (pk_set or f is not meta.auto_field) ] - returning_fields = meta.db_returning_fields + returning_fields = list(meta.db_returning_fields) + for field in fields: + value = ( + getattr(self, field.attname) if raw else field.pre_save(self, False) + ) + if hasattr(value, "resolve_expression"): + returning_fields.append(field) + elif field.db_returning: + returning_fields.remove(field) results = self._do_insert( cls._base_manager, using, fields, returning_fields, raw ) @@ -1203,8 +1216,13 @@ class Model(AltersData, metaclass=ModelBase): ) def _assign_returned_values(self, returned_values, returning_fields): - for value, field in zip(returned_values, returning_fields): + returning_fields_iter = iter(returning_fields) + for value, field in zip(returned_values, returning_fields_iter): setattr(self, field.attname, value) + # Defer all fields that were meant to be updated with their database + # resolved values but couldn't as they are effectively stale. + for field in returning_fields_iter: + self.__dict__.pop(field.attname, None) def _prepare_related_fields_for_save(self, operation_name, fields=None): # Ensure that a model instance without a PK hasn't been assigned to diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 77e8b165da..a1b8984a9b 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -69,8 +69,6 @@ Some examples # Create a new company using expressions. >>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog"))) - # Be sure to refresh it if you need to access the field. - >>> company.refresh_from_db() >>> company.ticker 'GOOG' @@ -157,12 +155,6 @@ know about it - it is dealt with entirely by the database. All Python does, through Django's ``F()`` class, is create the SQL syntax to refer to the field and describe the operation. -To access the new value saved this way, the object must be reloaded:: - - reporter = Reporters.objects.get(pk=reporter.pk) - # Or, more succinctly: - reporter.refresh_from_db() - As well as being used in operations on single instances as above, ``F()`` can be used with ``update()`` to perform bulk updates on a ``QuerySet``. This reduces the two queries we were using above - the ``get()`` and the @@ -199,7 +191,6 @@ array-slicing syntax. The indices are 0-based and the ``step`` argument to >>> writer = Writers.objects.get(name="Priyansh") >>> writer.name = F("name")[1:5] >>> writer.save() - >>> writer.refresh_from_db() >>> writer.name 'riya' @@ -221,23 +212,27 @@ robust: it will only ever update the field based on the value of the field in the database when the :meth:`~Model.save` or ``update()`` is executed, rather than based on its value when the instance was retrieved. -``F()`` assignments persist after ``Model.save()`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``F()`` assignments are refreshed after ``Model.save()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -``F()`` objects assigned to model fields persist after saving the model -instance and will be applied on each :meth:`~Model.save`. For example:: +``F()`` objects assigned to model fields are refreshed from the database on +:meth:`~Model.save` on backends that support it without incurring a subsequent +query (SQLite, PostgreSQL, and Oracle) and deferred otherwise (MySQL or +MariaDB). For example: - reporter = Reporters.objects.get(name="Tintin") - reporter.stories_filed = F("stories_filed") + 1 - reporter.save() +.. code-block:: pycon - reporter.name = "Tintin Jr." - reporter.save() + >>> reporter = Reporters.objects.get(name="Tintin") + >>> reporter.stories_filed = F("stories_filed") + 1 + >>> reporter.save() + >>> reporter.stories_filed # This triggers a refresh query on MySQL/MariaDB. + 14 # Assuming the database value was 13 when the object was saved. -``stories_filed`` will be updated twice in this case. If it's initially ``1``, -the final value will be ``3``. This persistence can be avoided by reloading the -model object after saving it, for example, by using -:meth:`~Model.refresh_from_db`. +.. versionchanged:: 6.0 + + In previous versions of Django, ``F()`` objects were not refreshed from the + database on :meth:`~Model.save` which resulted in them being evaluated and + persisted every time the instance was saved. Using ``F()`` in filters ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/releases/6.0.txt b/docs/releases/6.0.txt index 0d5565176f..adfac83b8d 100644 --- a/docs/releases/6.0.txt +++ b/docs/releases/6.0.txt @@ -331,7 +331,8 @@ Models value from the non-null input values. This is supported on SQLite, MySQL, Oracle, and PostgreSQL 16+. -* :class:`~django.db.models.GeneratedField`\s are now refreshed from the +* :class:`~django.db.models.GeneratedField`\s and :ref:`fields assigned + expressions ` are now refreshed from the database after :meth:`~django.db.models.Model.save` on backends that support the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that don't support it (MySQL and MariaDB), the fields are marked as deferred to diff --git a/tests/basic/tests.py b/tests/basic/tests.py index f8ec2715f6..38e7278210 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -1,5 +1,6 @@ import inspect import threading +import time from datetime import datetime, timedelta from unittest import mock @@ -12,6 +13,7 @@ from django.db import ( models, transaction, ) +from django.db.models.functions import Now from django.db.models.manager import BaseManager from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet from django.test import ( @@ -558,6 +560,26 @@ class ModelTest(TestCase): with self.subTest(case=case): self.assertIs(case._is_pk_set(), True) + def test_save_expressions(self): + article = Article(pub_date=Now()) + article.save() + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + article_pub_date = article.pub_date + self.assertIsInstance(article_pub_date, datetime) + # Sleep slightly to ensure a different database level NOW(). + time.sleep(0.1) + article.pub_date = Now() + article.save() + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertIsInstance(article.pub_date, datetime) + self.assertGreater(article.pub_date, article_pub_date) + class ModelLookupTest(TestCase): @classmethod diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 27d88be621..6f18321aa7 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -420,8 +420,11 @@ class BasicExpressionsTests(TestCase): # F expressions can be used to update attributes on single objects self.gmbh.num_employees = F("num_employees") + 4 self.gmbh.save() - self.gmbh.refresh_from_db() - self.assertEqual(self.gmbh.num_employees, 36) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(self.gmbh.num_employees, 36) def test_new_object_save(self): # We should be able to use Funcs when inserting new data @@ -1644,8 +1647,11 @@ class ExpressionsNumericTests(TestCase): n = Number.objects.create(integer=1, decimal_value=Decimal("0.5")) n.decimal_value = F("decimal_value") - Decimal("0.4") n.save() - n.refresh_from_db() - self.assertEqual(n.decimal_value, Decimal("0.1")) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(n.decimal_value, Decimal("0.1")) class ExpressionOperatorTests(TestCase): diff --git a/tests/field_defaults/tests.py b/tests/field_defaults/tests.py index e914adfc51..7f85d946f6 100644 --- a/tests/field_defaults/tests.py +++ b/tests/field_defaults/tests.py @@ -15,13 +15,7 @@ from django.db.models.expressions import ( ) from django.db.models.functions import Collate from django.db.models.lookups import GreaterThan -from django.test import ( - SimpleTestCase, - TestCase, - override_settings, - skipIfDBFeature, - skipUnlessDBFeature, -) +from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature from django.utils import timezone from .models import ( @@ -44,47 +38,56 @@ class DefaultTests(TestCase): self.assertEqual(a.headline, "Default headline") self.assertLess((now - a.pub_date).seconds, 5) - @skipUnlessDBFeature( - "can_return_columns_from_insert", "supports_expression_defaults" - ) + @skipUnlessDBFeature("supports_expression_defaults") def test_field_db_defaults_returning(self): a = DBArticle() a.save() self.assertIsInstance(a.id, int) - self.assertEqual(a.headline, "Default headline") - self.assertIsInstance(a.pub_date, datetime) - self.assertEqual(a.cost, Decimal("3.33")) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 3 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + self.assertEqual(a.cost, Decimal("3.33")) - @skipIfDBFeature("can_return_columns_from_insert") @skipUnlessDBFeature("supports_expression_defaults") def test_field_db_defaults_refresh(self): a = DBArticle() a.save() - a.refresh_from_db() + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 3 + ) self.assertIsInstance(a.id, int) - self.assertEqual(a.headline, "Default headline") - self.assertIsInstance(a.pub_date, datetime) - self.assertEqual(a.cost, Decimal("3.33")) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + self.assertEqual(a.cost, Decimal("3.33")) def test_null_db_default(self): obj1 = DBDefaults.objects.create() - if not connection.features.can_return_columns_from_insert: - obj1.refresh_from_db() - self.assertEqual(obj1.null, 1.1) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj1.null, 1.1) obj2 = DBDefaults.objects.create(null=None) - self.assertIsNone(obj2.null) + with self.assertNumQueries(0): + self.assertIsNone(obj2.null) @skipUnlessDBFeature("supports_expression_defaults") @override_settings(USE_TZ=True) def test_db_default_function(self): m = DBDefaultsFunction.objects.create() - if not connection.features.can_return_columns_from_insert: - m.refresh_from_db() - self.assertAlmostEqual(m.number, pi) - self.assertEqual(m.year, timezone.now().year) - self.assertAlmostEqual(m.added, pi + 4.5) - self.assertEqual(m.multiple_subfunctions, 4.5) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 4 + ) + with self.assertNumQueries(expected_num_queries): + self.assertAlmostEqual(m.number, pi) + self.assertEqual(m.year, timezone.now().year) + self.assertAlmostEqual(m.added, pi + 4.5) + self.assertEqual(m.multiple_subfunctions, 4.5) @skipUnlessDBFeature("insert_test_table_with_defaults") def test_both_default(self): @@ -125,14 +128,15 @@ class DefaultTests(TestCase): child2 = DBDefaultsFK.objects.create(language_code=parent2) self.assertEqual(child2.language_code, parent2) - @skipUnlessDBFeature( - "can_return_columns_from_insert", "supports_expression_defaults" - ) + @skipUnlessDBFeature("supports_expression_defaults") def test_case_when_db_default_returning(self): m = DBDefaultsFunction.objects.create() - self.assertEqual(m.case_when, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.case_when, 3) - @skipIfDBFeature("can_return_columns_from_insert") @skipUnlessDBFeature("supports_expression_defaults") def test_case_when_db_default_no_returning(self): m = DBDefaultsFunction.objects.create() diff --git a/tests/update_only_fields/tests.py b/tests/update_only_fields/tests.py index 9595c767eb..1c7ef88832 100644 --- a/tests/update_only_fields/tests.py +++ b/tests/update_only_fields/tests.py @@ -1,5 +1,6 @@ from django.core.exceptions import ObjectNotUpdated -from django.db import DatabaseError, transaction +from django.db import DatabaseError, connection, transaction +from django.db.models import F from django.db.models.signals import post_save, pre_save from django.test import TestCase @@ -308,3 +309,16 @@ class UpdateOnlyFieldsTests(TestCase): transaction.atomic(), ): obj.save(update_fields=["name"]) + + def test_update_fields_expression(self): + obj = Person.objects.create(name="Valerie", gender="F", pid=42) + updated_pid = F("pid") + 1 + obj.pid = updated_pid + obj.save(update_fields={"gender"}) + self.assertIs(obj.pid, updated_pid) + obj.save(update_fields={"pid"}) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj.pid, 43)