from unittest.mock import MagicMock, patch from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction from django.db.backends.base.base import BaseDatabaseWrapper from django.test import ( SimpleTestCase, TestCase, TransactionTestCase, skipUnlessDBFeature, ) from django.test.utils import CaptureQueriesContext, override_settings from ..models import Person, Square class DatabaseWrapperTests(SimpleTestCase): def test_repr(self): conn = connections[DEFAULT_DB_ALIAS] self.assertEqual( repr(conn), f"", ) def test_initialization_class_attributes(self): """ The "initialization" class attributes like client_class and creation_class should be set on the class and reflected in the corresponding instance attributes of the instantiated backend. """ conn = connections[DEFAULT_DB_ALIAS] conn_class = type(conn) attr_names = [ ("client_class", "client"), ("creation_class", "creation"), ("features_class", "features"), ("introspection_class", "introspection"), ("ops_class", "ops"), ("validation_class", "validation"), ] for class_attr_name, instance_attr_name in attr_names: class_attr_value = getattr(conn_class, class_attr_name) self.assertIsNotNone(class_attr_value) instance_attr_value = getattr(conn, instance_attr_name) self.assertIsInstance(instance_attr_value, class_attr_value) def test_initialization_display_name(self): self.assertEqual(BaseDatabaseWrapper.display_name, "unknown") self.assertNotEqual(connection.display_name, "unknown") def test_get_database_version(self): with patch.object(BaseDatabaseWrapper, "__init__", return_value=None): msg = ( "subclasses of BaseDatabaseWrapper may require a " "get_database_version() method." ) with self.assertRaisesMessage(NotImplementedError, msg): BaseDatabaseWrapper().get_database_version() def test_check_database_version_supported_with_none_as_database_version(self): with patch.object(connection.features, "minimum_database_version", None): connection.check_database_version_supported() class DatabaseWrapperLoggingTests(TransactionTestCase): available_apps = [] @override_settings(DEBUG=True) def test_commit_debug_log(self): conn = connections[DEFAULT_DB_ALIAS] with CaptureQueriesContext(conn): with self.assertLogs("django.db.backends", "DEBUG") as cm: with transaction.atomic(): Person.objects.create(first_name="first", last_name="last") self.assertGreaterEqual(len(conn.queries_log), 3) self.assertEqual(conn.queries_log[-3]["sql"], "BEGIN") self.assertRegex( cm.output[0], r"DEBUG:django.db.backends:\(\d+.\d{3}\) " rf"BEGIN; args=None; alias={DEFAULT_DB_ALIAS}", ) self.assertEqual(conn.queries_log[-1]["sql"], "COMMIT") self.assertRegex( cm.output[-1], r"DEBUG:django.db.backends:\(\d+.\d{3}\) " rf"COMMIT; args=None; alias={DEFAULT_DB_ALIAS}", ) @override_settings(DEBUG=True) def test_rollback_debug_log(self): conn = connections[DEFAULT_DB_ALIAS] with CaptureQueriesContext(conn): with self.assertLogs("django.db.backends", "DEBUG") as cm: with self.assertRaises(Exception), transaction.atomic(): Person.objects.create(first_name="first", last_name="last") raise Exception("Force rollback") self.assertEqual(conn.queries_log[-1]["sql"], "ROLLBACK") self.assertRegex( cm.output[-1], r"DEBUG:django.db.backends:\(\d+.\d{3}\) " rf"ROLLBACK; args=None; alias={DEFAULT_DB_ALIAS}", ) def test_no_logs_without_debug(self): with self.assertNoLogs("django.db.backends", "DEBUG"): with self.assertRaises(Exception), transaction.atomic(): Person.objects.create(first_name="first", last_name="last") raise Exception("Force rollback") conn = connections[DEFAULT_DB_ALIAS] self.assertEqual(len(conn.queries_log), 0) class ExecuteWrapperTests(TestCase): @staticmethod def call_execute(connection, params=None): ret_val = "1" if params is None else "%s" sql = "SELECT " + ret_val + connection.features.bare_select_suffix with connection.cursor() as cursor: cursor.execute(sql, params) def call_executemany(self, connection, params=None): # executemany() must use an update query. Make sure it does nothing # by putting a false condition in the WHERE clause. sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table) if params is None: params = [(i,) for i in range(3)] with connection.cursor() as cursor: cursor.executemany(sql, params) @staticmethod def mock_wrapper(): return MagicMock(side_effect=lambda execute, *args: execute(*args)) def test_wrapper_invoked(self): wrapper = self.mock_wrapper() with connection.execute_wrapper(wrapper): self.call_execute(connection) self.assertTrue(wrapper.called) (_, sql, params, many, context), _ = wrapper.call_args self.assertIn("SELECT", sql) self.assertIsNone(params) self.assertIs(many, False) self.assertEqual(context["connection"], connection) def test_wrapper_invoked_many(self): wrapper = self.mock_wrapper() with connection.execute_wrapper(wrapper): self.call_executemany(connection) self.assertTrue(wrapper.called) (_, sql, param_list, many, context), _ = wrapper.call_args self.assertIn("DELETE", sql) self.assertIsInstance(param_list, (list, tuple)) self.assertIs(many, True) self.assertEqual(context["connection"], connection) def test_database_queried(self): wrapper = self.mock_wrapper() with connection.execute_wrapper(wrapper): with connection.cursor() as cursor: sql = "SELECT 17" + connection.features.bare_select_suffix cursor.execute(sql) seventeen = cursor.fetchall() self.assertEqual(list(seventeen), [(17,)]) self.call_executemany(connection) def test_nested_wrapper_invoked(self): outer_wrapper = self.mock_wrapper() inner_wrapper = self.mock_wrapper() with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper( inner_wrapper ): self.call_execute(connection) self.assertEqual(inner_wrapper.call_count, 1) self.call_executemany(connection) self.assertEqual(inner_wrapper.call_count, 2) def test_outer_wrapper_blocks(self): def blocker(*args): pass wrapper = self.mock_wrapper() c = connection # This alias shortens the next line. with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper( wrapper ): with c.cursor() as cursor: cursor.execute("The database never sees this") self.assertEqual(wrapper.call_count, 1) cursor.executemany("The database never sees this %s", [("either",)]) self.assertEqual(wrapper.call_count, 2) def test_wrapper_gets_sql(self): wrapper = self.mock_wrapper() sql = "SELECT 'aloha'" + connection.features.bare_select_suffix with connection.execute_wrapper(wrapper), connection.cursor() as cursor: cursor.execute(sql) (_, reported_sql, _, _, _), _ = wrapper.call_args self.assertEqual(reported_sql, sql) def test_wrapper_connection_specific(self): wrapper = self.mock_wrapper() with connections["other"].execute_wrapper(wrapper): self.assertEqual(connections["other"].execute_wrappers, [wrapper]) self.call_execute(connection) self.assertFalse(wrapper.called) self.assertEqual(connection.execute_wrappers, []) self.assertEqual(connections["other"].execute_wrappers, []) class ConnectionHealthChecksTests(SimpleTestCase): databases = {"default"} def setUp(self): # All test cases here need newly configured and created connections. # Use the default db connection for convenience. connection.close() self.addCleanup(connection.close) def patch_settings_dict(self, conn_health_checks): self.settings_dict_patcher = patch.dict( connection.settings_dict, { **connection.settings_dict, "CONN_MAX_AGE": None, "CONN_HEALTH_CHECKS": conn_health_checks, }, ) self.settings_dict_patcher.start() self.addCleanup(self.settings_dict_patcher.stop) def run_query(self): with connection.cursor() as cursor: cursor.execute("SELECT 42" + connection.features.bare_select_suffix) @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_health_checks_enabled(self): self.patch_settings_dict(conn_health_checks=True) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() old_connection = connection.connection # Simulate request_finished. connection.close_if_unusable_or_obsolete() self.assertIs(old_connection, connection.connection) # Simulate connection health check failing. with patch.object( connection, "is_usable", return_value=False ) as mocked_is_usable: self.run_query() new_connection = connection.connection # A new connection is established. self.assertIsNot(new_connection, old_connection) # Only one health check per "request" is performed, so the next # query will carry on even if the health check fails. Next query # succeeds because the real connection is healthy and only the # health check failure is mocked. self.run_query() self.assertIs(new_connection, connection.connection) self.assertEqual(mocked_is_usable.call_count, 1) # Simulate request_finished. connection.close_if_unusable_or_obsolete() # The underlying connection is being reused further with health checks # succeeding. self.run_query() self.run_query() self.assertIs(new_connection, connection.connection) @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_health_checks_enabled_errors_occurred(self): self.patch_settings_dict(conn_health_checks=True) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() old_connection = connection.connection # Simulate errors_occurred. connection.errors_occurred = True # Simulate request_started (the connection is healthy). connection.close_if_unusable_or_obsolete() # Persistent connections are enabled. self.assertIs(old_connection, connection.connection) # No additional health checks after the one in # close_if_unusable_or_obsolete() are executed during this "request" # when running queries. with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_health_checks_disabled(self): self.patch_settings_dict(conn_health_checks=False) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() old_connection = connection.connection # Simulate request_finished. connection.close_if_unusable_or_obsolete() # Persistent connections are enabled (connection is not). self.assertIs(old_connection, connection.connection) # Health checks are not performed. with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() # Health check wasn't performed and the connection is unchanged. self.assertIs(old_connection, connection.connection) self.run_query() # The connection is unchanged after the next query either during # the current "request". self.assertIs(old_connection, connection.connection) @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_set_autocommit_health_checks_enabled(self): self.patch_settings_dict(conn_health_checks=True) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. with patch.object(connection, "is_usable", side_effect=AssertionError): # Simulate outermost atomic block: changing autocommit for # a connection. connection.set_autocommit(False) self.run_query() connection.commit() connection.set_autocommit(True) old_connection = connection.connection # Simulate request_finished. connection.close_if_unusable_or_obsolete() # Persistent connections are enabled. self.assertIs(old_connection, connection.connection) # Simulate connection health check failing. with patch.object( connection, "is_usable", return_value=False ) as mocked_is_usable: # Simulate outermost atomic block: changing autocommit for # a connection. connection.set_autocommit(False) new_connection = connection.connection self.assertIsNot(new_connection, old_connection) # Only one health check per "request" is performed, so a query will # carry on even if the health check fails. This query succeeds # because the real connection is healthy and only the health check # failure is mocked. self.run_query() connection.commit() connection.set_autocommit(True) # The connection is unchanged. self.assertIs(new_connection, connection.connection) self.assertEqual(mocked_is_usable.call_count, 1) # Simulate request_finished. connection.close_if_unusable_or_obsolete() # The underlying connection is being reused further with health checks # succeeding. connection.set_autocommit(False) self.run_query() connection.commit() connection.set_autocommit(True) self.assertIs(new_connection, connection.connection) class MultiDatabaseTests(TestCase): databases = {"default", "other"} def test_multi_database_init_connection_state_called_once(self): for db in self.databases: with self.subTest(database=db): with patch.object(connections[db], "commit", return_value=None): with patch.object( connections[db], "check_database_version_supported", ) as mocked_check_database_version_supported: connections[db].init_connection_state() after_first_calls = len( mocked_check_database_version_supported.mock_calls ) connections[db].init_connection_state() self.assertEqual( len(mocked_check_database_version_supported.mock_calls), after_first_calls, )