mirror of
https://github.com/django/django.git
synced 2025-01-07 17:06:04 +00:00
9ac3ef59f9
Thanks Tim Graham for the review.
335 lines
14 KiB
Python
335 lines
14 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
from django.db import DEFAULT_DB_ALIAS, connection, connections
|
|
from django.db.backends.base.base import BaseDatabaseWrapper
|
|
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
|
|
|
from ..models import Square
|
|
|
|
|
|
class DatabaseWrapperTests(SimpleTestCase):
|
|
def test_repr(self):
|
|
conn = connections[DEFAULT_DB_ALIAS]
|
|
self.assertEqual(
|
|
repr(conn),
|
|
f"<DatabaseWrapper vendor={connection.vendor!r} alias='default'>",
|
|
)
|
|
|
|
def test_initialization_class_attributes(self):
|
|
"""
|
|
The "initialization" class attributes like client_class and
|
|
creation_class should be set on the class and reflected in the
|
|
corresponding instance attributes of the instantiated backend.
|
|
"""
|
|
conn = connections[DEFAULT_DB_ALIAS]
|
|
conn_class = type(conn)
|
|
attr_names = [
|
|
("client_class", "client"),
|
|
("creation_class", "creation"),
|
|
("features_class", "features"),
|
|
("introspection_class", "introspection"),
|
|
("ops_class", "ops"),
|
|
("validation_class", "validation"),
|
|
]
|
|
for class_attr_name, instance_attr_name in attr_names:
|
|
class_attr_value = getattr(conn_class, class_attr_name)
|
|
self.assertIsNotNone(class_attr_value)
|
|
instance_attr_value = getattr(conn, instance_attr_name)
|
|
self.assertIsInstance(instance_attr_value, class_attr_value)
|
|
|
|
def test_initialization_display_name(self):
|
|
self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
|
|
self.assertNotEqual(connection.display_name, "unknown")
|
|
|
|
def test_get_database_version(self):
|
|
with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
|
|
msg = (
|
|
"subclasses of BaseDatabaseWrapper may require a "
|
|
"get_database_version() method."
|
|
)
|
|
with self.assertRaisesMessage(NotImplementedError, msg):
|
|
BaseDatabaseWrapper().get_database_version()
|
|
|
|
def test_check_database_version_supported_with_none_as_database_version(self):
|
|
with patch.object(connection.features, "minimum_database_version", None):
|
|
connection.check_database_version_supported()
|
|
|
|
|
|
class ExecuteWrapperTests(TestCase):
|
|
@staticmethod
|
|
def call_execute(connection, params=None):
|
|
ret_val = "1" if params is None else "%s"
|
|
sql = "SELECT " + ret_val + connection.features.bare_select_suffix
|
|
with connection.cursor() as cursor:
|
|
cursor.execute(sql, params)
|
|
|
|
def call_executemany(self, connection, params=None):
|
|
# executemany() must use an update query. Make sure it does nothing
|
|
# by putting a false condition in the WHERE clause.
|
|
sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table)
|
|
if params is None:
|
|
params = [(i,) for i in range(3)]
|
|
with connection.cursor() as cursor:
|
|
cursor.executemany(sql, params)
|
|
|
|
@staticmethod
|
|
def mock_wrapper():
|
|
return MagicMock(side_effect=lambda execute, *args: execute(*args))
|
|
|
|
def test_wrapper_invoked(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(wrapper):
|
|
self.call_execute(connection)
|
|
self.assertTrue(wrapper.called)
|
|
(_, sql, params, many, context), _ = wrapper.call_args
|
|
self.assertIn("SELECT", sql)
|
|
self.assertIsNone(params)
|
|
self.assertIs(many, False)
|
|
self.assertEqual(context["connection"], connection)
|
|
|
|
def test_wrapper_invoked_many(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(wrapper):
|
|
self.call_executemany(connection)
|
|
self.assertTrue(wrapper.called)
|
|
(_, sql, param_list, many, context), _ = wrapper.call_args
|
|
self.assertIn("DELETE", sql)
|
|
self.assertIsInstance(param_list, (list, tuple))
|
|
self.assertIs(many, True)
|
|
self.assertEqual(context["connection"], connection)
|
|
|
|
def test_database_queried(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(wrapper):
|
|
with connection.cursor() as cursor:
|
|
sql = "SELECT 17" + connection.features.bare_select_suffix
|
|
cursor.execute(sql)
|
|
seventeen = cursor.fetchall()
|
|
self.assertEqual(list(seventeen), [(17,)])
|
|
self.call_executemany(connection)
|
|
|
|
def test_nested_wrapper_invoked(self):
|
|
outer_wrapper = self.mock_wrapper()
|
|
inner_wrapper = self.mock_wrapper()
|
|
with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(
|
|
inner_wrapper
|
|
):
|
|
self.call_execute(connection)
|
|
self.assertEqual(inner_wrapper.call_count, 1)
|
|
self.call_executemany(connection)
|
|
self.assertEqual(inner_wrapper.call_count, 2)
|
|
|
|
def test_outer_wrapper_blocks(self):
|
|
def blocker(*args):
|
|
pass
|
|
|
|
wrapper = self.mock_wrapper()
|
|
c = connection # This alias shortens the next line.
|
|
with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(
|
|
wrapper
|
|
):
|
|
with c.cursor() as cursor:
|
|
cursor.execute("The database never sees this")
|
|
self.assertEqual(wrapper.call_count, 1)
|
|
cursor.executemany("The database never sees this %s", [("either",)])
|
|
self.assertEqual(wrapper.call_count, 2)
|
|
|
|
def test_wrapper_gets_sql(self):
|
|
wrapper = self.mock_wrapper()
|
|
sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
|
|
with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
|
|
cursor.execute(sql)
|
|
(_, reported_sql, _, _, _), _ = wrapper.call_args
|
|
self.assertEqual(reported_sql, sql)
|
|
|
|
def test_wrapper_connection_specific(self):
|
|
wrapper = self.mock_wrapper()
|
|
with connections["other"].execute_wrapper(wrapper):
|
|
self.assertEqual(connections["other"].execute_wrappers, [wrapper])
|
|
self.call_execute(connection)
|
|
self.assertFalse(wrapper.called)
|
|
self.assertEqual(connection.execute_wrappers, [])
|
|
self.assertEqual(connections["other"].execute_wrappers, [])
|
|
|
|
|
|
class ConnectionHealthChecksTests(SimpleTestCase):
|
|
databases = {"default"}
|
|
|
|
def setUp(self):
|
|
# All test cases here need newly configured and created connections.
|
|
# Use the default db connection for convenience.
|
|
connection.close()
|
|
self.addCleanup(connection.close)
|
|
|
|
def patch_settings_dict(self, conn_health_checks):
|
|
self.settings_dict_patcher = patch.dict(
|
|
connection.settings_dict,
|
|
{
|
|
**connection.settings_dict,
|
|
"CONN_MAX_AGE": None,
|
|
"CONN_HEALTH_CHECKS": conn_health_checks,
|
|
},
|
|
)
|
|
self.settings_dict_patcher.start()
|
|
self.addCleanup(self.settings_dict_patcher.stop)
|
|
|
|
def run_query(self):
|
|
with connection.cursor() as cursor:
|
|
cursor.execute("SELECT 42" + connection.features.bare_select_suffix)
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_health_checks_enabled(self):
|
|
self.patch_settings_dict(conn_health_checks=True)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
old_connection = connection.connection
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
self.assertIs(old_connection, connection.connection)
|
|
|
|
# Simulate connection health check failing.
|
|
with patch.object(
|
|
connection, "is_usable", return_value=False
|
|
) as mocked_is_usable:
|
|
self.run_query()
|
|
new_connection = connection.connection
|
|
# A new connection is established.
|
|
self.assertIsNot(new_connection, old_connection)
|
|
# Only one health check per "request" is performed, so the next
|
|
# query will carry on even if the health check fails. Next query
|
|
# succeeds because the real connection is healthy and only the
|
|
# health check failure is mocked.
|
|
self.run_query()
|
|
self.assertIs(new_connection, connection.connection)
|
|
self.assertEqual(mocked_is_usable.call_count, 1)
|
|
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# The underlying connection is being reused further with health checks
|
|
# succeeding.
|
|
self.run_query()
|
|
self.run_query()
|
|
self.assertIs(new_connection, connection.connection)
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_health_checks_enabled_errors_occurred(self):
|
|
self.patch_settings_dict(conn_health_checks=True)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
old_connection = connection.connection
|
|
# Simulate errors_occurred.
|
|
connection.errors_occurred = True
|
|
# Simulate request_started (the connection is healthy).
|
|
connection.close_if_unusable_or_obsolete()
|
|
# Persistent connections are enabled.
|
|
self.assertIs(old_connection, connection.connection)
|
|
# No additional health checks after the one in
|
|
# close_if_unusable_or_obsolete() are executed during this "request"
|
|
# when running queries.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_health_checks_disabled(self):
|
|
self.patch_settings_dict(conn_health_checks=False)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
|
|
old_connection = connection.connection
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# Persistent connections are enabled (connection is not).
|
|
self.assertIs(old_connection, connection.connection)
|
|
# Health checks are not performed.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
self.run_query()
|
|
# Health check wasn't performed and the connection is unchanged.
|
|
self.assertIs(old_connection, connection.connection)
|
|
self.run_query()
|
|
# The connection is unchanged after the next query either during
|
|
# the current "request".
|
|
self.assertIs(old_connection, connection.connection)
|
|
|
|
@skipUnlessDBFeature("test_db_allows_multiple_connections")
|
|
def test_set_autocommit_health_checks_enabled(self):
|
|
self.patch_settings_dict(conn_health_checks=True)
|
|
self.assertIsNone(connection.connection)
|
|
# Newly created connections are considered healthy without performing
|
|
# the health check.
|
|
with patch.object(connection, "is_usable", side_effect=AssertionError):
|
|
# Simulate outermost atomic block: changing autocommit for
|
|
# a connection.
|
|
connection.set_autocommit(False)
|
|
self.run_query()
|
|
connection.commit()
|
|
connection.set_autocommit(True)
|
|
|
|
old_connection = connection.connection
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# Persistent connections are enabled.
|
|
self.assertIs(old_connection, connection.connection)
|
|
|
|
# Simulate connection health check failing.
|
|
with patch.object(
|
|
connection, "is_usable", return_value=False
|
|
) as mocked_is_usable:
|
|
# Simulate outermost atomic block: changing autocommit for
|
|
# a connection.
|
|
connection.set_autocommit(False)
|
|
new_connection = connection.connection
|
|
self.assertIsNot(new_connection, old_connection)
|
|
# Only one health check per "request" is performed, so a query will
|
|
# carry on even if the health check fails. This query succeeds
|
|
# because the real connection is healthy and only the health check
|
|
# failure is mocked.
|
|
self.run_query()
|
|
connection.commit()
|
|
connection.set_autocommit(True)
|
|
# The connection is unchanged.
|
|
self.assertIs(new_connection, connection.connection)
|
|
self.assertEqual(mocked_is_usable.call_count, 1)
|
|
|
|
# Simulate request_finished.
|
|
connection.close_if_unusable_or_obsolete()
|
|
# The underlying connection is being reused further with health checks
|
|
# succeeding.
|
|
connection.set_autocommit(False)
|
|
self.run_query()
|
|
connection.commit()
|
|
connection.set_autocommit(True)
|
|
self.assertIs(new_connection, connection.connection)
|
|
|
|
|
|
class MultiDatabaseTests(TestCase):
|
|
databases = {"default", "other"}
|
|
|
|
def test_multi_database_init_connection_state_called_once(self):
|
|
for db in self.databases:
|
|
with self.subTest(database=db):
|
|
with patch.object(connections[db], "commit", return_value=None):
|
|
with patch.object(
|
|
connections[db],
|
|
"check_database_version_supported",
|
|
) as mocked_check_database_version_supported:
|
|
connections[db].init_connection_state()
|
|
after_first_calls = len(
|
|
mocked_check_database_version_supported.mock_calls
|
|
)
|
|
connections[db].init_connection_state()
|
|
self.assertEqual(
|
|
len(mocked_check_database_version_supported.mock_calls),
|
|
after_first_calls,
|
|
)
|