diff --git a/django/test/runner.py b/django/test/runner.py index d0367ba71e..ecae164d7f 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -1,6 +1,7 @@ import argparse import ctypes import faulthandler +import functools import hashlib import io import itertools @@ -485,6 +486,16 @@ def _init_worker( ) +def _safe_init_worker(init_worker, counter, *args, **kwargs): + try: + init_worker(counter, *args, **kwargs) + except Exception: + with counter.get_lock(): + # Set a value that will not increment above zero any time soon. + counter.value = -1000 + raise + + def _run_subsuite(args): """ Run a suite of tests with a RemoteTestRunner and return a RemoteTestResult. @@ -558,7 +569,7 @@ class ParallelTestSuite(unittest.TestSuite): counter = multiprocessing.Value(ctypes.c_int, 0) pool = multiprocessing.Pool( processes=self.processes, - initializer=self.init_worker.__func__, + initializer=functools.partial(_safe_init_worker, self.init_worker.__func__), initargs=[ counter, self.initial_settings, @@ -585,7 +596,11 @@ class ParallelTestSuite(unittest.TestSuite): try: subsuite_index, events = test_results.next(timeout=0.1) - except multiprocessing.TimeoutError: + except multiprocessing.TimeoutError as err: + if counter.value < 0: + err.add_note("ERROR: _init_worker failed, see prior traceback") + pool.close() + raise continue except StopIteration: pool.close() diff --git a/tests/test_runner/tests.py b/tests/test_runner/tests.py index 2c1fc3ad68..9173fa5d36 100644 --- a/tests/test_runner/tests.py +++ b/tests/test_runner/tests.py @@ -3,6 +3,7 @@ Tests for django test runner """ import collections.abc +import functools import multiprocessing import os import sys @@ -738,8 +739,10 @@ class TestRunnerInitializerTests(SimpleTestCase): "test_runner_apps.simple.tests", ] ) - # Initializer must be a function. - self.assertIs(mocked_pool.call_args.kwargs["initializer"], _init_worker) + # Initializer must be a partial function binding _init_worker. + initializer = mocked_pool.call_args.kwargs["initializer"] + self.assertIsInstance(initializer, functools.partial) + self.assertIs(initializer.args[0], _init_worker) initargs = mocked_pool.call_args.kwargs["initargs"] self.assertEqual(len(initargs), 7) self.assertEqual(initargs[5], True) # debug_mode