diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 2cf75bd528..2d1de9509e 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -52,6 +52,9 @@ class BaseDatabaseWrapper(object): self._dirty = False # Tracks if the connection is in a transaction managed by 'atomic' self.in_atomic_block = False + # Tracks if the transaction should be rolled back to the next + # available savepoint because of an exception in an inner block. + self.needs_rollback = False # List of savepoints created by 'atomic' self.savepoint_ids = [] # Hack to provide compatibility with legacy transaction management diff --git a/django/db/transaction.py b/django/db/transaction.py index a6eb0662d4..be8981f968 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -188,8 +188,11 @@ class Atomic(object): __exit__ commits the transaction or releases the savepoint on normal exit, and rolls back the transaction or to the savepoint on exceptions. + It's possible to disable the creation of savepoints if the goal is to + ensure that some code runs within a transaction without creating overhead. + A stack of savepoints identifiers is maintained as an attribute of the - connection. None denotes a plain transaction. + connection. None denotes the absence of a savepoint. This allows reentrancy even if the same AtomicWrapper is reused. For example, it's possible to define `oa = @atomic('other')` and use `@ao` or @@ -198,8 +201,9 @@ class Atomic(object): Since database connections are thread-local, this is thread-safe. """ - def __init__(self, using): + def __init__(self, using, savepoint): self.using = using + self.savepoint = savepoint def _legacy_enter_transaction_management(self, connection): if not connection.in_atomic_block: @@ -228,9 +232,15 @@ class Atomic(object): "'atomic' cannot be used when autocommit is disabled.") if connection.in_atomic_block: - # We're already in a transaction; create a savepoint. - sid = connection.savepoint() - connection.savepoint_ids.append(sid) + # We're already in a transaction; create a savepoint, unless we + # were told not to or we're already waiting for a rollback. The + # second condition avoids creating useless savepoints and prevents + # overwriting needs_rollback until the rollback is performed. + if self.savepoint and not connection.needs_rollback: + sid = connection.savepoint() + connection.savepoint_ids.append(sid) + else: + connection.savepoint_ids.append(None) else: # We aren't in a transaction yet; create one. # The usual way to start a transaction is to turn autocommit off. @@ -244,13 +254,23 @@ class Atomic(object): else: connection.set_autocommit(False) connection.in_atomic_block = True - connection.savepoint_ids.append(None) + connection.needs_rollback = False def __exit__(self, exc_type, exc_value, traceback): connection = get_connection(self.using) - sid = connection.savepoint_ids.pop() - if exc_value is None: - if sid is None: + if exc_value is None and not connection.needs_rollback: + if connection.savepoint_ids: + # Release savepoint if there is one + sid = connection.savepoint_ids.pop() + if sid is not None: + try: + connection.savepoint_commit(sid) + except DatabaseError: + connection.savepoint_rollback(sid) + # Remove this when the legacy transaction management goes away. + self._legacy_leave_transaction_management(connection) + raise + else: # Commit transaction connection.in_atomic_block = False try: @@ -265,17 +285,19 @@ class Atomic(object): connection.autocommit = True else: connection.set_autocommit(True) - else: - # Release savepoint - try: - connection.savepoint_commit(sid) - except DatabaseError: - connection.savepoint_rollback(sid) - # Remove this when the legacy transaction management goes away. - self._legacy_leave_transaction_management(connection) - raise else: - if sid is None: + # This flag will be set to True again if there isn't a savepoint + # allowing to perform the rollback at this level. + connection.needs_rollback = False + if connection.savepoint_ids: + # Roll back to savepoint if there is one, mark for rollback + # otherwise. + sid = connection.savepoint_ids.pop() + if sid is None: + connection.needs_rollback = True + else: + connection.savepoint_rollback(sid) + else: # Roll back transaction connection.in_atomic_block = False try: @@ -285,9 +307,6 @@ class Atomic(object): connection.autocommit = True else: connection.set_autocommit(True) - else: - # Roll back to savepoint - connection.savepoint_rollback(sid) # Remove this when the legacy transaction management goes away. self._legacy_leave_transaction_management(connection) @@ -301,17 +320,17 @@ class Atomic(object): return inner -def atomic(using=None): +def atomic(using=None, savepoint=True): # Bare decorator: @atomic -- although the first argument is called # `using`, it's actually the function being decorated. if callable(using): - return Atomic(DEFAULT_DB_ALIAS)(using) + return Atomic(DEFAULT_DB_ALIAS, savepoint)(using) # Decorator: @atomic(...) or context manager: with atomic(...): ... else: - return Atomic(using) + return Atomic(using, savepoint) -def atomic_if_autocommit(using=None): +def atomic_if_autocommit(using=None, savepoint=True): # This variant only exists to support the ability to disable transaction # management entirely in the DATABASES setting. It doesn't care about the # autocommit state at run time. @@ -319,7 +338,7 @@ def atomic_if_autocommit(using=None): autocommit = get_connection(db).settings_dict['AUTOCOMMIT'] if autocommit: - return atomic(using) + return atomic(using, savepoint) else: # Bare decorator: @atomic_if_autocommit if callable(using): @@ -447,7 +466,7 @@ def commit_manually(using=None): return _transaction_func(entering, exiting, using) -def commit_on_success_unless_managed(using=None): +def commit_on_success_unless_managed(using=None, savepoint=False): """ Transitory API to preserve backwards-compatibility while refactoring. @@ -455,10 +474,13 @@ def commit_on_success_unless_managed(using=None): simply be replaced by atomic_if_autocommit. Until then, it's necessary to avoid making a commit where Django didn't use to, since entering atomic in managed mode triggers a commmit. + + Unlike atomic, savepoint defaults to False because that's closer to the + legacy behavior. """ connection = get_connection(using) if connection.autocommit or connection.in_atomic_block: - return atomic_if_autocommit(using) + return atomic_if_autocommit(using, savepoint) else: def entering(using): pass diff --git a/docs/topics/db/transactions.txt b/docs/topics/db/transactions.txt index 8cbe0dccd0..d5c22e17f5 100644 --- a/docs/topics/db/transactions.txt +++ b/docs/topics/db/transactions.txt @@ -89,7 +89,7 @@ Controlling transactions explicitly Django provides a single API to control database transactions. -.. function:: atomic(using=None) +.. function:: atomic(using=None, savepoint=True) This function creates an atomic block for writes to the database. (Atomicity is the defining property of database transactions.) @@ -164,6 +164,14 @@ Django provides a single API to control database transactions. - releases or rolls back to the savepoint when exiting an inner block; - commits or rolls back the transaction when exiting the outermost block. + You can disable the creation of savepoints for inner blocks by setting the + ``savepoint`` argument to ``False``. If an exception occurs, Django will + perform the rollback when exiting the first parent block with a savepoint + if there is one, and the outermost block otherwise. Atomicity is still + guaranteed by the outer transaction. This option should only be used if + the overhead of savepoints is noticeable. It has the drawback of breaking + the error handling described above. + .. admonition:: Performance considerations Open transactions have a performance cost for your database server. To diff --git a/tests/transactions/tests.py b/tests/transactions/tests.py index d6cfd8ae95..42a78ad4ba 100644 --- a/tests/transactions/tests.py +++ b/tests/transactions/tests.py @@ -106,6 +106,44 @@ class AtomicTests(TransactionTestCase): raise Exception("Oops, that's his first name") self.assertQuerysetEqual(Reporter.objects.all(), []) + def test_merged_commit_commit(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + self.assertQuerysetEqual(Reporter.objects.all(), + ['', '']) + + def test_merged_commit_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with six.assertRaisesRegex(self, Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + # Writes in the outer block are rolled back too. + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_merged_rollback_commit(self): + with six.assertRaisesRegex(self, Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with transaction.atomic(savepoint=False): + Reporter.objects.create(last_name="Haddock") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_merged_rollback_rollback(self): + with six.assertRaisesRegex(self, Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with six.assertRaisesRegex(self, Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + def test_reuse_commit_commit(self): atomic = transaction.atomic() with atomic: @@ -171,6 +209,61 @@ class AtomicInsideLegacyTransactionManagementTests(AtomicTests): transaction.leave_transaction_management() +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class AtomicMergeTests(TransactionTestCase): + """Test merging transactions with savepoint=False.""" + + def test_merged_outer_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with six.assertRaisesRegex(self, Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Tournesol") + raise Exception("Oops, that's his last name") + # It wasn't possible to roll back + self.assertEqual(Reporter.objects.count(), 3) + # It wasn't possible to roll back + self.assertEqual(Reporter.objects.count(), 3) + # The outer block must roll back + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_merged_inner_savepoint_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with transaction.atomic(): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with six.assertRaisesRegex(self, Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Tournesol") + raise Exception("Oops, that's his last name") + # It wasn't possible to roll back + self.assertEqual(Reporter.objects.count(), 3) + # The first block with a savepoint must roll back + self.assertEqual(Reporter.objects.count(), 1) + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + def test_merged_outer_rollback_after_inner_failure_and_inner_success(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + # Inner block without a savepoint fails + with six.assertRaisesRegex(self, Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + # It wasn't possible to roll back + self.assertEqual(Reporter.objects.count(), 2) + # Inner block with a savepoint succeeds + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + # It still wasn't possible to roll back + self.assertEqual(Reporter.objects.count(), 3) + # The outer block must rollback + self.assertQuerysetEqual(Reporter.objects.all(), []) + + @skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") class AtomicErrorsTests(TransactionTestCase):