From bbb9ef3c62674e94ad8e6556933b112d84891f3d Mon Sep 17 00:00:00 2001 From: Natalia <124304+nessita@users.noreply.github.com> Date: Wed, 31 Jan 2024 11:12:03 -0300 Subject: [PATCH] [5.0.x] Refs #35059 -- Made asgi tests' SignalHandler helper class re-usable by other tests. Backport of a43d75e81da783fda08bf8d3493252e3676d11ea from main --- tests/asgi/tests.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index ced24c658e..50d15d2306 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -26,6 +26,17 @@ from .urls import sync_waiter, test_filename TEST_STATIC_ROOT = Path(__file__).parent / "project" / "static" +class SignalHandler: + """Helper class to track threads and kwargs when signals are dispatched.""" + + def __init__(self): + super().__init__() + self.calls = [] + + def __call__(self, signal, **kwargs): + self.calls.append({"thread": threading.current_thread(), "kwargs": kwargs}) + + @override_settings(ROOT_URLCONF="asgi.urls") class ASGITest(SimpleTestCase): async_request_factory = AsyncRequestFactory() @@ -312,17 +323,12 @@ class ASGITest(SimpleTestCase): self.assertEqual(response_body["body"], b"") async def test_request_lifecycle_signals_dispatched_with_thread_sensitive(self): - class SignalHandler: - """Track threads handler is dispatched on.""" - - threads = [] - - def __call__(self, **kwargs): - self.threads.append(threading.current_thread()) - + # Track request_started and request_finished signals. signal_handler = SignalHandler() request_started.connect(signal_handler) + self.addCleanup(request_started.disconnect, signal_handler) request_finished.connect(signal_handler) + self.addCleanup(request_finished.disconnect, signal_handler) # Perform a basic request. application = get_asgi_application() @@ -339,10 +345,9 @@ class ASGITest(SimpleTestCase): await communicator.wait() # AsyncToSync should have executed the signals in the same thread. - request_started_thread, request_finished_thread = signal_handler.threads - self.assertEqual(request_started_thread, request_finished_thread) - request_started.disconnect(signal_handler) - request_finished.disconnect(signal_handler) + self.assertEqual(len(signal_handler.calls), 2) + request_started_call, request_finished_call = signal_handler.calls + self.assertEqual(request_started_call["thread"], request_finished_call["thread"]) async def test_concurrent_async_uses_multiple_thread_pools(self): sync_waiter.active_threads.clear()