1
0
mirror of https://github.com/django/django.git synced 2025-08-21 01:09:13 +00:00

Fixed #36531 -- Added forkserver support to parallel test runner.

This commit is contained in:
Mariusz Felisiak 2025-07-28 23:21:17 +02:00 committed by nessita
parent d4dd3e503c
commit d55979334d
5 changed files with 29 additions and 14 deletions

View File

@ -62,7 +62,7 @@ class DatabaseCreation(BaseDatabaseCreation):
start_method = multiprocessing.get_start_method() start_method = multiprocessing.get_start_method()
if start_method == "fork": if start_method == "fork":
return orig_settings_dict return orig_settings_dict
if start_method == "spawn": if start_method in {"forkserver", "spawn"}:
return { return {
**orig_settings_dict, **orig_settings_dict,
"NAME": f"{self.connection.alias}_{suffix}.sqlite3", "NAME": f"{self.connection.alias}_{suffix}.sqlite3",
@ -99,9 +99,9 @@ class DatabaseCreation(BaseDatabaseCreation):
self.log("Got an error cloning the test database: %s" % e) self.log("Got an error cloning the test database: %s" % e)
sys.exit(2) sys.exit(2)
# Forking automatically makes a copy of an in-memory database. # Forking automatically makes a copy of an in-memory database.
# Spawn requires migrating to disk which will be re-opened in # Forkserver and spawn require migrating to disk which will be
# setup_worker_connection. # re-opened in setup_worker_connection.
elif multiprocessing.get_start_method() == "spawn": elif multiprocessing.get_start_method() in {"forkserver", "spawn"}:
ondisk_db = sqlite3.connect(target_database_name, uri=True) ondisk_db = sqlite3.connect(target_database_name, uri=True)
self.connection.connection.backup(ondisk_db) self.connection.connection.backup(ondisk_db)
ondisk_db.close() ondisk_db.close()
@ -137,7 +137,7 @@ class DatabaseCreation(BaseDatabaseCreation):
# Update settings_dict in place. # Update settings_dict in place.
self.connection.settings_dict.update(settings_dict) self.connection.settings_dict.update(settings_dict)
self.connection.close() self.connection.close()
elif start_method == "spawn": elif start_method in {"forkserver", "spawn"}:
alias = self.connection.alias alias = self.connection.alias
connection_str = ( connection_str = (
f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared" f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared"

View File

@ -404,8 +404,9 @@ def get_max_test_processes():
The maximum number of test processes when using the --parallel option. The maximum number of test processes when using the --parallel option.
""" """
# The current implementation of the parallel test runner requires # The current implementation of the parallel test runner requires
# multiprocessing to start subprocesses with fork() or spawn(). # multiprocessing to start subprocesses with fork(), forkserver(), or
if multiprocessing.get_start_method() not in {"fork", "spawn"}: # spawn().
if multiprocessing.get_start_method() not in {"fork", "spawn", "forkserver"}:
return 1 return 1
try: try:
return int(os.environ["DJANGO_TEST_PROCESSES"]) return int(os.environ["DJANGO_TEST_PROCESSES"])
@ -450,9 +451,12 @@ def _init_worker(
counter.value += 1 counter.value += 1
_worker_id = counter.value _worker_id = counter.value
start_method = multiprocessing.get_start_method() is_spawn_or_forkserver = multiprocessing.get_start_method() in {
"forkserver",
"spawn",
}
if start_method == "spawn": if is_spawn_or_forkserver:
if process_setup and callable(process_setup): if process_setup and callable(process_setup):
if process_setup_args is None: if process_setup_args is None:
process_setup_args = () process_setup_args = ()
@ -463,7 +467,7 @@ def _init_worker(
db_aliases = used_aliases if used_aliases is not None else connections db_aliases = used_aliases if used_aliases is not None else connections
for alias in db_aliases: for alias in db_aliases:
connection = connections[alias] connection = connections[alias]
if start_method == "spawn": if is_spawn_or_forkserver:
# Restore initial settings in spawned processes. # Restore initial settings in spawned processes.
connection.settings_dict.update(initial_settings[alias]) connection.settings_dict.update(initial_settings[alias])
if value := serialized_contents.get(alias): if value := serialized_contents.get(alias):
@ -606,7 +610,7 @@ class ParallelTestSuite(unittest.TestSuite):
return iter(self.subsuites) return iter(self.subsuites)
def initialize_suite(self): def initialize_suite(self):
if multiprocessing.get_start_method() == "spawn": if multiprocessing.get_start_method() in {"forkserver", "spawn"}:
self.initial_settings = { self.initial_settings = {
alias: connections[alias].settings_dict for alias in connections alias: connections[alias].settings_dict for alias in connections
} }

View File

@ -348,7 +348,8 @@ Templates
Tests Tests
~~~~~ ~~~~~
* ... * The :class:`.DiscoverRunner` now supports parallel test execution on systems
using the ``forkserver`` :mod:`multiprocessing` start method.
URLs URLs
~~~~ ~~~~

View File

@ -36,8 +36,8 @@ class TestDbSignatureTests(SimpleTestCase):
clone_settings_dict = creation_class.get_test_db_clone_settings("1") clone_settings_dict = creation_class.get_test_db_clone_settings("1")
self.assertEqual(clone_settings_dict["NAME"], expected_clone_name) self.assertEqual(clone_settings_dict["NAME"], expected_clone_name)
@mock.patch.object(multiprocessing, "get_start_method", return_value="forkserver") @mock.patch.object(multiprocessing, "get_start_method", return_value="unsupported")
def test_get_test_db_clone_settings_not_supported(self, *mocked_objects): def test_get_test_db_clone_settings_not_supported(self, *mocked_objects):
msg = "Cloning with start method 'forkserver' is not supported." msg = "Cloning with start method 'unsupported' is not supported."
with self.assertRaisesMessage(NotSupportedError, msg): with self.assertRaisesMessage(NotSupportedError, msg):
connection.creation.get_test_db_clone_settings(1) connection.creation.get_test_db_clone_settings(1)

View File

@ -97,6 +97,16 @@ class DiscoverRunnerParallelArgumentTests(SimpleTestCase):
mocked_cpu_count, mocked_cpu_count,
): ):
mocked_get_start_method.return_value = "forkserver" mocked_get_start_method.return_value = "forkserver"
self.assertEqual(get_max_test_processes(), 12)
with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}):
self.assertEqual(get_max_test_processes(), 7)
def test_get_max_test_processes_other(
self,
mocked_get_start_method,
mocked_cpu_count,
):
mocked_get_start_method.return_value = "other"
self.assertEqual(get_max_test_processes(), 1) self.assertEqual(get_max_test_processes(), 1)
with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}): with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}):
self.assertEqual(get_max_test_processes(), 1) self.assertEqual(get_max_test_processes(), 1)