diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 568f510a67..ab0ea8258b 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -151,7 +151,7 @@ class CursorDebugWrapper(CursorWrapper): logger.debug( "(%.3f) %s; args=%s; alias=%s", duration, - self.db.ops.format_debug_sql(sql), + sql, params, self.db.alias, extra={ diff --git a/django/test/runner.py b/django/test/runner.py index c8bb16e7b3..3e5c319ade 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -16,7 +16,6 @@ import unittest.suite from collections import defaultdict from contextlib import contextmanager from importlib import import_module -from io import StringIO import django from django.core.management import call_command @@ -41,16 +40,47 @@ except ImportError: tblib = None +class QueryFormatter(logging.Formatter): + def format(self, record): + if (alias := getattr(record, "alias", None)) in connections: + format_sql = connections[alias].ops.format_debug_sql + + sql = None + formatted_sql = None + if args := record.args: + if isinstance(args, tuple) and len(args) > 1 and (sql := args[1]): + record.args = (args[0], formatted_sql := format_sql(sql), *args[2:]) + elif isinstance(record.args, dict) and (sql := record.args.get("sql")): + record.args["sql"] = formatted_sql = format_sql(sql) + + if extra_sql := getattr(record, "sql", None): + if extra_sql == sql: + record.sql = formatted_sql + else: + record.sql = format_sql(extra_sql) + + return super().format(record) + + class DebugSQLTextTestResult(unittest.TextTestResult): def __init__(self, stream, descriptions, verbosity): self.logger = logging.getLogger("django.db.backends") self.logger.setLevel(logging.DEBUG) - self.debug_sql_stream = None + self.handler = None super().__init__(stream, descriptions, verbosity) + def _read_logger_stream(self): + if self.handler is None: + # Error before tests e.g. in setUpTestData(). + sql = "" + else: + self.handler.stream.seek(0) + sql = self.handler.stream.read() + return sql + def startTest(self, test): - self.debug_sql_stream = StringIO() - self.handler = logging.StreamHandler(self.debug_sql_stream) + self.handler = logging.StreamHandler(io.StringIO()) + self.handler.setFormatter(QueryFormatter()) self.logger.addHandler(self.handler) super().startTest(test) @@ -58,35 +88,26 @@ class DebugSQLTextTestResult(unittest.TextTestResult): super().stopTest(test) self.logger.removeHandler(self.handler) if self.showAll: - self.debug_sql_stream.seek(0) - self.stream.write(self.debug_sql_stream.read()) + self.stream.write(self._read_logger_stream()) self.stream.writeln(self.separator2) def addError(self, test, err): super().addError(test, err) - if self.debug_sql_stream is None: - # Error before tests e.g. in setUpTestData(). - sql = "" - else: - self.debug_sql_stream.seek(0) - sql = self.debug_sql_stream.read() - self.errors[-1] = self.errors[-1] + (sql,) + self.errors[-1] = self.errors[-1] + (self._read_logger_stream(),) def addFailure(self, test, err): super().addFailure(test, err) - self.debug_sql_stream.seek(0) - self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),) + self.failures[-1] = self.failures[-1] + (self._read_logger_stream(),) def addSubTest(self, test, subtest, err): super().addSubTest(test, subtest, err) if err is not None: - self.debug_sql_stream.seek(0) errors = ( self.failures if issubclass(err[0], test.failureException) else self.errors ) - errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) + errors[-1] = errors[-1] + (self._read_logger_stream(),) def printErrorList(self, flavour, errors): for test, err, sql_debug in errors: diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 0e5348e248..5e38f0112d 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -83,12 +83,7 @@ class LastExecutedQueryTest(TestCase): connection.ops.last_executed_query(cursor, "SELECT %s" + suffix, (1,)) def test_debug_sql(self): - qs = Reporter.objects.filter(first_name="test") - ops = connections[qs.db].ops - with mock.patch.object(ops, "format_debug_sql") as format_debug_sql: - list(qs) - # Queries are formatted with DatabaseOperations.format_debug_sql(). - format_debug_sql.assert_called() + list(Reporter.objects.filter(first_name="test")) sql = connection.queries[-1]["sql"].lower() self.assertIn("select", sql) self.assertIn(Reporter._meta.db_table, sql) @@ -580,13 +575,13 @@ class BackendTestCase(TransactionTestCase): @mock.patch("django.db.backends.utils.logger") @override_settings(DEBUG=True) def test_queries_logger(self, mocked_logger): - sql = "SELECT 1" + connection.features.bare_select_suffix - sql = connection.ops.format_debug_sql(sql) + sql = "select 1" + connection.features.bare_select_suffix with connection.cursor() as cursor: cursor.execute(sql) params, kwargs = mocked_logger.debug.call_args self.assertIn("; alias=%s", params[0]) self.assertEqual(params[2], sql) + self.assertNotEqual(params[2], connection.ops.format_debug_sql(sql)) self.assertIsNone(params[3]) self.assertEqual(params[4], connection.alias) self.assertEqual( diff --git a/tests/test_runner/test_debug_sql.py b/tests/test_runner/test_debug_sql.py index 27fc4001c2..acf66633ef 100644 --- a/tests/test_runner/test_debug_sql.py +++ b/tests/test_runner/test_debug_sql.py @@ -1,12 +1,184 @@ +import logging import unittest from io import StringIO +from time import time +from unittest import mock -from django.db import connection +from django.db import DEFAULT_DB_ALIAS, connection, connections from django.test import TestCase -from django.test.runner import DiscoverRunner +from django.test.runner import DiscoverRunner, QueryFormatter from .models import Person +logger = logging.getLogger(__name__) + + +class QueryFormatterTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.format_sql_calls = [] + + def new_format_sql(self, sql): + # Use time() to introduce some uniqueness. + formatted = "Formatted! %s at %s" % (sql.upper(), time()) + self.format_sql_calls.append({sql: formatted}) + return formatted + + def make_handler(self, **formatter_kwargs): + formatter = QueryFormatter(**formatter_kwargs) + + handler = logging.StreamHandler(StringIO()) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + + original_level = logger.getEffectiveLevel() + logger.setLevel(logging.DEBUG) + self.addCleanup(logger.setLevel, original_level) + logger.addHandler(handler) + self.addCleanup(logger.removeHandler, handler) + + return handler + + def do_log(self, msg, *logger_args, alias=DEFAULT_DB_ALIAS, extra=None): + if extra is None: + extra = {} + if alias and "alias" not in extra: + extra["alias"] = alias + # Patch connection's format_debug_sql to ensure it was properly called. + with mock.patch.object( + connections[alias].ops, "format_debug_sql", side_effect=self.new_format_sql + ): + logger.info(msg, *logger_args, extra=extra) + + def assertLogRecord(self, handler, expected): + handler.stream.seek(0) + self.assertEqual(handler.stream.read().strip(), expected) + + def assertSQLFormatted(self, handler, sql, total_calls=1): + self.assertEqual(len(self.format_sql_calls), total_calls) + formatted_sql = self.format_sql_calls[0][sql] + expected = f"=> Executing query duration=3.142 sql={formatted_sql}" + self.assertLogRecord(handler, expected) + + def test_formats_sql_bracket_format_style(self): + handler = self.make_handler( + fmt="{message} duration={duration:.3f} sql={sql}", style="{" + ) + msg = "=> Executing query" + sql = "select * from foo" + + self.do_log(msg, extra={"sql": sql, "duration": 3.1416}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_named_fmt_format_style(self): + handler = self.make_handler( + fmt="%(message)s duration=%(duration).3f sql=%(sql)s" + ) + msg = "=> Executing query" + sql = "select * from foo" + + self.do_log(msg, extra={"sql": sql, "duration": 3.1416}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_named_percent_format_style(self): + handler = self.make_handler() + msg = "=> Executing query duration=%(duration).3f sql=%(sql)s" + sql = "select * from foo" + + self.do_log(msg, {"duration": 3.1416, "sql": sql}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_default_percent_format_style(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql = "select * from foo" + + self.do_log(msg, 3.1416, sql) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_multiple_matching_sql(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql = "select * from foo" + + self.do_log(msg, 3.1416, sql, extra={"duration": 3.1416, "sql": sql}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_multiple_non_matching_sql(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql1 = "select * from foo" + sql2 = "select * from other" + + self.do_log(msg, 3.1416, sql1, extra={"duration": 3.1416, "sql": sql2}) + self.assertSQLFormatted(handler, sql1, total_calls=2) + # Second format call is triggered since the sql are different. + self.assertEqual(list(self.format_sql_calls[1].keys()), [sql2]) + + def test_log_record_no_args(self): + handler = self.make_handler() + msg = "=> Executing query no args" + + self.do_log(msg) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg) + + def test_log_record_not_enough_args(self): + handler = self.make_handler() + msg = "=> Executing query one args %r" + args = "not formatted" + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_not_key_in_dict_args(self): + handler = self.make_handler() + msg = "=> Executing query missing sql key %(foo)r" + args = {"foo": "bar"} + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_no_alias(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + args = (3.1416, "select * from foo") + + self.do_log(msg, *args, extra={"alias": "does not exist"}) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_arg_none(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + args = (3.1416, None) + + self.do_log(msg, *args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_key_none(self): + handler = self.make_handler() + msg = "=> Executing query duration=%(duration).3f sql=%(sql)s" + args = {"duration": 3.1416, "sql": None} + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_extra_none(self): + handler = self.make_handler( + fmt="{message} duration={duration:.3f} sql={sql}", style="{" + ) + msg = "=> Executing query" + + self.do_log(msg, extra={"sql": None, "duration": 3.1416}) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, f"{msg} duration=3.142 sql=None") + @unittest.skipUnless( connection.vendor == "sqlite", "Only run on sqlite so we can check output SQL."