From a40b0103bccb8216c944188d329d8ea5eceb7e92 Mon Sep 17 00:00:00 2001 From: Akash Kumar Sen Date: Thu, 22 Jun 2023 18:23:11 +0530 Subject: [PATCH] Fixed #30382 -- Allowed specifying parent classes in force_insert of Model.save(). --- django/db/models/base.py | 33 +++++++++++-- docs/ref/models/instances.txt | 17 +++++++ docs/releases/5.0.txt | 4 ++ tests/extra_regress/models.py | 2 +- tests/force_insert_update/models.py | 10 ++++ tests/force_insert_update/tests.py | 77 ++++++++++++++++++++++++++++- 6 files changed, 138 insertions(+), 5 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index 0c4a5ddcfc..0711ec0d61 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -832,6 +832,26 @@ class Model(AltersData, metaclass=ModelBase): asave.alters_data = True + @classmethod + def _validate_force_insert(cls, force_insert): + if force_insert is False: + return () + if force_insert is True: + return (cls,) + if not isinstance(force_insert, tuple): + raise TypeError("force_insert must be a bool or tuple.") + for member in force_insert: + if not isinstance(member, ModelBase): + raise TypeError( + f"Invalid force_insert member. {member!r} must be a model subclass." + ) + if not issubclass(cls, member): + raise TypeError( + f"Invalid force_insert member. {member.__qualname__} must be a " + f"base of {cls.__qualname__}." + ) + return force_insert + def save_base( self, raw=False, @@ -873,7 +893,11 @@ class Model(AltersData, metaclass=ModelBase): with context_manager: parent_inserted = False if not raw: - parent_inserted = self._save_parents(cls, using, update_fields) + # Validate force insert only when parents are inserted. + force_insert = self._validate_force_insert(force_insert) + parent_inserted = self._save_parents( + cls, using, update_fields, force_insert + ) updated = self._save_table( raw, cls, @@ -900,7 +924,9 @@ class Model(AltersData, metaclass=ModelBase): save_base.alters_data = True - def _save_parents(self, cls, using, update_fields, updated_parents=None): + def _save_parents( + self, cls, using, update_fields, force_insert, updated_parents=None + ): """Save all the parents of cls using values from self.""" meta = cls._meta inserted = False @@ -919,13 +945,14 @@ class Model(AltersData, metaclass=ModelBase): cls=parent, using=using, update_fields=update_fields, + force_insert=force_insert, updated_parents=updated_parents, ) updated = self._save_table( cls=parent, using=using, update_fields=update_fields, - force_insert=parent_inserted, + force_insert=parent_inserted or issubclass(parent, force_insert), ) if not updated: inserted = True diff --git a/docs/ref/models/instances.txt b/docs/ref/models/instances.txt index 346ae55130..6bfd521aaa 100644 --- a/docs/ref/models/instances.txt +++ b/docs/ref/models/instances.txt @@ -589,6 +589,18 @@ row. In these cases you can pass the ``force_insert=True`` or Passing both parameters is an error: you cannot both insert *and* update at the same time! +When using :ref:`multi-table inheritance `, it's also +possible to provide a tuple of parent classes to ``force_insert`` in order to +force ``INSERT`` statements for each base. For example:: + + Restaurant(pk=1, name="Bob's Cafe").save(force_insert=(Place,)) + + Restaurant(pk=1, name="Bob's Cafe", rating=4).save(force_insert=(Place, Rating)) + +You can pass ``force_insert=(models.Model,)`` to force an ``INSERT`` statement +for all parents. By default, ``force_insert=True`` only forces the insertion of +a new row for the current model. + It should be very rare that you'll need to use these parameters. Django will almost always do the right thing and trying to override that will lead to errors that are difficult to track down. This feature is for advanced use @@ -596,6 +608,11 @@ only. Using ``update_fields`` will force an update similarly to ``force_update``. +.. versionchanged:: 5.0 + + Support for passing a tuple of parent classes to ``force_insert`` was + added. + .. _ref-models-field-updates-using-f-expressions: Updating attributes based on existing fields diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 60de6f81e7..f5f4ecd668 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -335,6 +335,10 @@ Models :ref:`Choices classes ` directly instead of requiring expansion with the ``choices`` attribute. +* The :ref:`force_insert ` argument of + :meth:`.Model.save` now allows specifying a tuple of parent classes that must + be forced to be inserted. + Pagination ~~~~~~~~~~ diff --git a/tests/extra_regress/models.py b/tests/extra_regress/models.py index 2122bae7ae..4111d439b5 100644 --- a/tests/extra_regress/models.py +++ b/tests/extra_regress/models.py @@ -10,7 +10,7 @@ class RevisionableModel(models.Model): title = models.CharField(blank=True, max_length=255) when = models.DateTimeField(default=datetime.datetime.now) - def save(self, *args, force_insert=None, force_update=None, **kwargs): + def save(self, *args, force_insert=False, force_update=False, **kwargs): super().save( *args, force_insert=force_insert, force_update=force_update, **kwargs ) diff --git a/tests/force_insert_update/models.py b/tests/force_insert_update/models.py index 586be12f13..b95b197454 100644 --- a/tests/force_insert_update/models.py +++ b/tests/force_insert_update/models.py @@ -30,3 +30,13 @@ class SubSubCounter(SubCounter): class WithCustomPK(models.Model): name = models.IntegerField(primary_key=True) value = models.IntegerField() + + +class OtherSubCounter(Counter): + other_counter_ptr = models.OneToOneField( + Counter, primary_key=True, parent_link=True, on_delete=models.CASCADE + ) + + +class DiamondSubSubCounter(SubCounter, OtherSubCounter): + pass diff --git a/tests/force_insert_update/tests.py b/tests/force_insert_update/tests.py index e2fefd3269..cc223cf3ea 100644 --- a/tests/force_insert_update/tests.py +++ b/tests/force_insert_update/tests.py @@ -1,9 +1,11 @@ -from django.db import DatabaseError, IntegrityError, transaction +from django.db import DatabaseError, IntegrityError, models, transaction from django.test import TestCase from .models import ( Counter, + DiamondSubSubCounter, InheritedCounter, + OtherSubCounter, ProxyCounter, SubCounter, SubSubCounter, @@ -76,6 +78,29 @@ class InheritanceTests(TestCase): class ForceInsertInheritanceTests(TestCase): + def test_force_insert_not_bool_or_tuple(self): + msg = "force_insert must be a bool or tuple." + with self.assertRaisesMessage(TypeError, msg), transaction.atomic(): + Counter().save(force_insert=1) + with self.assertRaisesMessage(TypeError, msg), transaction.atomic(): + Counter().save(force_insert="test") + with self.assertRaisesMessage(TypeError, msg), transaction.atomic(): + Counter().save(force_insert=[]) + + def test_force_insert_not_model(self): + msg = f"Invalid force_insert member. {object!r} must be a model subclass." + with self.assertRaisesMessage(TypeError, msg), transaction.atomic(): + Counter().save(force_insert=(object,)) + instance = Counter() + msg = f"Invalid force_insert member. {instance!r} must be a model subclass." + with self.assertRaisesMessage(TypeError, msg), transaction.atomic(): + Counter().save(force_insert=(instance,)) + + def test_force_insert_not_base(self): + msg = "Invalid force_insert member. SubCounter must be a base of Counter." + with self.assertRaisesMessage(TypeError, msg): + Counter().save(force_insert=(SubCounter,)) + def test_force_insert_false(self): with self.assertNumQueries(3): obj = SubCounter.objects.create(pk=1, value=0) @@ -87,6 +112,10 @@ class ForceInsertInheritanceTests(TestCase): SubCounter(pk=obj.pk, value=2).save(force_insert=False) obj.refresh_from_db() self.assertEqual(obj.value, 2) + with self.assertNumQueries(2): + SubCounter(pk=obj.pk, value=3).save(force_insert=()) + obj.refresh_from_db() + self.assertEqual(obj.value, 3) def test_force_insert_false_with_existing_parent(self): parent = Counter.objects.create(pk=1, value=1) @@ -96,13 +125,59 @@ class ForceInsertInheritanceTests(TestCase): def test_force_insert_parent(self): with self.assertNumQueries(3): SubCounter(pk=1, value=1).save(force_insert=True) + # Force insert a new parent and don't UPDATE first. + with self.assertNumQueries(2): + SubCounter(pk=2, value=1).save(force_insert=(Counter,)) + with self.assertNumQueries(2): + SubCounter(pk=3, value=1).save(force_insert=(models.Model,)) def test_force_insert_with_grandparent(self): with self.assertNumQueries(4): SubSubCounter(pk=1, value=1).save(force_insert=True) + # Force insert parents on all levels and don't UPDATE first. + with self.assertNumQueries(3): + SubSubCounter(pk=2, value=1).save(force_insert=(models.Model,)) + with self.assertNumQueries(3): + SubSubCounter(pk=3, value=1).save(force_insert=(Counter,)) + # Force insert only the last parent. + with self.assertNumQueries(4): + SubSubCounter(pk=4, value=1).save(force_insert=(SubCounter,)) def test_force_insert_with_existing_grandparent(self): # Force insert only the last child. grandparent = Counter.objects.create(pk=1, value=1) with self.assertNumQueries(4): SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=True) + # Force insert a parent, and don't force insert a grandparent. + grandparent = Counter.objects.create(pk=2, value=1) + with self.assertNumQueries(3): + SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=(SubCounter,)) + # Force insert parents on all levels, grandparent conflicts. + grandparent = Counter.objects.create(pk=3, value=1) + with self.assertRaises(IntegrityError), transaction.atomic(): + SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=(Counter,)) + + def test_force_insert_diamond_mti(self): + # Force insert all parents. + with self.assertNumQueries(4): + DiamondSubSubCounter(pk=1, value=1).save( + force_insert=(Counter, SubCounter, OtherSubCounter) + ) + with self.assertNumQueries(4): + DiamondSubSubCounter(pk=2, value=1).save(force_insert=(models.Model,)) + # Force insert parents, and don't force insert a common grandparent. + with self.assertNumQueries(5): + DiamondSubSubCounter(pk=3, value=1).save( + force_insert=(SubCounter, OtherSubCounter) + ) + grandparent = Counter.objects.create(pk=4, value=1) + with self.assertNumQueries(4): + DiamondSubSubCounter(pk=grandparent.pk, value=1).save( + force_insert=(SubCounter, OtherSubCounter), + ) + # Force insert all parents, grandparent conflicts. + grandparent = Counter.objects.create(pk=5, value=1) + with self.assertRaises(IntegrityError), transaction.atomic(): + DiamondSubSubCounter(pk=grandparent.pk, value=1).save( + force_insert=(models.Model,) + )