mirror of
https://github.com/django/django.git
synced 2025-01-03 23:16:41 +00:00
798e38c2b9
Thanks to Petter Strandmark for the original idea and Mariusz Felisiak for advice during the DjangoConUS 2022 Sprint!
392 lines
16 KiB
Python
392 lines
16 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
|
|
from django.db.backends.base.base import BaseDatabaseWrapper
|
|
from django.test import (
|
|
SimpleTestCase,
|
|
TestCase,
|
|
TransactionTestCase,
|
|
skipUnlessDBFeature,
|
|
)
|
|
from django.test.utils import CaptureQueriesContext, override_settings
|
|
|
|
from ..models import Person, Square
|
|
|
|
|
|
class DatabaseWrapperTests(SimpleTestCase):
|
|
def test_repr(self):
|
|
conn = connections[DEFAULT_DB_ALIAS]
|
|
self.assertEqual(
|
|
repr(conn),
|
|
f"<DatabaseWrapper vendor={connection.vendor!r} alias='default'>",
|
|
)
|
|
|
|
def test_initialization_class_attributes(self):
|
|
"""
|
|
The "initialization" class attributes like client_class and
|
|
creation_class should be set on the class and reflected in the
|
|
corresponding instance attributes of the instantiated backend.
|
|
"""
|
|
conn = connections[DEFAULT_DB_ALIAS]
|
|
conn_class = type(conn)
|
|
attr_names = [
|
|
("client_class", "client"),
|
|
("creation_class", "creation"),
|
|
("features_class", "features"),
|
|
("introspection_class", "introspection"),
|
|
("ops_class", "ops"),
|
|
("validation_class", "validation"),
|
|
]
|
|
for class_attr_name, instance_attr_name in attr_names:
|
|
class_attr_value = getattr(conn_class, class_attr_name)
|
|
self.assertIsNotNone(class_attr_value)
|
|
instance_attr_value = getattr(conn, instance_attr_name)
|
|
self.assertIsInstance(instance_attr_value, class_attr_value)
|
|
|
|
def test_initialization_display_name(self):
|
|
self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
|
|
self.assertNotEqual(connection.display_name, "unknown")
|
|
|
|
def test_get_database_version(self):
|
|
with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
|
|
msg = (
|
|
"subclasses of BaseDatabaseWrapper may require a "
|
|
"get_database_version() method."
|
|
)
|
|
with self.assertRaisesMessage(NotImplementedError, msg):
|
|
BaseDatabaseWrapper().get_database_version()
|
|
|
|
def test_check_database_version_supported_with_none_as_database_version(self):
|
|
with patch.object(connection.features, "minimum_database_version", None):
|
|
connection.check_database_version_supported()
|
|
|
|
|
|
class DatabaseWrapperLoggingTests(TransactionTestCase):
|
|
available_apps = []
|
|
|
|
@override_settings(DEBUG=True)
|
|
def test_commit_debug_log(self):
|
|
conn = connections[DEFAULT_DB_ALIAS]
|
|
with CaptureQueriesContext(conn):
|
|
with self.assertLogs("django.db.backends", "DEBUG") as cm:
|
|
with transaction.atomic():
|
|
Person.objects.create(first_name="first", last_name="last")
|
|
|
|
self.assertGreaterEqual(len(conn.queries_log), 3)
|
|
self.assertEqual(conn.queries_log[-3]["sql"], "BEGIN")
|
|
self.assertRegex(
|
|
cm.output[0],
|
|
r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
|
|
rf"BEGIN; args=None; alias={DEFAULT_DB_ALIAS}",
|
|
)
|
|
self.assertEqual(conn.queries_log[-1]["sql"], "COMMIT")
|
|
self.assertRegex(
|
|
cm.output[-1],
|
|
r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
|
|
rf"COMMIT; args=None; alias={DEFAULT_DB_ALIAS}",
|
|
)
|
|
|
|
@override_settings(DEBUG=True)
|
|
def test_rollback_debug_log(self):
|
|
conn = connections[DEFAULT_DB_ALIAS]
|
|
with CaptureQueriesContext(conn):
|
|
with self.assertLogs("django.db.backends", "DEBUG") as cm:
|
|
with self.assertRaises(Exception), transaction.atomic():
|
|
Person.objects.create(first_name="first", last_name="last")
|
|
raise Exception("Force rollback")
|
|
|
|
self.assertEqual(conn.queries_log[-1]["sql"], "ROLLBACK")
|
|
self.assertRegex(
|
|
cm.output[-1],
|
|
r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
|
|
rf"ROLLBACK; args=None; alias={DEFAULT_DB_ALIAS}",
|
|
)
|
|
|
|
def test_no_logs_without_debug(self):
|
|
with self.assertNoLogs("django.db.backends", "DEBUG"):
|
|
with self.assertRaises(Exception), transaction.atomic():
|
|
Person.objects.create(first_name="first", last_name="last")
|
|
raise Exception("Force rollback")
|
|
|
|
conn = connections[DEFAULT_DB_ALIAS]
|
|
self.assertEqual(len(conn.queries_log), 0)
|
|
|
|
|
|
class ExecuteWrapperTests(TestCase):
|
|
@staticmethod
|
|
def call_execute(connection, params=None):
|
|
ret_val = "1" if params is None else "%s"
|
|
sql = "SELECT " + ret_val + connection.features.bare_select_suffix
|
|
with connection.cursor() as cursor:
|
|
cursor.execute(sql, params)
|
|
|
|
def call_executemany(self, connection, params=None):
|
|
# executemany() must use an update query. Make sure it does nothing
|
|
# by putting a false condition in the WHERE clause.
|
|
sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table)
|
|
if params is None:
|
|
params = [(i,) for i in range(3)]
|
|
with connection.cursor() as cursor:
|
|
cursor.executemany(sql, params)
|
|
|
|
@staticmethod
|
|
def mock_wrapper():
|
|
return MagicMock(side_effect=lambda execute, *args: execute(*args))
|
|
|
|
def test_wrapper_invoked(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(wrapper):
|
|
self.call_execute(connection)
|
|
self.assertTrue(wrapper.called)
|
|
(_, sql, params, many, context), _ = wrapper.call_args
|
|
self.assertIn("SELECT", sql)
|
|
self.assertIsNone(params)
|
|
self.assertIs(many, False)
|
|
self.assertEqual(context["connection"], connection)
|
|
|
|
def test_wrapper_invoked_many(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(wrapper):
|
|
self.call_executemany(connection)
|
|
self.assertTrue(wrapper.called)
|
|
(_, sql, param_list, many, context), _ = wrapper.call_args
|
|
self.assertIn("DELETE", sql)
|
|
self.assertIsInstance(param_list, (list, tuple))
|
|
self.assertIs(many, True)
|
|
self.assertEqual(context["connection"], connection)
|
|
|
|
def test_database_queried(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(wrapper):
|
|
with connection.cursor() as cursor:
|
|
sql = "SELECT 17" + connection.features.bare_select_suffix
|
|
cursor.execute(sql)
|
|
seventeen = cursor.fetchall()
|
|
self.assertEqual(list(seventeen), [(17,)])
|
|
self.call_executemany(connection)
|
|
|
|
def test_nested_wrapper_invoked(self):
|
|
outer_wrapper = self.mock_wrapper()
|
|
inner_wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(
|
|
inner_wrapper
|
|
):
|
|
self.call_execute(connection)
|
|
self.assertEqual(inner_wrapper.call_count, 1)
|
|
self.call_executemany(connection)
|
|
self.assertEqual(inner_wrapper.call_count, 2)
|
|
|
|
def test_outer_wrapper_blocks(self):
|
|
def blocker(*args):
|
|
pass
|
|
|
|
wrapper = self.mock_wrapper()
|
|
c = connection # This alias shortens the next line.
|
|
with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(
|
|
wrapper
|
|
):
|
|
with c.cursor() as cursor:
|
|
cursor.execute("The database never sees this")
|
|
self.assertEqual(wrapper.call_count, 1)
|
|
cursor.executemany("The database never sees this %s", [("either",)])
|
|
self.assertEqual(wrapper.call_count, 2)
|
|
|
|
def test_wrapper_gets_sql(self):
|
|
wrapper = self.mock_wrapper()
|
|
sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
|
|
with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
|
|
cursor.execute(sql)
|
|
(_, reported_sql, _, _, _), _ = wrapper.call_args
|
|
self.assertEqual(reported_sql, sql)
|
|
|
|
def test_wrapper_connection_specific(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connections["other"].execute_wrapper(wrapper):
|
|
self.assertEqual(connections["other"].execute_wrappers, [wrapper])
|
|
self.call_execute(connection)
|
|
self.assertFalse(wrapper.called)
|
|
self.assertEqual(connection.execute_wrappers, [])
|
|
self.assertEqual(connections["other"].execute_wrappers, [])
|
|
|
|
|
|
class ConnectionHealthChecksTests(SimpleTestCase):
|
|
databases = {"default"}
|
|
|
|
def setUp(self):
|
|
# All test cases here need newly configured and created connections.
|
|
# Use the default db connection for convenience.
|
|
connection.close()
|
|
self.addCleanup(connection.close)
|
|
|
|
def patch_settings_dict(self, conn_health_checks):
|
|
self.settings_dict_patcher = patch.dict(
|
|
connection.settings_dict,
|
|
{
|
|
**connection.settings_dict,
|
|
"CONN_MAX_AGE": None,
|
|
"CONN_HEALTH_CHECKS": conn_health_checks,
|
|
},
|
|
)
|
|
self.settings_dict_patcher.start()
|
|
self.addCleanup(self.settings_dict_patcher.stop)
|
|
|
|
def run_query(self):
|
|
with connection.cursor() as cursor:
|
|
cursor.execute("SELECT 42" + connection.features.bare_select_suffix)
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_health_checks_enabled(self):
|
|
self.patch_settings_dict(conn_health_checks=True)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
old_connection = connection.connection
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
self.assertIs(old_connection, connection.connection)
|
|
|
|
# Simulate connection health check failing.
|
|
with patch.object(
|
|
connection, "is_usable", return_value=False
|
|
) as mocked_is_usable:
|
|
self.run_query()
|
|
new_connection = connection.connection
|
|
# A new connection is established.
|
|
self.assertIsNot(new_connection, old_connection)
|
|
# Only one health check per "request" is performed, so the next
|
|
# query will carry on even if the health check fails. Next query
|
|
# succeeds because the real connection is healthy and only the
|
|
# health check failure is mocked.
|
|
self.run_query()
|
|
self.assertIs(new_connection, connection.connection)
|
|
self.assertEqual(mocked_is_usable.call_count, 1)
|
|
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# The underlying connection is being reused further with health checks
|
|
# succeeding.
|
|
self.run_query()
|
|
self.run_query()
|
|
self.assertIs(new_connection, connection.connection)
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_health_checks_enabled_errors_occurred(self):
|
|
self.patch_settings_dict(conn_health_checks=True)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
old_connection = connection.connection
|
|
# Simulate errors_occurred.
|
|
connection.errors_occurred = True
|
|
# Simulate request_started (the connection is healthy).
|
|
connection.close_if_unusable_or_obsolete()
|
|
# Persistent connections are enabled.
|
|
self.assertIs(old_connection, connection.connection)
|
|
# No additional health checks after the one in
|
|
# close_if_unusable_or_obsolete() are executed during this "request"
|
|
# when running queries.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_health_checks_disabled(self):
|
|
self.patch_settings_dict(conn_health_checks=False)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
old_connection = connection.connection
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# Persistent connections are enabled (connection is not).
|
|
self.assertIs(old_connection, connection.connection)
|
|
# Health checks are not performed.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
# Health check wasn't performed and the connection is unchanged.
|
|
self.assertIs(old_connection, connection.connection)
|
|
self.run_query()
|
|
# The connection is unchanged after the next query either during
|
|
# the current "request".
|
|
self.assertIs(old_connection, connection.connection)
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_set_autocommit_health_checks_enabled(self):
|
|
self.patch_settings_dict(conn_health_checks=True)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
# Simulate outermost atomic block: changing autocommit for
|
|
# a connection.
|
|
connection.set_autocommit(False)
|
|
self.run_query()
|
|
connection.commit()
|
|
connection.set_autocommit(True)
|
|
|
|
old_connection = connection.connection
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# Persistent connections are enabled.
|
|
self.assertIs(old_connection, connection.connection)
|
|
|
|
# Simulate connection health check failing.
|
|
with patch.object(
|
|
connection, "is_usable", return_value=False
|
|
) as mocked_is_usable:
|
|
# Simulate outermost atomic block: changing autocommit for
|
|
# a connection.
|
|
connection.set_autocommit(False)
|
|
new_connection = connection.connection
|
|
self.assertIsNot(new_connection, old_connection)
|
|
# Only one health check per "request" is performed, so a query will
|
|
# carry on even if the health check fails. This query succeeds
|
|
# because the real connection is healthy and only the health check
|
|
# failure is mocked.
|
|
self.run_query()
|
|
connection.commit()
|
|
connection.set_autocommit(True)
|
|
# The connection is unchanged.
|
|
self.assertIs(new_connection, connection.connection)
|
|
self.assertEqual(mocked_is_usable.call_count, 1)
|
|
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# The underlying connection is being reused further with health checks
|
|
# succeeding.
|
|
connection.set_autocommit(False)
|
|
self.run_query()
|
|
connection.commit()
|
|
connection.set_autocommit(True)
|
|
self.assertIs(new_connection, connection.connection)
|
|
|
|
|
|
class MultiDatabaseTests(TestCase):
|
|
databases = {"default", "other"}
|
|
|
|
def test_multi_database_init_connection_state_called_once(self):
|
|
for db in self.databases:
|
|
with self.subTest(database=db):
|
|
with patch.object(connections[db], "commit", return_value=None):
|
|
with patch.object(
|
|
connections[db],
|
|
"check_database_version_supported",
|
|
) as mocked_check_database_version_supported:
|
|
connections[db].init_connection_state()
|
|
after_first_calls = len(
|
|
mocked_check_database_version_supported.mock_calls
|
|
)
|
|
connections[db].init_connection_state()
|
|
self.assertEqual(
|
|
len(mocked_check_database_version_supported.mock_calls),
|
|
after_first_calls,
|
|
)
|