1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

added support for parallel database setups

This commit is contained in:
leondaz 2024-10-12 17:00:45 +03:00
parent 9423f8b476
commit bcf3914528

View File

@ -6,6 +6,7 @@ import re
import sys import sys
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from io import StringIO from io import StringIO
@ -51,6 +52,7 @@ __all__ = (
"tag", "tag",
"requires_tz_support", "requires_tz_support",
"setup_databases", "setup_databases",
"TestDatabaseSetup",
"setup_test_environment", "setup_test_environment",
"teardown_test_environment", "teardown_test_environment",
) )
@ -170,6 +172,137 @@ def teardown_test_environment():
del mail.outbox del mail.outbox
class TestDatabaseSetup:
def __init__(
self,
verbosity,
interactive,
time_keeper=None,
keepdb=False,
debug_sql=False,
parallel=0,
aliases=None,
serialized_aliases=None,
**kwargs,
):
self.verbosity = verbosity
self.interactive = interactive
self.time_keeper = time_keeper or NullTimeKeeper()
self.keepdb = keepdb
self.debug_sql = debug_sql
self.parallel = parallel
self.aliases = aliases
self.serialized_aliases = serialized_aliases
self.test_databases = {}
self.mirrored_aliases = {}
self.old_names = []
self.kwargs = kwargs
def setup(self):
"""Main method to set up test databases."""
self.test_databases, self.mirrored_aliases = get_unique_databases_and_mirrors(
self.aliases
)
# Process each unique database
for db_sig, (db_name, aliases) in self.test_databases.items():
# Extract the first alias outside the loop
first_alias = aliases[0]
remaining_aliases = aliases[1:]
# Create the test database for the first alias
self.create_test_database(first_alias, db_name)
# Clone the test database if parallel testing is enabled
if self.parallel > 1:
self.clone_databases_in_parallel(first_alias)
# Configure remaining aliases as mirrors
for alias in remaining_aliases:
self.configure_as_mirror(alias, first_alias)
# Configure any additional test mirrors
self.configure_test_mirrors()
# Enable debug SQL logging if required
if self.debug_sql:
self.enable_debug_sql()
return self.old_names
def create_test_database(self, alias, db_name):
"""Create the test database for the given alias."""
connection = connections[alias]
self.old_names.append((connection, db_name, True))
serialize_alias = (
self.serialized_aliases is None or alias in self.serialized_aliases
)
with self.time_keeper.timed(f" Creating '{alias}'"):
connection.creation.create_test_db(
verbosity=self.verbosity,
autoclobber=not self.interactive,
keepdb=self.keepdb,
serialize=serialize_alias,
)
def clone_databases_sequentially(self, alias):
"""Clone the test database sequentially for parallel test execution."""
for index in range(self.parallel):
self.clone_test_database(alias, index)
def clone_databases_in_parallel(self, alias):
"""Clone the test database in parallel threads for parallel test execution."""
with ThreadPoolExecutor(max_workers=self.parallel) as executor:
futures = [
executor.submit(self.clone_test_database, alias, index)
for index in range(self.parallel)
]
# Optionally, wait for all futures to complete
# This just ensures that any exceptions are raised
for future in futures:
future.result()
def clone_test_database(self, alias, index):
"""Clone the test database for parallel execution."""
connection = connections[alias]
# re-init the connection per thread
connection.close()
connection.connect()
with self.time_keeper.timed(f" Cloning '{alias}'"):
connection.creation.clone_test_db(
suffix=str(index + 1),
verbosity=self.verbosity,
keepdb=self.keepdb,
)
def configure_as_mirror(self, alias, source_alias):
"""Configure an alias as a mirror of the source alias."""
connection = connections[alias]
self.old_names.append(
(connection, None, False)
) # False indicates no creation needed
source_settings = connections[source_alias].settings_dict
connection.creation.set_as_test_mirror(source_settings)
def configure_test_mirrors(self):
"""Configure any additional test mirrors."""
for alias, mirror_alias in self.mirrored_aliases.items():
connection = connections[alias]
mirror_settings = connections[mirror_alias].settings_dict
connection.creation.set_as_test_mirror(mirror_settings)
def enable_debug_sql(self):
"""Enable debug SQL logging for all connections."""
for alias in connections:
connections[alias].force_debug_cursor = True
def setup_databases( def setup_databases(
verbosity, verbosity,
interactive, interactive,
@ -182,58 +315,17 @@ def setup_databases(
serialized_aliases=None, serialized_aliases=None,
**kwargs, **kwargs,
): ):
"""Create the test databases.""" return TestDatabaseSetup(
if time_keeper is None: verbosity,
time_keeper = NullTimeKeeper() interactive,
time_keeper,
test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases) keepdb,
debug_sql,
old_names = [] parallel,
aliases,
for db_name, aliases in test_databases.values(): serialized_aliases,
first_alias = None **kwargs,
for alias in aliases: ).setup()
connection = connections[alias]
old_names.append((connection, db_name, first_alias is None))
# Actually create the database for the first connection
if first_alias is None:
first_alias = alias
with time_keeper.timed(" Creating '%s'" % alias):
serialize_alias = (
serialized_aliases is None or alias in serialized_aliases
)
connection.creation.create_test_db(
verbosity=verbosity,
autoclobber=not interactive,
keepdb=keepdb,
serialize=serialize_alias,
)
if parallel > 1:
for index in range(parallel):
with time_keeper.timed(" Cloning '%s'" % alias):
connection.creation.clone_test_db(
suffix=str(index + 1),
verbosity=verbosity,
keepdb=keepdb,
)
# Configure all other connections as mirrors of the first one
else:
connections[alias].creation.set_as_test_mirror(
connections[first_alias].settings_dict
)
# Configure the test mirrors.
for alias, mirror_alias in mirrored_aliases.items():
connections[alias].creation.set_as_test_mirror(
connections[mirror_alias].settings_dict
)
if debug_sql:
for alias in connections:
connections[alias].force_debug_cursor = True
return old_names
def iter_test_cases(tests): def iter_test_cases(tests):