mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #34139 -- Fixed acreate(), aget_or_create(), and aupdate_or_create() methods for related managers.
Bug in 58b27e0dbb.
			
			
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							76e37513e2
						
					
				
				
					commit
					7b94847e38
				
			
							
								
								
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -495,6 +495,7 @@ answer newbie questions, and generally made Django that much better: | |||||||
|     John Shaffer <jshaffer2112@gmail.com> |     John Shaffer <jshaffer2112@gmail.com> | ||||||
|     Jökull Sólberg Auðunsson <jokullsolberg@gmail.com> |     Jökull Sólberg Auðunsson <jokullsolberg@gmail.com> | ||||||
|     Jon Dufresne <jon.dufresne@gmail.com> |     Jon Dufresne <jon.dufresne@gmail.com> | ||||||
|  |     Jon Janzen <jon@jonjanzen.com> | ||||||
|     Jonas Haag <jonas@lophus.org> |     Jonas Haag <jonas@lophus.org> | ||||||
|     Jonas Lundberg <jonas.lundberg@gmail.com> |     Jonas Lundberg <jonas.lundberg@gmail.com> | ||||||
|     Jonathan Davis <jonathandavis47780@gmail.com> |     Jonathan Davis <jonathandavis47780@gmail.com> | ||||||
|   | |||||||
| @@ -2,6 +2,8 @@ import functools | |||||||
| import itertools | import itertools | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  |  | ||||||
|  | from asgiref.sync import sync_to_async | ||||||
|  |  | ||||||
| from django.contrib.contenttypes.models import ContentType | from django.contrib.contenttypes.models import ContentType | ||||||
| from django.core import checks | from django.core import checks | ||||||
| from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist | from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist | ||||||
| @@ -747,6 +749,11 @@ def create_generic_related_manager(superclass, rel): | |||||||
|  |  | ||||||
|         create.alters_data = True |         create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def acreate(self, **kwargs): | ||||||
|  |             return await sync_to_async(self.create)(**kwargs) | ||||||
|  |  | ||||||
|  |         acreate.alters_data = True | ||||||
|  |  | ||||||
|         def get_or_create(self, **kwargs): |         def get_or_create(self, **kwargs): | ||||||
|             kwargs[self.content_type_field_name] = self.content_type |             kwargs[self.content_type_field_name] = self.content_type | ||||||
|             kwargs[self.object_id_field_name] = self.pk_val |             kwargs[self.object_id_field_name] = self.pk_val | ||||||
| @@ -755,6 +762,11 @@ def create_generic_related_manager(superclass, rel): | |||||||
|  |  | ||||||
|         get_or_create.alters_data = True |         get_or_create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def aget_or_create(self, **kwargs): | ||||||
|  |             return await sync_to_async(self.get_or_create)(**kwargs) | ||||||
|  |  | ||||||
|  |         aget_or_create.alters_data = True | ||||||
|  |  | ||||||
|         def update_or_create(self, **kwargs): |         def update_or_create(self, **kwargs): | ||||||
|             kwargs[self.content_type_field_name] = self.content_type |             kwargs[self.content_type_field_name] = self.content_type | ||||||
|             kwargs[self.object_id_field_name] = self.pk_val |             kwargs[self.object_id_field_name] = self.pk_val | ||||||
| @@ -763,4 +775,9 @@ def create_generic_related_manager(superclass, rel): | |||||||
|  |  | ||||||
|         update_or_create.alters_data = True |         update_or_create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def aupdate_or_create(self, **kwargs): | ||||||
|  |             return await sync_to_async(self.update_or_create)(**kwargs) | ||||||
|  |  | ||||||
|  |         aupdate_or_create.alters_data = True | ||||||
|  |  | ||||||
|     return GenericRelatedObjectManager |     return GenericRelatedObjectManager | ||||||
|   | |||||||
| @@ -63,6 +63,8 @@ and two directions (forward and reverse) for a total of six combinations. | |||||||
|    ``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead. |    ``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead. | ||||||
| """ | """ | ||||||
|  |  | ||||||
|  | from asgiref.sync import sync_to_async | ||||||
|  |  | ||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db import ( | from django.db import ( | ||||||
|     DEFAULT_DB_ALIAS, |     DEFAULT_DB_ALIAS, | ||||||
| @@ -793,6 +795,11 @@ def create_reverse_many_to_one_manager(superclass, rel): | |||||||
|  |  | ||||||
|         create.alters_data = True |         create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def acreate(self, **kwargs): | ||||||
|  |             return await sync_to_async(self.create)(**kwargs) | ||||||
|  |  | ||||||
|  |         acreate.alters_data = True | ||||||
|  |  | ||||||
|         def get_or_create(self, **kwargs): |         def get_or_create(self, **kwargs): | ||||||
|             self._check_fk_val() |             self._check_fk_val() | ||||||
|             kwargs[self.field.name] = self.instance |             kwargs[self.field.name] = self.instance | ||||||
| @@ -801,6 +808,11 @@ def create_reverse_many_to_one_manager(superclass, rel): | |||||||
|  |  | ||||||
|         get_or_create.alters_data = True |         get_or_create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def aget_or_create(self, **kwargs): | ||||||
|  |             return await sync_to_async(self.get_or_create)(**kwargs) | ||||||
|  |  | ||||||
|  |         aget_or_create.alters_data = True | ||||||
|  |  | ||||||
|         def update_or_create(self, **kwargs): |         def update_or_create(self, **kwargs): | ||||||
|             self._check_fk_val() |             self._check_fk_val() | ||||||
|             kwargs[self.field.name] = self.instance |             kwargs[self.field.name] = self.instance | ||||||
| @@ -809,6 +821,11 @@ def create_reverse_many_to_one_manager(superclass, rel): | |||||||
|  |  | ||||||
|         update_or_create.alters_data = True |         update_or_create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def aupdate_or_create(self, **kwargs): | ||||||
|  |             return await sync_to_async(self.update_or_create)(**kwargs) | ||||||
|  |  | ||||||
|  |         aupdate_or_create.alters_data = True | ||||||
|  |  | ||||||
|         # remove() and clear() are only provided if the ForeignKey can have a |         # remove() and clear() are only provided if the ForeignKey can have a | ||||||
|         # value of null. |         # value of null. | ||||||
|         if rel.field.null: |         if rel.field.null: | ||||||
| @@ -1191,6 +1208,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): | |||||||
|  |  | ||||||
|         create.alters_data = True |         create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def acreate(self, *, through_defaults=None, **kwargs): | ||||||
|  |             return await sync_to_async(self.create)( | ||||||
|  |                 through_defaults=through_defaults, **kwargs | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         acreate.alters_data = True | ||||||
|  |  | ||||||
|         def get_or_create(self, *, through_defaults=None, **kwargs): |         def get_or_create(self, *, through_defaults=None, **kwargs): | ||||||
|             db = router.db_for_write(self.instance.__class__, instance=self.instance) |             db = router.db_for_write(self.instance.__class__, instance=self.instance) | ||||||
|             obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create( |             obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create( | ||||||
| @@ -1204,6 +1228,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): | |||||||
|  |  | ||||||
|         get_or_create.alters_data = True |         get_or_create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def aget_or_create(self, *, through_defaults=None, **kwargs): | ||||||
|  |             return await sync_to_async(self.get_or_create)( | ||||||
|  |                 through_defaults=through_defaults, **kwargs | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         aget_or_create.alters_data = True | ||||||
|  |  | ||||||
|         def update_or_create(self, *, through_defaults=None, **kwargs): |         def update_or_create(self, *, through_defaults=None, **kwargs): | ||||||
|             db = router.db_for_write(self.instance.__class__, instance=self.instance) |             db = router.db_for_write(self.instance.__class__, instance=self.instance) | ||||||
|             obj, created = super( |             obj, created = super( | ||||||
| @@ -1217,6 +1248,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): | |||||||
|  |  | ||||||
|         update_or_create.alters_data = True |         update_or_create.alters_data = True | ||||||
|  |  | ||||||
|  |         async def aupdate_or_create(self, *, through_defaults=None, **kwargs): | ||||||
|  |             return await sync_to_async(self.update_or_create)( | ||||||
|  |                 through_defaults=through_defaults, **kwargs | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         aupdate_or_create.alters_data = True | ||||||
|  |  | ||||||
|         def _get_target_ids(self, target_field_name, objs): |         def _get_target_ids(self, target_field_name, objs): | ||||||
|             """ |             """ | ||||||
|             Return the set of ids of `objs` that the target field references. |             Return the set of ids of `objs` that the target field references. | ||||||
|   | |||||||
| @@ -76,6 +76,9 @@ Related objects reference | |||||||
|         intermediate instance(s). |         intermediate instance(s). | ||||||
|  |  | ||||||
|     .. method:: create(through_defaults=None, **kwargs) |     .. method:: create(through_defaults=None, **kwargs) | ||||||
|  |     .. method:: acreate(through_defaults=None, **kwargs) | ||||||
|  |  | ||||||
|  |         *Asynchronous version*: ``acreate`` | ||||||
|  |  | ||||||
|         Creates a new object, saves it and puts it in the related object set. |         Creates a new object, saves it and puts it in the related object set. | ||||||
|         Returns the newly created object:: |         Returns the newly created object:: | ||||||
| @@ -110,6 +113,10 @@ Related objects reference | |||||||
|         needed. You can use callables as values in the ``through_defaults`` |         needed. You can use callables as values in the ``through_defaults`` | ||||||
|         dictionary. |         dictionary. | ||||||
|  |  | ||||||
|  |         .. versionchanged:: 4.1 | ||||||
|  |  | ||||||
|  |             ``acreate()`` method was added. | ||||||
|  |  | ||||||
|     .. method:: remove(*objs, bulk=True) |     .. method:: remove(*objs, bulk=True) | ||||||
|  |  | ||||||
|         Removes the specified model objects from the related object set:: |         Removes the specified model objects from the related object set:: | ||||||
|   | |||||||
| @@ -16,3 +16,7 @@ Bugfixes | |||||||
|   an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and |   an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and | ||||||
|   a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod` |   a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod` | ||||||
|   (:ticket:`34088`). |   (:ticket:`34088`). | ||||||
|  |  | ||||||
|  | * Fixed a bug in Django 4.1 that caused a crash of ``acreate()``, | ||||||
|  |   ``aget_or_create()``, and ``aupdate_or_create()`` asynchronous methods for | ||||||
|  |   related managers (:ticket:`34139`). | ||||||
|   | |||||||
| @@ -9,3 +9,7 @@ class RelatedModel(models.Model): | |||||||
| class SimpleModel(models.Model): | class SimpleModel(models.Model): | ||||||
|     field = models.IntegerField() |     field = models.IntegerField() | ||||||
|     created = models.DateTimeField(default=timezone.now) |     created = models.DateTimeField(default=timezone.now) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ManyToManyModel(models.Model): | ||||||
|  |     simples = models.ManyToManyField("SimpleModel") | ||||||
|   | |||||||
							
								
								
									
										56
									
								
								tests/async/test_async_related_managers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								tests/async/test_async_related_managers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | |||||||
|  | from django.test import TestCase | ||||||
|  |  | ||||||
|  | from .models import ManyToManyModel, SimpleModel | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AsyncRelatedManagersOperationTest(TestCase): | ||||||
|  |     @classmethod | ||||||
|  |     def setUpTestData(cls): | ||||||
|  |         cls.mtm1 = ManyToManyModel.objects.create() | ||||||
|  |         cls.s1 = SimpleModel.objects.create(field=0) | ||||||
|  |  | ||||||
|  |     async def test_acreate(self): | ||||||
|  |         await self.mtm1.simples.acreate(field=2) | ||||||
|  |         new_simple = await self.mtm1.simples.aget() | ||||||
|  |         self.assertEqual(new_simple.field, 2) | ||||||
|  |  | ||||||
|  |     async def test_acreate_reverse(self): | ||||||
|  |         await self.s1.relatedmodel_set.acreate() | ||||||
|  |         new_relatedmodel = await self.s1.relatedmodel_set.aget() | ||||||
|  |         self.assertEqual(new_relatedmodel.simple, self.s1) | ||||||
|  |  | ||||||
|  |     async def test_aget_or_create(self): | ||||||
|  |         new_simple, created = await self.mtm1.simples.aget_or_create(field=2) | ||||||
|  |         self.assertIs(created, True) | ||||||
|  |         self.assertEqual(await self.mtm1.simples.acount(), 1) | ||||||
|  |         self.assertEqual(new_simple.field, 2) | ||||||
|  |         new_simple, created = await self.mtm1.simples.aget_or_create( | ||||||
|  |             id=new_simple.id, through_defaults={"field": 3} | ||||||
|  |         ) | ||||||
|  |         self.assertIs(created, False) | ||||||
|  |         self.assertEqual(await self.mtm1.simples.acount(), 1) | ||||||
|  |         self.assertEqual(new_simple.field, 2) | ||||||
|  |  | ||||||
|  |     async def test_aget_or_create_reverse(self): | ||||||
|  |         new_relatedmodel, created = await self.s1.relatedmodel_set.aget_or_create() | ||||||
|  |         self.assertIs(created, True) | ||||||
|  |         self.assertEqual(await self.s1.relatedmodel_set.acount(), 1) | ||||||
|  |         self.assertEqual(new_relatedmodel.simple, self.s1) | ||||||
|  |  | ||||||
|  |     async def test_aupdate_or_create(self): | ||||||
|  |         new_simple, created = await self.mtm1.simples.aupdate_or_create(field=2) | ||||||
|  |         self.assertIs(created, True) | ||||||
|  |         self.assertEqual(await self.mtm1.simples.acount(), 1) | ||||||
|  |         self.assertEqual(new_simple.field, 2) | ||||||
|  |         new_simple, created = await self.mtm1.simples.aupdate_or_create( | ||||||
|  |             id=new_simple.id, defaults={"field": 3} | ||||||
|  |         ) | ||||||
|  |         self.assertIs(created, False) | ||||||
|  |         self.assertEqual(await self.mtm1.simples.acount(), 1) | ||||||
|  |         self.assertEqual(new_simple.field, 3) | ||||||
|  |  | ||||||
|  |     async def test_aupdate_or_create_reverse(self): | ||||||
|  |         new_relatedmodel, created = await self.s1.relatedmodel_set.aupdate_or_create() | ||||||
|  |         self.assertIs(created, True) | ||||||
|  |         self.assertEqual(await self.s1.relatedmodel_set.acount(), 1) | ||||||
|  |         self.assertEqual(new_relatedmodel.simple, self.s1) | ||||||
| @@ -45,6 +45,10 @@ class GenericRelationsTests(TestCase): | |||||||
|         # Original list of tags: |         # Original list of tags: | ||||||
|         return obj.tag, obj.content_type.model_class(), obj.object_id |         return obj.tag, obj.content_type.model_class(), obj.object_id | ||||||
|  |  | ||||||
|  |     async def test_generic_async_acreate(self): | ||||||
|  |         await self.bacon.tags.acreate(tag="orange") | ||||||
|  |         self.assertEqual(await self.bacon.tags.acount(), 3) | ||||||
|  |  | ||||||
|     def test_generic_update_or_create_when_created(self): |     def test_generic_update_or_create_when_created(self): | ||||||
|         """ |         """ | ||||||
|         Should be able to use update_or_create from the generic related manager |         Should be able to use update_or_create from the generic related manager | ||||||
| @@ -70,6 +74,18 @@ class GenericRelationsTests(TestCase): | |||||||
|         self.assertEqual(count + 1, self.bacon.tags.count()) |         self.assertEqual(count + 1, self.bacon.tags.count()) | ||||||
|         self.assertEqual(tag.tag, "juicy") |         self.assertEqual(tag.tag, "juicy") | ||||||
|  |  | ||||||
|  |     async def test_generic_async_aupdate_or_create(self): | ||||||
|  |         tag, created = await self.bacon.tags.aupdate_or_create( | ||||||
|  |             id=self.fatty.id, defaults={"tag": "orange"} | ||||||
|  |         ) | ||||||
|  |         self.assertIs(created, False) | ||||||
|  |         self.assertEqual(tag.tag, "orange") | ||||||
|  |         self.assertEqual(await self.bacon.tags.acount(), 2) | ||||||
|  |         tag, created = await self.bacon.tags.aupdate_or_create(tag="pink") | ||||||
|  |         self.assertIs(created, True) | ||||||
|  |         self.assertEqual(await self.bacon.tags.acount(), 3) | ||||||
|  |         self.assertEqual(tag.tag, "pink") | ||||||
|  |  | ||||||
|     def test_generic_get_or_create_when_created(self): |     def test_generic_get_or_create_when_created(self): | ||||||
|         """ |         """ | ||||||
|         Should be able to use get_or_create from the generic related manager |         Should be able to use get_or_create from the generic related manager | ||||||
| @@ -96,6 +112,18 @@ class GenericRelationsTests(TestCase): | |||||||
|         # shouldn't had changed the tag |         # shouldn't had changed the tag | ||||||
|         self.assertEqual(tag.tag, "stinky") |         self.assertEqual(tag.tag, "stinky") | ||||||
|  |  | ||||||
|  |     async def test_generic_async_aget_or_create(self): | ||||||
|  |         tag, created = await self.bacon.tags.aget_or_create( | ||||||
|  |             id=self.fatty.id, defaults={"tag": "orange"} | ||||||
|  |         ) | ||||||
|  |         self.assertIs(created, False) | ||||||
|  |         self.assertEqual(tag.tag, "fatty") | ||||||
|  |         self.assertEqual(await self.bacon.tags.acount(), 2) | ||||||
|  |         tag, created = await self.bacon.tags.aget_or_create(tag="orange") | ||||||
|  |         self.assertIs(created, True) | ||||||
|  |         self.assertEqual(await self.bacon.tags.acount(), 3) | ||||||
|  |         self.assertEqual(tag.tag, "orange") | ||||||
|  |  | ||||||
|     def test_generic_relations_m2m_mimic(self): |     def test_generic_relations_m2m_mimic(self): | ||||||
|         """ |         """ | ||||||
|         Objects with declared GenericRelations can be tagged directly -- the |         Objects with declared GenericRelations can be tagged directly -- the | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user