diff --git a/django/test/runner.py b/django/test/runner.py index 8ce7ac718b..71bee37297 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -404,8 +404,8 @@ class ParallelTestSuite(unittest.TestSuite): run_subsuite = _run_subsuite runner_class = RemoteTestRunner - def __init__(self, suite, processes, failfast=False, buffer=False): - self.subsuites = partition_suite_by_case(suite) + def __init__(self, subsuites, processes, failfast=False, buffer=False): + self.subsuites = subsuites self.processes = processes self.failfast = failfast self.buffer = buffer @@ -685,22 +685,17 @@ class DiscoverRunner: suite = self.test_suite(all_tests) if self.parallel > 1: - parallel_suite = self.parallel_test_suite( - suite, - self.parallel, - self.failfast, - self.buffer, - ) - + subsuites = partition_suite_by_case(suite) # Since tests are distributed across processes on a per-TestCase # basis, there's no need for more processes than TestCases. - parallel_units = len(parallel_suite.subsuites) - self.parallel = min(self.parallel, parallel_units) - - # If there's only one TestCase, parallelization isn't needed. - if self.parallel > 1: - suite = parallel_suite - + processes = min(self.parallel, len(subsuites)) + if processes > 1: + suite = self.parallel_test_suite( + subsuites, + processes, + self.failfast, + self.buffer, + ) return suite def setup_databases(self, **kwargs): diff --git a/tests/test_runner/test_discover_runner.py b/tests/test_runner/test_discover_runner.py index 4eff76c5fd..3f840ee049 100644 --- a/tests/test_runner/test_discover_runner.py +++ b/tests/test_runner/test_discover_runner.py @@ -348,6 +348,12 @@ class DiscoverRunnerTests(SimpleTestCase): with self.assertRaisesMessage(ValueError, msg): DiscoverRunner(pdb=True, parallel=2) + def test_number_of_parallel_workers(self): + """Number of processes doesn't exceed the number of TestCases.""" + runner = DiscoverRunner(parallel=5, verbosity=0) + suite = runner.build_suite(['test_runner_apps.tagged']) + self.assertEqual(suite.processes, len(suite.subsuites)) + def test_buffer_mode_test_pass(self): runner = DiscoverRunner(buffer=True, verbose=0) with captured_stdout() as stdout, captured_stderr() as stderr: