diff --git a/django/db/transaction.py b/django/db/transaction.py index 7827d04e9b..248cb65cb7 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -345,16 +345,19 @@ def commit_on_success(using=None): managed(True, using=using) def exiting(exc_value, using): - if exc_value is not None: - if is_dirty(using=using): - rollback(using=using) - else: - if is_dirty(using=using): - try: - commit(using=using) - except: + try: + if exc_value is not None: + if is_dirty(using=using): rollback(using=using) - raise + else: + if is_dirty(using=using): + try: + commit(using=using) + except: + rollback(using=using) + raise + finally: + leave_transaction_management(using=using) return _transaction_func(entering, exiting, using) diff --git a/tests/modeltests/transactions/tests.py b/tests/modeltests/transactions/tests.py index 9deb18382c..00ceed584b 100644 --- a/tests/modeltests/transactions/tests.py +++ b/tests/modeltests/transactions/tests.py @@ -113,6 +113,25 @@ class TransactionTests(TransactionTestCase): remove_comitted_on_success("Alice") self.assertEqual(list(Reporter.objects.all()), []) + @skipUnlessDBFeature('supports_transactions') + def test_commit_on_success_exit(self): + @transaction.autocommit() + def gen_reporter(): + @transaction.commit_on_success + def create_reporter(): + Reporter.objects.create(first_name="Bobby", last_name="Tables") + + create_reporter() + # Much more formal + r = Reporter.objects.get() + r.first_name = "Robert" + r.save() + + gen_reporter() + r = Reporter.objects.get() + self.assertEqual(r.first_name, "Robert") + + @skipUnlessDBFeature('supports_transactions') def test_manually_managed(self): """ @@ -146,6 +165,7 @@ class TransactionTests(TransactionTestCase): using_manually_managed_mistake ) + class TransactionRollbackTests(TransactionTestCase): def execute_bad_sql(self): cursor = connection.cursor() diff --git a/tests/modeltests/transactions/tests_25.py b/tests/modeltests/transactions/tests_25.py index ec3c4d1215..cc2290edff 100644 --- a/tests/modeltests/transactions/tests_25.py +++ b/tests/modeltests/transactions/tests_25.py @@ -78,6 +78,20 @@ class TransactionContextManagerTests(TransactionTestCase): self.assertQuerysetEqual(Reporter.objects.all(), []) + @skipUnlessDBFeature('supports_transactions') + def test_commit_on_success_exit(self): + with transaction.autocommit(): + with transaction.commit_on_success(): + Reporter.objects.create(first_name="Bobby", last_name="Tables") + + # Much more formal + r = Reporter.objects.get() + r.first_name = "Robert" + r.save() + + r = Reporter.objects.get() + self.assertEqual(r.first_name, "Robert") + @skipUnlessDBFeature('supports_transactions') def test_manually_managed(self): """