mirror of
				https://github.com/django/django.git
				synced 2025-10-26 07:06:08 +00:00 
			
		
		
		
	Fixed #30171 -- Fixed DatabaseError in servers tests.
Made DatabaseWrapper thread sharing logic reentrant. Used a reference
counting like scheme to allow nested uses.
The error appeared after 8c775391b7.
			
			
This commit is contained in:
		| @@ -1,4 +1,5 @@ | |||||||
| import copy | import copy | ||||||
|  | import threading | ||||||
| import time | import time | ||||||
| import warnings | import warnings | ||||||
| from collections import deque | from collections import deque | ||||||
| @@ -43,8 +44,7 @@ class BaseDatabaseWrapper: | |||||||
|  |  | ||||||
|     queries_limit = 9000 |     queries_limit = 9000 | ||||||
|  |  | ||||||
|     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, |     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): | ||||||
|                  allow_thread_sharing=False): |  | ||||||
|         # Connection related attributes. |         # Connection related attributes. | ||||||
|         # The underlying database connection. |         # The underlying database connection. | ||||||
|         self.connection = None |         self.connection = None | ||||||
| @@ -80,7 +80,8 @@ class BaseDatabaseWrapper: | |||||||
|         self.errors_occurred = False |         self.errors_occurred = False | ||||||
|  |  | ||||||
|         # Thread-safety related attributes. |         # Thread-safety related attributes. | ||||||
|         self.allow_thread_sharing = allow_thread_sharing |         self._thread_sharing_lock = threading.Lock() | ||||||
|  |         self._thread_sharing_count = 0 | ||||||
|         self._thread_ident = _thread.get_ident() |         self._thread_ident = _thread.get_ident() | ||||||
|  |  | ||||||
|         # A list of no-argument functions to run when the transaction commits. |         # A list of no-argument functions to run when the transaction commits. | ||||||
| @@ -515,12 +516,27 @@ class BaseDatabaseWrapper: | |||||||
|  |  | ||||||
|     # ##### Thread safety handling ##### |     # ##### Thread safety handling ##### | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def allow_thread_sharing(self): | ||||||
|  |         with self._thread_sharing_lock: | ||||||
|  |             return self._thread_sharing_count > 0 | ||||||
|  |  | ||||||
|  |     def inc_thread_sharing(self): | ||||||
|  |         with self._thread_sharing_lock: | ||||||
|  |             self._thread_sharing_count += 1 | ||||||
|  |  | ||||||
|  |     def dec_thread_sharing(self): | ||||||
|  |         with self._thread_sharing_lock: | ||||||
|  |             if self._thread_sharing_count <= 0: | ||||||
|  |                 raise RuntimeError('Cannot decrement the thread sharing count below zero.') | ||||||
|  |             self._thread_sharing_count -= 1 | ||||||
|  |  | ||||||
|     def validate_thread_sharing(self): |     def validate_thread_sharing(self): | ||||||
|         """ |         """ | ||||||
|         Validate that the connection isn't accessed by another thread than the |         Validate that the connection isn't accessed by another thread than the | ||||||
|         one which originally created it, unless the connection was explicitly |         one which originally created it, unless the connection was explicitly | ||||||
|         authorized to be shared between threads (via the `allow_thread_sharing` |         authorized to be shared between threads (via the `inc_thread_sharing()` | ||||||
|         property). Raise an exception if the validation fails. |         method). Raise an exception if the validation fails. | ||||||
|         """ |         """ | ||||||
|         if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()): |         if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()): | ||||||
|             raise DatabaseError( |             raise DatabaseError( | ||||||
| @@ -589,11 +605,7 @@ class BaseDatabaseWrapper: | |||||||
|         potential child threads while (or after) the test database is destroyed. |         potential child threads while (or after) the test database is destroyed. | ||||||
|         Refs #10868, #17786, #16969. |         Refs #10868, #17786, #16969. | ||||||
|         """ |         """ | ||||||
|         return self.__class__( |         return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS) | ||||||
|             {**self.settings_dict, 'NAME': None}, |  | ||||||
|             alias=NO_DB_ALIAS, |  | ||||||
|             allow_thread_sharing=False, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def schema_editor(self, *args, **kwargs): |     def schema_editor(self, *args, **kwargs): | ||||||
|         """ |         """ | ||||||
| @@ -635,7 +647,7 @@ class BaseDatabaseWrapper: | |||||||
|         finally: |         finally: | ||||||
|             self.execute_wrappers.pop() |             self.execute_wrappers.pop() | ||||||
|  |  | ||||||
|     def copy(self, alias=None, allow_thread_sharing=None): |     def copy(self, alias=None): | ||||||
|         """ |         """ | ||||||
|         Return a copy of this connection. |         Return a copy of this connection. | ||||||
|  |  | ||||||
| @@ -644,6 +656,4 @@ class BaseDatabaseWrapper: | |||||||
|         settings_dict = copy.deepcopy(self.settings_dict) |         settings_dict = copy.deepcopy(self.settings_dict) | ||||||
|         if alias is None: |         if alias is None: | ||||||
|             alias = self.alias |             alias = self.alias | ||||||
|         if allow_thread_sharing is None: |         return type(self)(settings_dict, alias) | ||||||
|             allow_thread_sharing = self.allow_thread_sharing |  | ||||||
|         return type(self)(settings_dict, alias, allow_thread_sharing) |  | ||||||
|   | |||||||
| @@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): | |||||||
|                     return self.__class__( |                     return self.__class__( | ||||||
|                         {**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, |                         {**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, | ||||||
|                         alias=self.alias, |                         alias=self.alias, | ||||||
|                         allow_thread_sharing=False, |  | ||||||
|                     ) |                     ) | ||||||
|         return nodb_connection |         return nodb_connection | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase): | |||||||
|             # the server thread. |             # the server thread. | ||||||
|             if conn.vendor == 'sqlite' and conn.is_in_memory_db(): |             if conn.vendor == 'sqlite' and conn.is_in_memory_db(): | ||||||
|                 # Explicitly enable thread-shareability for this connection |                 # Explicitly enable thread-shareability for this connection | ||||||
|                 conn.allow_thread_sharing = True |                 conn.inc_thread_sharing() | ||||||
|                 connections_override[conn.alias] = conn |                 connections_override[conn.alias] = conn | ||||||
|  |  | ||||||
|         cls._live_server_modified_settings = modify_settings( |         cls._live_server_modified_settings = modify_settings( | ||||||
| @@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase): | |||||||
|             # Terminate the live server's thread |             # Terminate the live server's thread | ||||||
|             cls.server_thread.terminate() |             cls.server_thread.terminate() | ||||||
|  |  | ||||||
|         # Restore sqlite in-memory database connections' non-shareability |             # Restore sqlite in-memory database connections' non-shareability. | ||||||
|         for conn in connections.all(): |             for conn in cls.server_thread.connections_override.values(): | ||||||
|             if conn.vendor == 'sqlite' and conn.is_in_memory_db(): |                 conn.dec_thread_sharing() | ||||||
|                 conn.allow_thread_sharing = False |  | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def tearDownClass(cls): |     def tearDownClass(cls): | ||||||
|   | |||||||
| @@ -286,6 +286,9 @@ backends. | |||||||
|   * ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``) |   * ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``) | ||||||
|   * ``_create_check_sql()`` and ``_delete_check_sql()`` |   * ``_create_check_sql()`` and ``_delete_check_sql()`` | ||||||
|  |  | ||||||
|  | * The third argument of ``DatabaseWrapper.__init__()``, | ||||||
|  |   ``allow_thread_sharing``, is removed. | ||||||
|  |  | ||||||
| Admin actions are no longer collected from base ``ModelAdmin`` classes | Admin actions are no longer collected from base ``ModelAdmin`` classes | ||||||
| ---------------------------------------------------------------------- | ---------------------------------------------------------------------- | ||||||
|  |  | ||||||
|   | |||||||
| @@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase): | |||||||
|             connection = connections[DEFAULT_DB_ALIAS] |             connection = connections[DEFAULT_DB_ALIAS] | ||||||
|             # Allow thread sharing so the connection can be closed by the |             # Allow thread sharing so the connection can be closed by the | ||||||
|             # main thread. |             # main thread. | ||||||
|             connection.allow_thread_sharing = True |             connection.inc_thread_sharing() | ||||||
|             connection.cursor() |             connection.cursor() | ||||||
|             connections_dict[id(connection)] = connection |             connections_dict[id(connection)] = connection | ||||||
|  |         try: | ||||||
|             for x in range(2): |             for x in range(2): | ||||||
|                 t = threading.Thread(target=runner) |                 t = threading.Thread(target=runner) | ||||||
|                 t.start() |                 t.start() | ||||||
|                 t.join() |                 t.join() | ||||||
|             # Each created connection got different inner connection. |             # Each created connection got different inner connection. | ||||||
|             self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) |             self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) | ||||||
|         # Finish by closing the connections opened by the other threads (the |         finally: | ||||||
|         # connection opened in the main thread will automatically be closed on |             # Finish by closing the connections opened by the other threads | ||||||
|         # teardown). |             # (the connection opened in the main thread will automatically be | ||||||
|  |             # closed on teardown). | ||||||
|             for conn in connections_dict.values(): |             for conn in connections_dict.values(): | ||||||
|                 if conn is not connection: |                 if conn is not connection: | ||||||
|  |                     if conn.allow_thread_sharing: | ||||||
|                         conn.close() |                         conn.close() | ||||||
|  |                         conn.dec_thread_sharing() | ||||||
|  |  | ||||||
|     def test_connections_thread_local(self): |     def test_connections_thread_local(self): | ||||||
|         """ |         """ | ||||||
| @@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase): | |||||||
|             for conn in connections.all(): |             for conn in connections.all(): | ||||||
|                 # Allow thread sharing so the connection can be closed by the |                 # Allow thread sharing so the connection can be closed by the | ||||||
|                 # main thread. |                 # main thread. | ||||||
|                 conn.allow_thread_sharing = True |                 conn.inc_thread_sharing() | ||||||
|                 connections_dict[id(conn)] = conn |                 connections_dict[id(conn)] = conn | ||||||
|  |         try: | ||||||
|             for x in range(2): |             for x in range(2): | ||||||
|                 t = threading.Thread(target=runner) |                 t = threading.Thread(target=runner) | ||||||
|                 t.start() |                 t.start() | ||||||
|                 t.join() |                 t.join() | ||||||
|             self.assertEqual(len(connections_dict), 6) |             self.assertEqual(len(connections_dict), 6) | ||||||
|         # Finish by closing the connections opened by the other threads (the |         finally: | ||||||
|         # connection opened in the main thread will automatically be closed on |             # Finish by closing the connections opened by the other threads | ||||||
|         # teardown). |             # (the connection opened in the main thread will automatically be | ||||||
|  |             # closed on teardown). | ||||||
|             for conn in connections_dict.values(): |             for conn in connections_dict.values(): | ||||||
|                 if conn is not connection: |                 if conn is not connection: | ||||||
|  |                     if conn.allow_thread_sharing: | ||||||
|                         conn.close() |                         conn.close() | ||||||
|  |                         conn.dec_thread_sharing() | ||||||
|  |  | ||||||
|     def test_pass_connection_between_threads(self): |     def test_pass_connection_between_threads(self): | ||||||
|         """ |         """ | ||||||
| @@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase): | |||||||
|             t.start() |             t.start() | ||||||
|             t.join() |             t.join() | ||||||
|  |  | ||||||
|         # Without touching allow_thread_sharing, which should be False by default. |         # Without touching thread sharing, which should be False by default. | ||||||
|         exceptions = [] |         exceptions = [] | ||||||
|         do_thread() |         do_thread() | ||||||
|         # Forbidden! |         # Forbidden! | ||||||
|         self.assertIsInstance(exceptions[0], DatabaseError) |         self.assertIsInstance(exceptions[0], DatabaseError) | ||||||
|  |  | ||||||
|         # If explicitly setting allow_thread_sharing to False |         # After calling inc_thread_sharing() on the connection. | ||||||
|         connections['default'].allow_thread_sharing = False |         connections['default'].inc_thread_sharing() | ||||||
|         exceptions = [] |         try: | ||||||
|         do_thread() |  | ||||||
|         # Forbidden! |  | ||||||
|         self.assertIsInstance(exceptions[0], DatabaseError) |  | ||||||
|  |  | ||||||
|         # If explicitly setting allow_thread_sharing to True |  | ||||||
|         connections['default'].allow_thread_sharing = True |  | ||||||
|             exceptions = [] |             exceptions = [] | ||||||
|             do_thread() |             do_thread() | ||||||
|             # All good |             # All good | ||||||
|             self.assertEqual(exceptions, []) |             self.assertEqual(exceptions, []) | ||||||
|  |         finally: | ||||||
|  |             connections['default'].dec_thread_sharing() | ||||||
|  |  | ||||||
|     def test_closing_non_shared_connections(self): |     def test_closing_non_shared_connections(self): | ||||||
|         """ |         """ | ||||||
| @@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase): | |||||||
|                 except DatabaseError as e: |                 except DatabaseError as e: | ||||||
|                     exceptions.add(e) |                     exceptions.add(e) | ||||||
|             # Enable thread sharing |             # Enable thread sharing | ||||||
|             connections['default'].allow_thread_sharing = True |             connections['default'].inc_thread_sharing() | ||||||
|  |             try: | ||||||
|                 t2 = threading.Thread(target=runner2, args=[connections['default']]) |                 t2 = threading.Thread(target=runner2, args=[connections['default']]) | ||||||
|                 t2.start() |                 t2.start() | ||||||
|                 t2.join() |                 t2.join() | ||||||
|  |             finally: | ||||||
|  |                 connections['default'].dec_thread_sharing() | ||||||
|         t1 = threading.Thread(target=runner1) |         t1 = threading.Thread(target=runner1) | ||||||
|         t1.start() |         t1.start() | ||||||
|         t1.join() |         t1.join() | ||||||
|         # No exception was raised |         # No exception was raised | ||||||
|         self.assertEqual(len(exceptions), 0) |         self.assertEqual(len(exceptions), 0) | ||||||
|  |  | ||||||
|  |     def test_thread_sharing_count(self): | ||||||
|  |         self.assertIs(connection.allow_thread_sharing, False) | ||||||
|  |         connection.inc_thread_sharing() | ||||||
|  |         self.assertIs(connection.allow_thread_sharing, True) | ||||||
|  |         connection.inc_thread_sharing() | ||||||
|  |         self.assertIs(connection.allow_thread_sharing, True) | ||||||
|  |         connection.dec_thread_sharing() | ||||||
|  |         self.assertIs(connection.allow_thread_sharing, True) | ||||||
|  |         connection.dec_thread_sharing() | ||||||
|  |         self.assertIs(connection.allow_thread_sharing, False) | ||||||
|  |         msg = 'Cannot decrement the thread sharing count below zero.' | ||||||
|  |         with self.assertRaisesMessage(RuntimeError, msg): | ||||||
|  |             connection.dec_thread_sharing() | ||||||
|  |  | ||||||
|  |  | ||||||
| class MySQLPKZeroTests(TestCase): | class MySQLPKZeroTests(TestCase): | ||||||
|     """ |     """ | ||||||
|   | |||||||
| @@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase): | |||||||
|         # Pass a connection to the thread to check they are being closed. |         # Pass a connection to the thread to check they are being closed. | ||||||
|         connections_override = {DEFAULT_DB_ALIAS: conn} |         connections_override = {DEFAULT_DB_ALIAS: conn} | ||||||
|  |  | ||||||
|         saved_sharing = conn.allow_thread_sharing |         conn.inc_thread_sharing() | ||||||
|         try: |         try: | ||||||
|             conn.allow_thread_sharing = True |  | ||||||
|             self.assertTrue(conn.is_usable()) |             self.assertTrue(conn.is_usable()) | ||||||
|             self.run_live_server_thread(connections_override) |             self.run_live_server_thread(connections_override) | ||||||
|             self.assertFalse(conn.is_usable()) |             self.assertFalse(conn.is_usable()) | ||||||
|         finally: |         finally: | ||||||
|             conn.allow_thread_sharing = saved_sharing |             conn.dec_thread_sharing() | ||||||
|   | |||||||
| @@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase): | |||||||
|             # app without having set the required STATIC_URL setting.") |             # app without having set the required STATIC_URL setting.") | ||||||
|             pass |             pass | ||||||
|         finally: |         finally: | ||||||
|  |             # Use del to avoid decrementing the database thread sharing count a | ||||||
|  |             # second time. | ||||||
|  |             del cls.server_thread | ||||||
|             super().tearDownClass() |             super().tearDownClass() | ||||||
|  |  | ||||||
|     def test_test_test(self): |     def test_test_test(self): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user