diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index 3aeade4c05..0fbb586f85 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -26,6 +26,17 @@ from .urls import sync_waiter, test_filename TEST_STATIC_ROOT = Path(__file__).parent / "project" / "static" +class SignalHandler: + """Helper class to track threads and kwargs when signals are dispatched.""" + + def __init__(self): + super().__init__() + self.calls = [] + + def __call__(self, signal, **kwargs): + self.calls.append({"thread": threading.current_thread(), "kwargs": kwargs}) + + @override_settings(ROOT_URLCONF="asgi.urls") class ASGITest(SimpleTestCase): async_request_factory = AsyncRequestFactory() @@ -310,17 +321,12 @@ class ASGITest(SimpleTestCase): self.assertEqual(response_body["body"], b"") async def test_request_lifecycle_signals_dispatched_with_thread_sensitive(self): - class SignalHandler: - """Track threads handler is dispatched on.""" - - threads = [] - - def __call__(self, **kwargs): - self.threads.append(threading.current_thread()) - + # Track request_started and request_finished signals. signal_handler = SignalHandler() request_started.connect(signal_handler) + self.addCleanup(request_started.disconnect, signal_handler) request_finished.connect(signal_handler) + self.addCleanup(request_finished.disconnect, signal_handler) # Perform a basic request. application = get_asgi_application() @@ -337,10 +343,9 @@ class ASGITest(SimpleTestCase): await communicator.wait() # AsyncToSync should have executed the signals in the same thread. - request_started_thread, request_finished_thread = signal_handler.threads - self.assertEqual(request_started_thread, request_finished_thread) - request_started.disconnect(signal_handler) - request_finished.disconnect(signal_handler) + self.assertEqual(len(signal_handler.calls), 2) + request_started_call, request_finished_call = signal_handler.calls + self.assertEqual(request_started_call["thread"], request_finished_call["thread"]) async def test_concurrent_async_uses_multiple_thread_pools(self): sync_waiter.active_threads.clear()