From 28522c3c8d5eb581347aececc3ac61c134528114 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Tue, 25 Jun 2024 17:12:10 +0100 Subject: [PATCH] Fixed #35554, Refs #35060 -- Corrected deprecated *args parsing in Model.save()/asave(). The transitional logic added to deprecate the usage of *args for Model.save()/asave() introduced two issues that this branch fixes: * Passing extra positional arguments no longer raised TypeError. * Passing a positional but empty update_fields would save all fields. Co-authored-by: Natalia <124304+nessita@users.noreply.github.com> --- django/db/models/base.py | 52 ++++++++++++++------------- tests/basic/tests.py | 58 +++++++++++++++++++++++++++++++ tests/update_only_fields/tests.py | 24 +++++++++++++ 3 files changed, 110 insertions(+), 24 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index cd300e47bc..dcfdd6eade 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -803,18 +803,20 @@ class Model(AltersData, metaclass=ModelBase): RemovedInDjango60Warning, stacklevel=2, ) - for arg, attr in zip( - args, ["force_insert", "force_update", "using", "update_fields"] - ): - if arg: - if attr == "force_insert": - force_insert = arg - elif attr == "force_update": - force_update = arg - elif attr == "using": - using = arg - else: - update_fields = arg + total_len_args = len(args) + 1 # include self + if total_len_args > 5: + # Recreate the proper TypeError message from Python. + raise TypeError( + "Model.save() takes from 1 to 5 positional arguments but " + f"{total_len_args} were given" + ) + force_insert = args[0] + try: + force_update = args[1] + using = args[2] + update_fields = args[3] + except IndexError: + pass self._prepare_related_fields_for_save(operation_name="save") @@ -888,18 +890,20 @@ class Model(AltersData, metaclass=ModelBase): RemovedInDjango60Warning, stacklevel=2, ) - for arg, attr in zip( - args, ["force_insert", "force_update", "using", "update_fields"] - ): - if arg: - if attr == "force_insert": - force_insert = arg - elif attr == "force_update": - force_update = arg - elif attr == "using": - using = arg - else: - update_fields = arg + total_len_args = len(args) + 1 # include self + if total_len_args > 5: + # Recreate the proper TypeError message from Python. + raise TypeError( + "Model.asave() takes from 1 to 5 positional arguments but " + f"{total_len_args} were given" + ) + force_insert = args[0] + try: + force_update = args[1] + using = args[2] + update_fields = args[3] + except IndexError: + pass return await sync_to_async(self.save)( force_insert=force_insert, diff --git a/tests/basic/tests.py b/tests/basic/tests.py index 38fb9ca200..4e89febed2 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -210,6 +210,35 @@ class ModelInstanceCreationTests(TestCase): a.save(False, False, None, None) self.assertEqual(Article.objects.count(), 1) + def test_save_deprecation_positional_arguments_used(self): + a = Article() + fields = ["headline"] + with ( + self.assertWarns(RemovedInDjango60Warning), + mock.patch.object(a, "save_base") as mock_save_base, + ): + a.save(None, 1, 2, fields) + self.assertEqual( + mock_save_base.mock_calls, + [ + mock.call( + using=2, + force_insert=None, + force_update=1, + update_fields=frozenset(fields), + ) + ], + ) + + def test_save_too_many_positional_arguments(self): + a = Article() + msg = "Model.save() takes from 1 to 5 positional arguments but 6 were given" + with ( + self.assertWarns(RemovedInDjango60Warning), + self.assertRaisesMessage(TypeError, msg), + ): + a.save(False, False, None, None, None) + async def test_asave_deprecation(self): a = Article(headline="original", pub_date=datetime(2014, 5, 16)) msg = "Passing positional arguments to asave() is deprecated" @@ -217,6 +246,35 @@ class ModelInstanceCreationTests(TestCase): await a.asave(False, False, None, None) self.assertEqual(await Article.objects.acount(), 1) + async def test_asave_deprecation_positional_arguments_used(self): + a = Article() + fields = ["headline"] + with ( + self.assertWarns(RemovedInDjango60Warning), + mock.patch.object(a, "save_base") as mock_save_base, + ): + await a.asave(None, 1, 2, fields) + self.assertEqual( + mock_save_base.mock_calls, + [ + mock.call( + using=2, + force_insert=None, + force_update=1, + update_fields=frozenset(fields), + ) + ], + ) + + async def test_asave_too_many_positional_arguments(self): + a = Article() + msg = "Model.asave() takes from 1 to 5 positional arguments but 6 were given" + with ( + self.assertWarns(RemovedInDjango60Warning), + self.assertRaisesMessage(TypeError, msg), + ): + await a.asave(False, False, None, None, None) + @ignore_warnings(category=RemovedInDjango60Warning) def test_save_positional_arguments(self): a = Article.objects.create(headline="original", pub_date=datetime(2014, 5, 16)) diff --git a/tests/update_only_fields/tests.py b/tests/update_only_fields/tests.py index 6c23ae27d8..816112bc33 100644 --- a/tests/update_only_fields/tests.py +++ b/tests/update_only_fields/tests.py @@ -1,5 +1,6 @@ from django.db.models.signals import post_save, pre_save from django.test import TestCase +from django.utils.deprecation import RemovedInDjango60Warning from .models import Account, Employee, Person, Profile, ProxyEmployee @@ -256,6 +257,29 @@ class UpdateOnlyFieldsTests(TestCase): pre_save.disconnect(pre_save_receiver) post_save.disconnect(post_save_receiver) + def test_empty_update_fields_positional_save(self): + s = Person.objects.create(name="Sara", gender="F") + + msg = "Passing positional arguments to save() is deprecated" + with ( + self.assertWarnsMessage(RemovedInDjango60Warning, msg), + self.assertNumQueries(0), + ): + s.save(False, False, None, []) + + async def test_empty_update_fields_positional_asave(self): + s = await Person.objects.acreate(name="Sara", gender="F") + # Workaround for a lack of async assertNumQueries. + s.name = "Other" + + msg = "Passing positional arguments to asave() is deprecated" + with self.assertWarnsMessage(RemovedInDjango60Warning, msg): + await s.asave(False, False, None, []) + + # No save occurred for an empty update_fields. + await s.arefresh_from_db() + self.assertEqual(s.name, "Sara") + def test_num_queries_inheritance(self): s = Employee.objects.create(name="Sara", gender="F") s.employee_num = 1