From bc9be72bdc9bb4dfc7f967ac3856115f0a6166b8 Mon Sep 17 00:00:00 2001 From: Loic Bistuer Date: Sun, 30 Mar 2014 01:57:28 +0700 Subject: [PATCH] Fixed transaction handling for a number of operations on related objects. Thanks Anssi and Aymeric for the reviews. Refs #21174. --- django/contrib/contenttypes/fields.py | 25 +++-- django/db/models/fields/related.py | 144 +++++++++++++++----------- docs/releases/1.8.txt | 14 ++- tests/many_to_many/tests.py | 5 +- tests/multiple_database/tests.py | 15 ++- 5 files changed, 124 insertions(+), 79 deletions(-) diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index b551d66b13..7e91957546 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -394,9 +394,12 @@ class ReverseGenericRelatedObjectsDescriptor(object): def __set__(self, instance, value): manager = self.__get__(instance) - manager.clear() - for obj in value: - manager.add(obj) + + db = router.db_for_write(manager.model, instance=manager.instance) + with transaction.atomic(using=db, savepoint=False): + manager.clear() + for obj in value: + manager.add(obj) def create_generic_related_manager(superclass): @@ -474,12 +477,14 @@ def create_generic_related_manager(superclass): self.prefetch_cache_name) def add(self, *objs): - for obj in objs: - if not isinstance(obj, self.model): - raise TypeError("'%s' instance expected" % self.model._meta.object_name) - setattr(obj, self.content_type_field_name, self.content_type) - setattr(obj, self.object_id_field_name, self.pk_val) - obj.save() + db = router.db_for_write(self.model, instance=self.instance) + with transaction.atomic(using=db, savepoint=False): + for obj in objs: + if not isinstance(obj, self.model): + raise TypeError("'%s' instance expected" % self.model._meta.object_name) + setattr(obj, self.content_type_field_name, self.content_type) + setattr(obj, self.object_id_field_name, self.pk_val) + obj.save() add.alters_data = True def remove(self, *objs, **kwargs): @@ -498,6 +503,8 @@ def create_generic_related_manager(superclass): db = router.db_for_write(self.model, instance=self.instance) queryset = queryset.using(db) if bulk: + # `QuerySet.delete()` creates its own atomic block which + # contains the `pre_delete` and `post_delete` signal handlers. queryset.delete() else: with transaction.atomic(using=db, savepoint=False): diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index ef48ff538b..523b64a7e2 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -735,6 +735,7 @@ def create_foreign_related_manager(superclass, rel_field, rel_model): db = router.db_for_write(self.model, instance=self.instance) queryset = queryset.using(db) if bulk: + # `QuerySet.update()` is intrinsically atomic. queryset.update(**{rel_field.name: None}) else: with transaction.atomic(using=db, savepoint=False): @@ -763,11 +764,14 @@ class ForeignRelatedObjectsDescriptor(object): def __set__(self, instance, value): manager = self.__get__(instance) - # If the foreign key can support nulls, then completely clear the related set. - # Otherwise, just move the named objects into the set. - if self.related.field.null: - manager.clear() - manager.add(*value) + + db = router.db_for_write(manager.model, instance=manager.instance) + with transaction.atomic(using=db, savepoint=False): + # If the foreign key can support nulls, then completely clear the related set. + # Otherwise, just move the named objects into the set. + if self.related.field.null: + manager.clear() + manager.add(*value) @cached_property def related_manager_cls(self): @@ -901,11 +905,14 @@ def create_many_related_manager(superclass, rel): "Cannot use add() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name) ) - self._add_items(self.source_field_name, self.target_field_name, *objs) - # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table - if self.symmetrical: - self._add_items(self.target_field_name, self.source_field_name, *objs) + db = router.db_for_write(self.through, instance=self.instance) + with transaction.atomic(using=db, savepoint=False): + self._add_items(self.source_field_name, self.target_field_name, *objs) + + # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table + if self.symmetrical: + self._add_items(self.target_field_name, self.source_field_name, *objs) add.alters_data = True def remove(self, *objs): @@ -920,17 +927,17 @@ def create_many_related_manager(superclass, rel): def clear(self): db = router.db_for_write(self.through, instance=self.instance) + with transaction.atomic(using=db, savepoint=False): + signals.m2m_changed.send(sender=self.through, action="pre_clear", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=None, using=db) - signals.m2m_changed.send(sender=self.through, action="pre_clear", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=None, using=db) + filters = self._build_remove_filters(super(ManyRelatedManager, self).get_queryset().using(db)) + self.through._default_manager.using(db).filter(filters).delete() - filters = self._build_remove_filters(super(ManyRelatedManager, self).get_queryset().using(db)) - self.through._default_manager.using(db).filter(filters).delete() - - signals.m2m_changed.send(sender=self.through, action="post_clear", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=None, using=db) + signals.m2m_changed.send(sender=self.through, action="post_clear", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=None, using=db) clear.alters_data = True def create(self, **kwargs): @@ -990,35 +997,39 @@ def create_many_related_manager(superclass, rel): ) else: new_ids.add(obj) + db = router.db_for_write(self.through, instance=self.instance) - vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) - vals = vals.filter(**{ - source_field_name: self.related_val[0], - '%s__in' % target_field_name: new_ids, - }) + vals = (self.through._default_manager.using(db) + .values_list(target_field_name, flat=True) + .filter(**{ + source_field_name: self.related_val[0], + '%s__in' % target_field_name: new_ids, + })) new_ids = new_ids - set(vals) - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are inserting the - # duplicate data row for symmetrical reverse entries. - signals.m2m_changed.send(sender=self.through, action='pre_add', - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=new_ids, using=db) - # Add the ones that aren't there already - self.through._default_manager.using(db).bulk_create([ - self.through(**{ - '%s_id' % source_field_name: self.related_val[0], - '%s_id' % target_field_name: obj_id, - }) - for obj_id in new_ids - ]) + with transaction.atomic(using=db, savepoint=False): + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are inserting the + # duplicate data row for symmetrical reverse entries. + signals.m2m_changed.send(sender=self.through, action='pre_add', + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=new_ids, using=db) - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are inserting the - # duplicate data row for symmetrical reverse entries. - signals.m2m_changed.send(sender=self.through, action='post_add', - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=new_ids, using=db) + # Add the ones that aren't there already + self.through._default_manager.using(db).bulk_create([ + self.through(**{ + '%s_id' % source_field_name: self.related_val[0], + '%s_id' % target_field_name: obj_id, + }) + for obj_id in new_ids + ]) + + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are inserting the + # duplicate data row for symmetrical reverse entries. + signals.m2m_changed.send(sender=self.through, action='post_add', + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=new_ids, using=db) def _remove_items(self, source_field_name, target_field_name, *objs): # source_field_name: the PK colname in join table for the source object @@ -1037,23 +1048,23 @@ def create_many_related_manager(superclass, rel): old_ids.add(obj) db = router.db_for_write(self.through, instance=self.instance) + with transaction.atomic(using=db, savepoint=False): + # Send a signal to the other end if need be. + signals.m2m_changed.send(sender=self.through, action="pre_remove", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=old_ids, using=db) + target_model_qs = super(ManyRelatedManager, self).get_queryset() + if target_model_qs._has_filters(): + old_vals = target_model_qs.using(db).filter(**{ + '%s__in' % self.target_field.related_field.attname: old_ids}) + else: + old_vals = old_ids + filters = self._build_remove_filters(old_vals) + self.through._default_manager.using(db).filter(filters).delete() - # Send a signal to the other end if need be. - signals.m2m_changed.send(sender=self.through, action="pre_remove", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=old_ids, using=db) - target_model_qs = super(ManyRelatedManager, self).get_queryset() - if target_model_qs._has_filters(): - old_vals = target_model_qs.using(db).filter(**{ - '%s__in' % self.target_field.related_field.attname: old_ids}) - else: - old_vals = old_ids - filters = self._build_remove_filters(old_vals) - self.through._default_manager.using(db).filter(filters).delete() - - signals.m2m_changed.send(sender=self.through, action="post_remove", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=old_ids, using=db) + signals.m2m_changed.send(sender=self.through, action="post_remove", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=old_ids, using=db) return ManyRelatedManager @@ -1103,8 +1114,11 @@ class ManyRelatedObjectsDescriptor(object): raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)) manager = self.__get__(instance) - manager.clear() - manager.add(*value) + + db = router.db_for_write(manager.through, instance=manager.instance) + with transaction.atomic(using=db, savepoint=False): + manager.clear() + manager.add(*value) class ReverseManyRelatedObjectsDescriptor(object): @@ -1157,11 +1171,15 @@ class ReverseManyRelatedObjectsDescriptor(object): raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)) manager = self.__get__(instance) + # clear() can change expected output of 'value' queryset, we force evaluation # of queryset before clear; ticket #19816 value = tuple(value) - manager.clear() - manager.add(*value) + + db = router.db_for_write(manager.through, instance=manager.instance) + with transaction.atomic(using=db, savepoint=False): + manager.clear() + manager.add(*value) class ForeignObjectRel(object): diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 1f56c3109e..789c1db535 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -168,7 +168,19 @@ Backwards incompatible changes in 1.8 deprecation timeline for a given feature, its removal may appear as a backwards incompatible change. -... +* Some operations on related objects such as + :meth:`~django.db.models.fields.related.RelatedManager.add()` or + :ref:`direct assignment` ran multiple data modifying + queries without wrapping them in transactions. To reduce the risk of data + corruption, all data modifying methods that affect multiple related objects + (i.e. ``add()``, ``remove()``, ``clear()``, and + :ref:`direct assignment`) now perform their data modifying + queries from within a transaction, provided your database supports + transactions. + + This has one backwards incompatible side effect, signal handlers triggered + from these methods are now executed within the method's transaction and + any exception in a signal handler will prevent the whole operation. Miscellaneous ~~~~~~~~~~~~~ diff --git a/tests/many_to_many/tests.py b/tests/many_to_many/tests.py index 4293fc0c42..545d50a90e 100644 --- a/tests/many_to_many/tests.py +++ b/tests/many_to_many/tests.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +from django.db import transaction from django.test import TestCase from django.utils import six @@ -54,7 +55,9 @@ class ManyToManyTests(TestCase): # Adding an object of the wrong type raises TypeError with six.assertRaisesRegex(self, TypeError, "'Publication' instance expected, got