mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #28344 -- Allowed customizing queryset in Model.refresh_from_db()/arefresh_from_db().
The from_queryset parameter can be used to: - use a custom Manager - lock the row until the end of transaction - select additional related objects
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							f3d10546a8
						
					
				
				
					commit
					f92641a636
				
			| @@ -673,7 +673,7 @@ class Model(AltersData, metaclass=ModelBase): | ||||
|             if f.attname not in self.__dict__ | ||||
|         } | ||||
|  | ||||
|     def refresh_from_db(self, using=None, fields=None): | ||||
|     def refresh_from_db(self, using=None, fields=None, from_queryset=None): | ||||
|         """ | ||||
|         Reload field values from the database. | ||||
|  | ||||
| @@ -705,10 +705,13 @@ class Model(AltersData, metaclass=ModelBase): | ||||
|                     "are not allowed in fields." % LOOKUP_SEP | ||||
|                 ) | ||||
|  | ||||
|         if from_queryset is None: | ||||
|             hints = {"instance": self} | ||||
|         db_instance_qs = self.__class__._base_manager.db_manager( | ||||
|             using, hints=hints | ||||
|         ).filter(pk=self.pk) | ||||
|             from_queryset = self.__class__._base_manager.db_manager(using, hints=hints) | ||||
|         elif using is not None: | ||||
|             from_queryset = from_queryset.using(using) | ||||
|  | ||||
|         db_instance_qs = from_queryset.filter(pk=self.pk) | ||||
|  | ||||
|         # Use provided fields, if not set then reload all non-deferred fields. | ||||
|         deferred_fields = self.get_deferred_fields() | ||||
| @@ -729,8 +732,11 @@ class Model(AltersData, metaclass=ModelBase): | ||||
|                 # This field wasn't refreshed - skip ahead. | ||||
|                 continue | ||||
|             setattr(self, field.attname, getattr(db_instance, field.attname)) | ||||
|             # Clear cached foreign keys. | ||||
|             if field.is_relation and field.is_cached(self): | ||||
|             # Clear or copy cached foreign keys. | ||||
|             if field.is_relation: | ||||
|                 if field.is_cached(db_instance): | ||||
|                     field.set_cached_value(self, field.get_cached_value(db_instance)) | ||||
|                 elif field.is_cached(self): | ||||
|                     field.delete_cached_value(self) | ||||
|  | ||||
|         # Clear cached relations. | ||||
| @@ -745,8 +751,10 @@ class Model(AltersData, metaclass=ModelBase): | ||||
|  | ||||
|         self._state.db = db_instance._state.db | ||||
|  | ||||
|     async def arefresh_from_db(self, using=None, fields=None): | ||||
|         return await sync_to_async(self.refresh_from_db)(using=using, fields=fields) | ||||
|     async def arefresh_from_db(self, using=None, fields=None, from_queryset=None): | ||||
|         return await sync_to_async(self.refresh_from_db)( | ||||
|             using=using, fields=fields, from_queryset=from_queryset | ||||
|         ) | ||||
|  | ||||
|     def serializable_value(self, field_name): | ||||
|         """ | ||||
|   | ||||
| @@ -142,8 +142,8 @@ value from the database: | ||||
|     >>> del obj.field | ||||
|     >>> obj.field  # Loads the field from the database | ||||
|  | ||||
| .. method:: Model.refresh_from_db(using=None, fields=None) | ||||
| .. method:: Model.arefresh_from_db(using=None, fields=None) | ||||
| .. method:: Model.refresh_from_db(using=None, fields=None, from_queryset=None) | ||||
| .. method:: Model.arefresh_from_db(using=None, fields=None, from_queryset=None) | ||||
|  | ||||
| *Asynchronous version*: ``arefresh_from_db()`` | ||||
|  | ||||
| @@ -197,6 +197,27 @@ all of the instance's fields when a deferred field is reloaded:: | ||||
|                     fields = fields.union(deferred_fields) | ||||
|             super().refresh_from_db(using, fields, **kwargs) | ||||
|  | ||||
| The ``from_queryset`` argument allows using a different queryset than the one | ||||
| created from :attr:`~django.db.models.Model._base_manager`. It gives you more | ||||
| control over how the model is reloaded. For example, when your model uses soft | ||||
| deletion you can make ``refresh_from_db()`` to take this into account:: | ||||
|  | ||||
|     obj.refresh_from_db(from_queryset=MyModel.active_objects.all()) | ||||
|  | ||||
| You can cache related objects that otherwise would be cleared from the reloaded | ||||
| instance:: | ||||
|  | ||||
|     obj.refresh_from_db(from_queryset=MyModel.objects.select_related("related_field")) | ||||
|  | ||||
| You can lock the row until the end of transaction before reloading a model's | ||||
| values:: | ||||
|  | ||||
|     obj.refresh_from_db(from_queryset=MyModel.objects.select_for_update()) | ||||
|  | ||||
| .. versionchanged:: 5.1 | ||||
|  | ||||
|     The ``from_queryset`` argument was added. | ||||
|  | ||||
| .. method:: Model.get_deferred_fields() | ||||
|  | ||||
| A helper method that returns a set containing the attribute names of all those | ||||
|   | ||||
| @@ -208,6 +208,11 @@ Models | ||||
|   :class:`~django.contrib.postgres.fields.ArrayField` can now be :ref:`sliced | ||||
|   <slicing-using-f>`. | ||||
|  | ||||
| * The new ``from_queryset`` argument of :meth:`.Model.refresh_from_db` and | ||||
|   :meth:`.Model.arefresh_from_db`  allows customizing the queryset used to | ||||
|   reload a model's value. This can be used to lock the row before reloading or | ||||
|   to select related objects. | ||||
|  | ||||
| Requests and Responses | ||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
|   | ||||
| @@ -23,3 +23,14 @@ class AsyncModelOperationTest(TestCase): | ||||
|         await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20) | ||||
|         await self.s1.arefresh_from_db() | ||||
|         self.assertEqual(self.s1.field, 20) | ||||
|  | ||||
|     async def test_arefresh_from_db_from_queryset(self): | ||||
|         await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20) | ||||
|         with self.assertRaises(SimpleModel.DoesNotExist): | ||||
|             await self.s1.arefresh_from_db( | ||||
|                 from_queryset=SimpleModel.objects.filter(field=0) | ||||
|             ) | ||||
|         await self.s1.arefresh_from_db( | ||||
|             from_queryset=SimpleModel.objects.filter(field__gt=0) | ||||
|         ) | ||||
|         self.assertEqual(self.s1.field, 20) | ||||
|   | ||||
| @@ -4,7 +4,14 @@ from datetime import datetime, timedelta | ||||
| from unittest import mock | ||||
|  | ||||
| from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist | ||||
| from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models | ||||
| from django.db import ( | ||||
|     DEFAULT_DB_ALIAS, | ||||
|     DatabaseError, | ||||
|     connection, | ||||
|     connections, | ||||
|     models, | ||||
|     transaction, | ||||
| ) | ||||
| from django.db.models.manager import BaseManager | ||||
| from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet | ||||
| from django.test import ( | ||||
| @@ -13,7 +20,8 @@ from django.test import ( | ||||
|     TransactionTestCase, | ||||
|     skipUnlessDBFeature, | ||||
| ) | ||||
| from django.test.utils import ignore_warnings | ||||
| from django.test.utils import CaptureQueriesContext, ignore_warnings | ||||
| from django.utils.connection import ConnectionDoesNotExist | ||||
| from django.utils.deprecation import RemovedInDjango60Warning | ||||
| from django.utils.translation import gettext_lazy | ||||
|  | ||||
| @@ -1003,3 +1011,47 @@ class ModelRefreshTests(TestCase): | ||||
|         # Cache was cleared and new results are available. | ||||
|         self.assertCountEqual(a2_prefetched.selfref_set.all(), [s]) | ||||
|         self.assertCountEqual(a2_prefetched.cited.all(), [s]) | ||||
|  | ||||
|     @skipUnlessDBFeature("has_select_for_update") | ||||
|     def test_refresh_for_update(self): | ||||
|         a = Article.objects.create(pub_date=datetime.now()) | ||||
|         for_update_sql = connection.ops.for_update_sql() | ||||
|  | ||||
|         with transaction.atomic(), CaptureQueriesContext(connection) as ctx: | ||||
|             a.refresh_from_db(from_queryset=Article.objects.select_for_update()) | ||||
|         self.assertTrue( | ||||
|             any(for_update_sql in query["sql"] for query in ctx.captured_queries) | ||||
|         ) | ||||
|  | ||||
|     def test_refresh_with_related(self): | ||||
|         a = Article.objects.create(pub_date=datetime.now()) | ||||
|         fa = FeaturedArticle.objects.create(article=a) | ||||
|  | ||||
|         from_queryset = FeaturedArticle.objects.select_related("article") | ||||
|         with self.assertNumQueries(1): | ||||
|             fa.refresh_from_db(from_queryset=from_queryset) | ||||
|             self.assertEqual(fa.article.pub_date, a.pub_date) | ||||
|         with self.assertNumQueries(2): | ||||
|             fa.refresh_from_db() | ||||
|             self.assertEqual(fa.article.pub_date, a.pub_date) | ||||
|  | ||||
|     def test_refresh_overwrites_queryset_using(self): | ||||
|         a = Article.objects.create(pub_date=datetime.now()) | ||||
|  | ||||
|         from_queryset = Article.objects.using("nonexistent") | ||||
|         with self.assertRaises(ConnectionDoesNotExist): | ||||
|             a.refresh_from_db(from_queryset=from_queryset) | ||||
|         a.refresh_from_db(using="default", from_queryset=from_queryset) | ||||
|  | ||||
|     def test_refresh_overwrites_queryset_fields(self): | ||||
|         a = Article.objects.create(pub_date=datetime.now()) | ||||
|         headline = "headline" | ||||
|         Article.objects.filter(pk=a.pk).update(headline=headline) | ||||
|  | ||||
|         from_queryset = Article.objects.only("pub_date") | ||||
|         with self.assertNumQueries(1): | ||||
|             a.refresh_from_db(from_queryset=from_queryset) | ||||
|             self.assertNotEqual(a.headline, headline) | ||||
|         with self.assertNumQueries(1): | ||||
|             a.refresh_from_db(fields=["headline"], from_queryset=from_queryset) | ||||
|             self.assertEqual(a.headline, headline) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user