mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #21134 -- Prevented queries in broken transactions.
Squashed commit of the following: commit 63ddb271a44df389b2c302e421fc17b7f0529755 Author: Aymeric Augustin <aymeric.augustin@m4x.org> Date: Sun Sep 29 22:51:00 2013 +0200 Clarified interactions between atomic and exceptions. commit 2899ec299228217c876ba3aa4024e523a41c8504 Author: Aymeric Augustin <aymeric.augustin@m4x.org> Date: Sun Sep 22 22:45:32 2013 +0200 Fixed TransactionManagementError in tests. Previous commit introduced an additional check to prevent running queries in transactions that will be rolled back, which triggered a few failures in the tests. In practice using transaction.atomic instead of the low-level savepoint APIs was enough to fix the problems. commit 4a639b059ea80aeb78f7f160a7d4b9f609b9c238 Author: Aymeric Augustin <aymeric.augustin@m4x.org> Date: Tue Sep 24 22:24:17 2013 +0200 Allowed nesting constraint_checks_disabled inside atomic. Since MySQL handles transactions loosely, this isn't a problem. commit 2a4ab1cb6e83391ff7e25d08479e230ca564bfef Author: Aymeric Augustin <aymeric.augustin@m4x.org> Date: Sat Sep 21 18:43:12 2013 +0200 Prevented running queries in transactions that will be rolled back. This avoids a counter-intuitive behavior in an edge case on databases with non-atomic transaction semantics. It prevents using savepoint_rollback() inside an atomic block without calling set_rollback(False) first, which is backwards-incompatible in tests. Refs #21134. commit 8e3db393853c7ac64a445b66e57f3620a3fde7b0 Author: Aymeric Augustin <aymeric.augustin@m4x.org> Date: Sun Sep 22 22:14:17 2013 +0200 Replaced manual savepoints by atomic blocks. This ensures the rollback flag is handled consistently in internal APIs.
This commit is contained in:
		| @@ -58,12 +58,11 @@ class SessionStore(SessionBase): | |||||||
|             expire_date=self.get_expiry_date() |             expire_date=self.get_expiry_date() | ||||||
|         ) |         ) | ||||||
|         using = router.db_for_write(Session, instance=obj) |         using = router.db_for_write(Session, instance=obj) | ||||||
|         sid = transaction.savepoint(using=using) |  | ||||||
|         try: |         try: | ||||||
|             obj.save(force_insert=must_create, using=using) |             with transaction.atomic(using=using): | ||||||
|  |                 obj.save(force_insert=must_create, using=using) | ||||||
|         except IntegrityError: |         except IntegrityError: | ||||||
|             if must_create: |             if must_create: | ||||||
|                 transaction.savepoint_rollback(sid, using=using) |  | ||||||
|                 raise CreateError |                 raise CreateError | ||||||
|             raise |             raise | ||||||
|  |  | ||||||
|   | |||||||
| @@ -361,6 +361,12 @@ class BaseDatabaseWrapper(object): | |||||||
|             raise TransactionManagementError( |             raise TransactionManagementError( | ||||||
|                 "This is forbidden when an 'atomic' block is active.") |                 "This is forbidden when an 'atomic' block is active.") | ||||||
|  |  | ||||||
|  |     def validate_no_broken_transaction(self): | ||||||
|  |         if self.needs_rollback: | ||||||
|  |             raise TransactionManagementError( | ||||||
|  |                 "An error occurred in the current transaction. You can't " | ||||||
|  |                 "execute queries until the end of the 'atomic' block.") | ||||||
|  |  | ||||||
|     def abort(self): |     def abort(self): | ||||||
|         """ |         """ | ||||||
|         Roll back any ongoing transaction and clean the transaction state |         Roll back any ongoing transaction and clean the transaction state | ||||||
| @@ -638,6 +644,9 @@ class BaseDatabaseFeatures(object): | |||||||
|     # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965 |     # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965 | ||||||
|     autocommits_when_autocommit_is_off = False |     autocommits_when_autocommit_is_off = False | ||||||
|  |  | ||||||
|  |     # Does the backend prevent running SQL queries in broken transactions? | ||||||
|  |     atomic_transactions = True | ||||||
|  |  | ||||||
|     # Can we roll back DDL in a transaction? |     # Can we roll back DDL in a transaction? | ||||||
|     can_rollback_ddl = False |     can_rollback_ddl = False | ||||||
|  |  | ||||||
|   | |||||||
| @@ -172,6 +172,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|     requires_explicit_null_ordering_when_grouping = True |     requires_explicit_null_ordering_when_grouping = True | ||||||
|     allows_primary_key_0 = False |     allows_primary_key_0 = False | ||||||
|     uses_savepoints = True |     uses_savepoints = True | ||||||
|  |     atomic_transactions = False | ||||||
|     supports_check_constraints = False |     supports_check_constraints = False | ||||||
|  |  | ||||||
|     def __init__(self, connection): |     def __init__(self, connection): | ||||||
| @@ -484,7 +485,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): | |||||||
|         """ |         """ | ||||||
|         Re-enable foreign key checks after they have been disabled. |         Re-enable foreign key checks after they have been disabled. | ||||||
|         """ |         """ | ||||||
|         self.cursor().execute('SET foreign_key_checks=1') |         # Override needs_rollback in case constraint_checks_disabled is | ||||||
|  |         # nested inside transaction.atomic. | ||||||
|  |         self.needs_rollback, needs_rollback = False, self.needs_rollback | ||||||
|  |         try: | ||||||
|  |             self.cursor().execute('SET foreign_key_checks=1') | ||||||
|  |         finally: | ||||||
|  |             self.needs_rollback = needs_rollback | ||||||
|  |  | ||||||
|     def check_constraints(self, table_names=None): |     def check_constraints(self, table_names=None): | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -96,6 +96,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|     has_bulk_insert = True |     has_bulk_insert = True | ||||||
|     supports_tablespaces = True |     supports_tablespaces = True | ||||||
|     supports_sequence_reset = False |     supports_sequence_reset = False | ||||||
|  |     atomic_transactions = False | ||||||
|     supports_combined_alters = False |     supports_combined_alters = False | ||||||
|     max_index_name_length = 30 |     max_index_name_length = 30 | ||||||
|     nulls_order_largest = True |     nulls_order_largest = True | ||||||
|   | |||||||
| @@ -105,6 +105,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|     supports_foreign_keys = False |     supports_foreign_keys = False | ||||||
|     supports_check_constraints = False |     supports_check_constraints = False | ||||||
|     autocommits_when_autocommit_is_off = True |     autocommits_when_autocommit_is_off = True | ||||||
|  |     atomic_transactions = False | ||||||
|     supports_paramstyle_pyformat = False |     supports_paramstyle_pyformat = False | ||||||
|     supports_sequence_reset = False |     supports_sequence_reset = False | ||||||
|  |  | ||||||
|   | |||||||
| @@ -19,14 +19,9 @@ class CursorWrapper(object): | |||||||
|         self.cursor = cursor |         self.cursor = cursor | ||||||
|         self.db = db |         self.db = db | ||||||
|  |  | ||||||
|     SET_DIRTY_ATTRS = frozenset(['execute', 'executemany', 'callproc']) |     WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset']) | ||||||
|     WRAP_ERROR_ATTRS = frozenset([ |  | ||||||
|         'callproc', 'close', 'execute', 'executemany', |  | ||||||
|         'fetchone', 'fetchmany', 'fetchall', 'nextset']) |  | ||||||
|  |  | ||||||
|     def __getattr__(self, attr): |     def __getattr__(self, attr): | ||||||
|         if attr in CursorWrapper.SET_DIRTY_ATTRS: |  | ||||||
|             self.db.set_dirty() |  | ||||||
|         cursor_attr = getattr(self.cursor, attr) |         cursor_attr = getattr(self.cursor, attr) | ||||||
|         if attr in CursorWrapper.WRAP_ERROR_ATTRS: |         if attr in CursorWrapper.WRAP_ERROR_ATTRS: | ||||||
|             return self.db.wrap_database_errors(cursor_attr) |             return self.db.wrap_database_errors(cursor_attr) | ||||||
| @@ -44,18 +39,42 @@ class CursorWrapper(object): | |||||||
|         # specific behavior. |         # specific behavior. | ||||||
|         self.close() |         self.close() | ||||||
|  |  | ||||||
|  |     # The following methods cannot be implemented in __getattr__, because the | ||||||
|  |     # code must run when the method is invoked, not just when it is accessed. | ||||||
|  |  | ||||||
|  |     def callproc(self, procname, params=None): | ||||||
|  |         self.db.validate_no_broken_transaction() | ||||||
|  |         self.db.set_dirty() | ||||||
|  |         with self.db.wrap_database_errors: | ||||||
|  |             if params is None: | ||||||
|  |                 return self.cursor.callproc(procname) | ||||||
|  |             else: | ||||||
|  |                 return self.cursor.callproc(procname, params) | ||||||
|  |  | ||||||
|  |     def execute(self, sql, params=None): | ||||||
|  |         self.db.validate_no_broken_transaction() | ||||||
|  |         self.db.set_dirty() | ||||||
|  |         with self.db.wrap_database_errors: | ||||||
|  |             if params is None: | ||||||
|  |                 return self.cursor.execute(sql) | ||||||
|  |             else: | ||||||
|  |                 return self.cursor.execute(sql, params) | ||||||
|  |  | ||||||
|  |     def executemany(self, sql, param_list): | ||||||
|  |         self.db.validate_no_broken_transaction() | ||||||
|  |         self.db.set_dirty() | ||||||
|  |         with self.db.wrap_database_errors: | ||||||
|  |             return self.cursor.executemany(sql, param_list) | ||||||
|  |  | ||||||
|  |  | ||||||
| class CursorDebugWrapper(CursorWrapper): | class CursorDebugWrapper(CursorWrapper): | ||||||
|  |  | ||||||
|  |     # XXX callproc isn't instrumented at this time. | ||||||
|  |  | ||||||
|     def execute(self, sql, params=None): |     def execute(self, sql, params=None): | ||||||
|         self.db.set_dirty() |  | ||||||
|         start = time() |         start = time() | ||||||
|         try: |         try: | ||||||
|             with self.db.wrap_database_errors: |             return super(CursorDebugWrapper, self).execute(sql, params) | ||||||
|                 if params is None: |  | ||||||
|                     # params default might be backend specific |  | ||||||
|                     return self.cursor.execute(sql) |  | ||||||
|                 return self.cursor.execute(sql, params) |  | ||||||
|         finally: |         finally: | ||||||
|             stop = time() |             stop = time() | ||||||
|             duration = stop - start |             duration = stop - start | ||||||
| @@ -69,11 +88,9 @@ class CursorDebugWrapper(CursorWrapper): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def executemany(self, sql, param_list): |     def executemany(self, sql, param_list): | ||||||
|         self.db.set_dirty() |  | ||||||
|         start = time() |         start = time() | ||||||
|         try: |         try: | ||||||
|             with self.db.wrap_database_errors: |             return super(CursorDebugWrapper, self).executemany(sql, param_list) | ||||||
|                 return self.cursor.executemany(sql, param_list) |  | ||||||
|         finally: |         finally: | ||||||
|             stop = time() |             stop = time() | ||||||
|             duration = stop - start |             duration = stop - start | ||||||
|   | |||||||
| @@ -436,14 +436,9 @@ class QuerySet(object): | |||||||
|         for k, v in six.iteritems(defaults): |         for k, v in six.iteritems(defaults): | ||||||
|             setattr(obj, k, v) |             setattr(obj, k, v) | ||||||
|  |  | ||||||
|         sid = transaction.savepoint(using=self.db) |         with transaction.atomic(using=self.db): | ||||||
|         try: |  | ||||||
|             obj.save(using=self.db) |             obj.save(using=self.db) | ||||||
|             transaction.savepoint_commit(sid, using=self.db) |         return obj, False | ||||||
|             return obj, False |  | ||||||
|         except DatabaseError: |  | ||||||
|             transaction.savepoint_rollback(sid, using=self.db) |  | ||||||
|             six.reraise(*sys.exc_info()) |  | ||||||
|  |  | ||||||
|     def _create_object_from_params(self, lookup, params): |     def _create_object_from_params(self, lookup, params): | ||||||
|         """ |         """ | ||||||
| @@ -451,19 +446,16 @@ class QuerySet(object): | |||||||
|         Used by get_or_create and update_or_create |         Used by get_or_create and update_or_create | ||||||
|         """ |         """ | ||||||
|         obj = self.model(**params) |         obj = self.model(**params) | ||||||
|         sid = transaction.savepoint(using=self.db) |  | ||||||
|         try: |         try: | ||||||
|             obj.save(force_insert=True, using=self.db) |             with transaction.atomic(using=self.db): | ||||||
|             transaction.savepoint_commit(sid, using=self.db) |                 obj.save(force_insert=True, using=self.db) | ||||||
|             return obj, True |             return obj, True | ||||||
|         except DatabaseError as e: |         except IntegrityError: | ||||||
|             transaction.savepoint_rollback(sid, using=self.db) |  | ||||||
|             exc_info = sys.exc_info() |             exc_info = sys.exc_info() | ||||||
|             if isinstance(e, IntegrityError): |             try: | ||||||
|                 try: |                 return self.get(**lookup), False | ||||||
|                     return self.get(**lookup), False |             except self.model.DoesNotExist: | ||||||
|                 except self.model.DoesNotExist: |                 pass | ||||||
|                     pass |  | ||||||
|             six.reraise(*exc_info) |             six.reraise(*exc_info) | ||||||
|  |  | ||||||
|     def _extract_model_params(self, defaults, **kwargs): |     def _extract_model_params(self, defaults, **kwargs): | ||||||
|   | |||||||
| @@ -16,14 +16,15 @@ import warnings | |||||||
|  |  | ||||||
| from functools import wraps | from functools import wraps | ||||||
|  |  | ||||||
| from django.db import connections, DatabaseError, DEFAULT_DB_ALIAS | from django.db import ( | ||||||
|  |         connections, DEFAULT_DB_ALIAS, | ||||||
|  |         DatabaseError, ProgrammingError) | ||||||
| from django.utils.decorators import available_attrs | from django.utils.decorators import available_attrs | ||||||
|  |  | ||||||
|  |  | ||||||
| class TransactionManagementError(Exception): | class TransactionManagementError(ProgrammingError): | ||||||
|     """ |     """ | ||||||
|     This exception is thrown when something bad happens with transaction |     This exception is thrown when transaction management is used improperly. | ||||||
|     management. |  | ||||||
|     """ |     """ | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|   | |||||||
| @@ -163,20 +163,31 @@ Django provides a single API to control database transactions. | |||||||
|     called, so the exception handler can also operate on the database if |     called, so the exception handler can also operate on the database if | ||||||
|     necessary. |     necessary. | ||||||
|  |  | ||||||
|     .. admonition:: Don't catch database exceptions inside ``atomic``! |     .. admonition:: Avoid catching exceptions inside ``atomic``! | ||||||
|  |  | ||||||
|         If you catch :exc:`~django.db.DatabaseError` or a subclass such as |         When exiting an ``atomic`` block, Django looks at whether it's exited | ||||||
|         :exc:`~django.db.IntegrityError` inside an ``atomic`` block, you will |         normally or with an exception to determine whether to commit or roll | ||||||
|         hide from Django the fact that an error has occurred and that the |         back. If you catch and handle exceptions inside an ``atomic`` block, | ||||||
|         transaction is broken. At this point, Django's behavior is unspecified |         you may hide from Django the fact that a problem has happened. This | ||||||
|         and database-dependent. It will usually result in a rollback, which |         can result in unexpected behavior. | ||||||
|         may break your expectations, since you caught the exception. |  | ||||||
|  |         This is mostly a concern for :exc:`~django.db.DatabaseError` and its | ||||||
|  |         subclasses such as :exc:`~django.db.IntegrityError`. After such an | ||||||
|  |         error, the transaction is broken and Django will perform a rollback at | ||||||
|  |         the end of the ``atomic`` block. If you attempt to run database | ||||||
|  |         queries before the rollback happens, Django will raise a | ||||||
|  |         :class:`~django.db.transaction.TransactionManagementError`. You may | ||||||
|  |         also encounter this behavior when an ORM-related signal handler raises | ||||||
|  |         an exception. | ||||||
|  |  | ||||||
|         The correct way to catch database errors is around an ``atomic`` block |         The correct way to catch database errors is around an ``atomic`` block | ||||||
|         as shown above. If necessary, add an extra ``atomic`` block for this |         as shown above. If necessary, add an extra ``atomic`` block for this | ||||||
|         purpose -- it's cheap! This pattern is useful to delimit explicitly |         purpose. This pattern has another advantage: it delimits explicitly | ||||||
|         which operations will be rolled back if an exception occurs. |         which operations will be rolled back if an exception occurs. | ||||||
|  |  | ||||||
|  |         If you catch exceptions raised by raw SQL queries, Django's behavior | ||||||
|  |         is unspecified and database-dependent. | ||||||
|  |  | ||||||
|     In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting |     In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting | ||||||
|     to commit, roll back, or change the autocommit state of the database |     to commit, roll back, or change the autocommit state of the database | ||||||
|     connection within an ``atomic`` block will raise an exception. |     connection within an ``atomic`` block will raise an exception. | ||||||
|   | |||||||
| @@ -149,11 +149,9 @@ class CustomPKTests(TestCase): | |||||||
|         Employee.objects.create( |         Employee.objects.create( | ||||||
|             employee_code=123, first_name="Frank", last_name="Jones" |             employee_code=123, first_name="Frank", last_name="Jones" | ||||||
|         ) |         ) | ||||||
|         sid = transaction.savepoint() |         with self.assertRaises(IntegrityError): | ||||||
|         self.assertRaises(IntegrityError, |             with transaction.atomic(): | ||||||
|             Employee.objects.create, employee_code=123, first_name="Fred", last_name="Jones" |                 Employee.objects.create(employee_code=123, first_name="Fred", last_name="Jones") | ||||||
|         ) |  | ||||||
|         transaction.savepoint_rollback(sid) |  | ||||||
|  |  | ||||||
|     def test_custom_field_pk(self): |     def test_custom_field_pk(self): | ||||||
|         # Regression for #10785 -- Custom fields can be used for primary keys. |         # Regression for #10785 -- Custom fields can be used for primary keys. | ||||||
| @@ -175,8 +173,6 @@ class CustomPKTests(TestCase): | |||||||
|     def test_required_pk(self): |     def test_required_pk(self): | ||||||
|         # The primary key must be specified, so an error is raised if you |         # The primary key must be specified, so an error is raised if you | ||||||
|         # try to create an object without it. |         # try to create an object without it. | ||||||
|         sid = transaction.savepoint() |         with self.assertRaises(IntegrityError): | ||||||
|         self.assertRaises(IntegrityError, |             with transaction.atomic(): | ||||||
|             Employee.objects.create, first_name="Tom", last_name="Smith" |                 Employee.objects.create(first_name="Tom", last_name="Smith") | ||||||
|         ) |  | ||||||
|         transaction.savepoint_rollback(sid) |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ from __future__ import unicode_literals | |||||||
|  |  | ||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db.models import F | from django.db.models import F | ||||||
|  | from django.db import transaction | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.utils import six | from django.utils import six | ||||||
|  |  | ||||||
| @@ -185,11 +186,11 @@ class ExpressionsTests(TestCase): | |||||||
|             "foo", |             "foo", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         self.assertRaises(FieldError, |         with transaction.atomic(): | ||||||
|             lambda: Company.objects.exclude( |             with self.assertRaises(FieldError): | ||||||
|                 ceo__firstname=F('point_of_contact__firstname') |                 Company.objects.exclude( | ||||||
|             ).update(name=F('point_of_contact__lastname')) |                     ceo__firstname=F('point_of_contact__firstname') | ||||||
|         ) |                 ).update(name=F('point_of_contact__lastname')) | ||||||
|  |  | ||||||
|         # F expressions can be used to update attributes on single objects |         # F expressions can be used to update attributes on single objects | ||||||
|         test_gmbh = Company.objects.get(name="Test GmbH") |         test_gmbh = Company.objects.get(name="Test GmbH") | ||||||
|   | |||||||
| @@ -21,24 +21,29 @@ class ForceTests(TestCase): | |||||||
|         # Won't work because force_update and force_insert are mutually |         # Won't work because force_update and force_insert are mutually | ||||||
|         # exclusive |         # exclusive | ||||||
|         c.value = 4 |         c.value = 4 | ||||||
|         self.assertRaises(ValueError, c.save, force_insert=True, force_update=True) |         with self.assertRaises(ValueError): | ||||||
|  |             c.save(force_insert=True, force_update=True) | ||||||
|  |  | ||||||
|         # Try to update something that doesn't have a primary key in the first |         # Try to update something that doesn't have a primary key in the first | ||||||
|         # place. |         # place. | ||||||
|         c1 = Counter(name="two", value=2) |         c1 = Counter(name="two", value=2) | ||||||
|         self.assertRaises(ValueError, c1.save, force_update=True) |         with self.assertRaises(ValueError): | ||||||
|  |             with transaction.atomic(): | ||||||
|  |                 c1.save(force_update=True) | ||||||
|         c1.save(force_insert=True) |         c1.save(force_insert=True) | ||||||
|  |  | ||||||
|         # Won't work because we can't insert a pk of the same value. |         # Won't work because we can't insert a pk of the same value. | ||||||
|         sid = transaction.savepoint() |  | ||||||
|         c.value = 5 |         c.value = 5 | ||||||
|         self.assertRaises(IntegrityError, c.save, force_insert=True) |         with self.assertRaises(IntegrityError): | ||||||
|         transaction.savepoint_rollback(sid) |             with transaction.atomic(): | ||||||
|  |                 c.save(force_insert=True) | ||||||
|  |  | ||||||
|         # Trying to update should still fail, even with manual primary keys, if |         # Trying to update should still fail, even with manual primary keys, if | ||||||
|         # the data isn't in the database already. |         # the data isn't in the database already. | ||||||
|         obj = WithCustomPK(name=1, value=1) |         obj = WithCustomPK(name=1, value=1) | ||||||
|         self.assertRaises(DatabaseError, obj.save, force_update=True) |         with self.assertRaises(DatabaseError): | ||||||
|  |             with transaction.atomic(): | ||||||
|  |                 obj.save(force_update=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class InheritanceTests(TestCase): | class InheritanceTests(TestCase): | ||||||
|   | |||||||
| @@ -118,7 +118,7 @@ class OneToOneTests(TestCase): | |||||||
|         self.assertEqual(repr(o1.multimodel), '<MultiModel: Multimodel x1>') |         self.assertEqual(repr(o1.multimodel), '<MultiModel: Multimodel x1>') | ||||||
|         # This will fail because each one-to-one field must be unique (and |         # This will fail because each one-to-one field must be unique (and | ||||||
|         # link2=o1 was used for x1, above). |         # link2=o1 was used for x1, above). | ||||||
|         sid = transaction.savepoint() |  | ||||||
|         mm = MultiModel(link1=self.p2, link2=o1, name="x1") |         mm = MultiModel(link1=self.p2, link2=o1, name="x1") | ||||||
|         self.assertRaises(IntegrityError, mm.save) |         with self.assertRaises(IntegrityError): | ||||||
|         transaction.savepoint_rollback(sid) |             with transaction.atomic(): | ||||||
|  |                 mm.save() | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ import sys | |||||||
| from unittest import skipIf, skipUnless | from unittest import skipIf, skipUnless | ||||||
|  |  | ||||||
| from django.db import connection, transaction, DatabaseError, IntegrityError | from django.db import connection, transaction, DatabaseError, IntegrityError | ||||||
| from django.test import TransactionTestCase, skipUnlessDBFeature | from django.test import TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature | ||||||
| from django.test.utils import IgnoreDeprecationWarningsMixin | from django.test.utils import IgnoreDeprecationWarningsMixin | ||||||
| from django.utils import six | from django.utils import six | ||||||
|  |  | ||||||
| @@ -204,10 +204,10 @@ class AtomicTests(TransactionTestCase): | |||||||
|                 with transaction.atomic(savepoint=False): |                 with transaction.atomic(savepoint=False): | ||||||
|                     connection.cursor().execute( |                     connection.cursor().execute( | ||||||
|                             "SELECT no_such_col FROM transactions_reporter") |                             "SELECT no_such_col FROM transactions_reporter") | ||||||
|             transaction.savepoint_rollback(sid) |             # prevent atomic from rolling back since we're recovering manually | ||||||
|             # atomic block should rollback, but prevent it, as we just did it. |  | ||||||
|             self.assertTrue(transaction.get_rollback()) |             self.assertTrue(transaction.get_rollback()) | ||||||
|             transaction.set_rollback(False) |             transaction.set_rollback(False) | ||||||
|  |             transaction.savepoint_rollback(sid) | ||||||
|         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>']) |         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>']) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -267,11 +267,19 @@ class AtomicMergeTests(TransactionTestCase): | |||||||
|                     with transaction.atomic(savepoint=False): |                     with transaction.atomic(savepoint=False): | ||||||
|                         Reporter.objects.create(first_name="Calculus") |                         Reporter.objects.create(first_name="Calculus") | ||||||
|                         raise Exception("Oops, that's his last name") |                         raise Exception("Oops, that's his last name") | ||||||
|                 # It wasn't possible to roll back |                 # The third insert couldn't be roll back. Temporarily mark the | ||||||
|  |                 # connection as not needing rollback to check it. | ||||||
|  |                 self.assertTrue(transaction.get_rollback()) | ||||||
|  |                 transaction.set_rollback(False) | ||||||
|                 self.assertEqual(Reporter.objects.count(), 3) |                 self.assertEqual(Reporter.objects.count(), 3) | ||||||
|             # It wasn't possible to roll back |                 transaction.set_rollback(True) | ||||||
|  |             # The second insert couldn't be roll back. Temporarily mark the | ||||||
|  |             # connection as not needing rollback to check it. | ||||||
|  |             self.assertTrue(transaction.get_rollback()) | ||||||
|  |             transaction.set_rollback(False) | ||||||
|             self.assertEqual(Reporter.objects.count(), 3) |             self.assertEqual(Reporter.objects.count(), 3) | ||||||
|         # The outer block must roll back |             transaction.set_rollback(True) | ||||||
|  |         # The first block has a savepoint and must roll back. | ||||||
|         self.assertQuerysetEqual(Reporter.objects.all(), []) |         self.assertQuerysetEqual(Reporter.objects.all(), []) | ||||||
|  |  | ||||||
|     def test_merged_inner_savepoint_rollback(self): |     def test_merged_inner_savepoint_rollback(self): | ||||||
| @@ -283,36 +291,22 @@ class AtomicMergeTests(TransactionTestCase): | |||||||
|                     with transaction.atomic(savepoint=False): |                     with transaction.atomic(savepoint=False): | ||||||
|                         Reporter.objects.create(first_name="Calculus") |                         Reporter.objects.create(first_name="Calculus") | ||||||
|                         raise Exception("Oops, that's his last name") |                         raise Exception("Oops, that's his last name") | ||||||
|                 # It wasn't possible to roll back |                 # The third insert couldn't be roll back. Temporarily mark the | ||||||
|  |                 # connection as not needing rollback to check it. | ||||||
|  |                 self.assertTrue(transaction.get_rollback()) | ||||||
|  |                 transaction.set_rollback(False) | ||||||
|                 self.assertEqual(Reporter.objects.count(), 3) |                 self.assertEqual(Reporter.objects.count(), 3) | ||||||
|             # The first block with a savepoint must roll back |                 transaction.set_rollback(True) | ||||||
|  |             # The second block has a savepoint and must roll back. | ||||||
|             self.assertEqual(Reporter.objects.count(), 1) |             self.assertEqual(Reporter.objects.count(), 1) | ||||||
|         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>']) |         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>']) | ||||||
|  |  | ||||||
|     def test_merged_outer_rollback_after_inner_failure_and_inner_success(self): |  | ||||||
|         with transaction.atomic(): |  | ||||||
|             Reporter.objects.create(first_name="Tintin") |  | ||||||
|             # Inner block without a savepoint fails |  | ||||||
|             with six.assertRaisesRegex(self, Exception, "Oops"): |  | ||||||
|                 with transaction.atomic(savepoint=False): |  | ||||||
|                     Reporter.objects.create(first_name="Haddock") |  | ||||||
|                     raise Exception("Oops, that's his last name") |  | ||||||
|             # It wasn't possible to roll back |  | ||||||
|             self.assertEqual(Reporter.objects.count(), 2) |  | ||||||
|             # Inner block with a savepoint succeeds |  | ||||||
|             with transaction.atomic(savepoint=False): |  | ||||||
|                 Reporter.objects.create(first_name="Archibald", last_name="Haddock") |  | ||||||
|             # It still wasn't possible to roll back |  | ||||||
|             self.assertEqual(Reporter.objects.count(), 3) |  | ||||||
|         # The outer block must rollback |  | ||||||
|         self.assertQuerysetEqual(Reporter.objects.all(), []) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @skipUnless(connection.features.uses_savepoints, | @skipUnless(connection.features.uses_savepoints, | ||||||
|         "'atomic' requires transactions and savepoints.") |         "'atomic' requires transactions and savepoints.") | ||||||
| class AtomicErrorsTests(TransactionTestCase): | class AtomicErrorsTests(TransactionTestCase): | ||||||
|  |  | ||||||
|     available_apps = [] |     available_apps = ['transactions'] | ||||||
|  |  | ||||||
|     def test_atomic_prevents_setting_autocommit(self): |     def test_atomic_prevents_setting_autocommit(self): | ||||||
|         autocommit = transaction.get_autocommit() |         autocommit = transaction.get_autocommit() | ||||||
| @@ -336,6 +330,29 @@ class AtomicErrorsTests(TransactionTestCase): | |||||||
|             with self.assertRaises(transaction.TransactionManagementError): |             with self.assertRaises(transaction.TransactionManagementError): | ||||||
|                 transaction.leave_transaction_management() |                 transaction.leave_transaction_management() | ||||||
|  |  | ||||||
|  |     def test_atomic_prevents_queries_in_broken_transaction(self): | ||||||
|  |         r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") | ||||||
|  |         with transaction.atomic(): | ||||||
|  |             r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id) | ||||||
|  |             with self.assertRaises(IntegrityError): | ||||||
|  |                 r2.save(force_insert=True) | ||||||
|  |             # The transaction is marked as needing rollback. | ||||||
|  |             with self.assertRaises(transaction.TransactionManagementError): | ||||||
|  |                 r2.save(force_update=True) | ||||||
|  |         self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Haddock") | ||||||
|  |  | ||||||
|  |     @skipIfDBFeature('atomic_transactions') | ||||||
|  |     def test_atomic_allows_queries_after_fixing_transaction(self): | ||||||
|  |         r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") | ||||||
|  |         with transaction.atomic(): | ||||||
|  |             r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id) | ||||||
|  |             with self.assertRaises(IntegrityError): | ||||||
|  |                 r2.save(force_insert=True) | ||||||
|  |             # Mark the transaction as no longer needing rollback. | ||||||
|  |             transaction.set_rollback(False) | ||||||
|  |             r2.save(force_update=True) | ||||||
|  |         self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Calculus") | ||||||
|  |  | ||||||
|  |  | ||||||
| class AtomicMiscTests(TransactionTestCase): | class AtomicMiscTests(TransactionTestCase): | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user