From 1a03a984ab3728253f964ba16cd8d806f76bddf9 Mon Sep 17 00:00:00 2001 From: Natalia <124304+nessita@users.noreply.github.com> Date: Tue, 3 Jun 2025 15:54:16 -0300 Subject: [PATCH] Fixed #36380 -- Deferred SQL formatting when running tests with --debug-sql. Thanks to Jacob Walls for the report and previous iterations of this fix, to Simon Charette for the logging formatter idea, and to Tim Graham for testing and ensuring that 3rd party backends remain compatible. This partially reverts d8f093908c504ae0dbc39d3f5231f7d7920dde37. Refs #36112, #35448. Co-authored-by: Jacob Walls --- django/db/backends/utils.py | 2 +- django/test/runner.py | 55 ++++++--- tests/backends/tests.py | 11 +- tests/test_runner/test_debug_sql.py | 176 +++++++++++++++++++++++++++- 4 files changed, 216 insertions(+), 28 deletions(-) diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 568f510a67..ab0ea8258b 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -151,7 +151,7 @@ class CursorDebugWrapper(CursorWrapper): logger.debug( "(%.3f) %s; args=%s; alias=%s", duration, - self.db.ops.format_debug_sql(sql), + sql, params, self.db.alias, extra={ diff --git a/django/test/runner.py b/django/test/runner.py index c8bb16e7b3..3e5c319ade 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -16,7 +16,6 @@ import unittest.suite from collections import defaultdict from contextlib import contextmanager from importlib import import_module -from io import StringIO import django from django.core.management import call_command @@ -41,16 +40,47 @@ except ImportError: tblib = None +class QueryFormatter(logging.Formatter): + def format(self, record): + if (alias := getattr(record, "alias", None)) in connections: + format_sql = connections[alias].ops.format_debug_sql + + sql = None + formatted_sql = None + if args := record.args: + if isinstance(args, tuple) and len(args) > 1 and (sql := args[1]): + record.args = (args[0], formatted_sql := format_sql(sql), *args[2:]) + elif isinstance(record.args, dict) and (sql := record.args.get("sql")): + record.args["sql"] = formatted_sql = format_sql(sql) + + if extra_sql := getattr(record, "sql", None): + if extra_sql == sql: + record.sql = formatted_sql + else: + record.sql = format_sql(extra_sql) + + return super().format(record) + + class DebugSQLTextTestResult(unittest.TextTestResult): def __init__(self, stream, descriptions, verbosity): self.logger = logging.getLogger("django.db.backends") self.logger.setLevel(logging.DEBUG) - self.debug_sql_stream = None + self.handler = None super().__init__(stream, descriptions, verbosity) + def _read_logger_stream(self): + if self.handler is None: + # Error before tests e.g. in setUpTestData(). + sql = "" + else: + self.handler.stream.seek(0) + sql = self.handler.stream.read() + return sql + def startTest(self, test): - self.debug_sql_stream = StringIO() - self.handler = logging.StreamHandler(self.debug_sql_stream) + self.handler = logging.StreamHandler(io.StringIO()) + self.handler.setFormatter(QueryFormatter()) self.logger.addHandler(self.handler) super().startTest(test) @@ -58,35 +88,26 @@ class DebugSQLTextTestResult(unittest.TextTestResult): super().stopTest(test) self.logger.removeHandler(self.handler) if self.showAll: - self.debug_sql_stream.seek(0) - self.stream.write(self.debug_sql_stream.read()) + self.stream.write(self._read_logger_stream()) self.stream.writeln(self.separator2) def addError(self, test, err): super().addError(test, err) - if self.debug_sql_stream is None: - # Error before tests e.g. in setUpTestData(). - sql = "" - else: - self.debug_sql_stream.seek(0) - sql = self.debug_sql_stream.read() - self.errors[-1] = self.errors[-1] + (sql,) + self.errors[-1] = self.errors[-1] + (self._read_logger_stream(),) def addFailure(self, test, err): super().addFailure(test, err) - self.debug_sql_stream.seek(0) - self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),) + self.failures[-1] = self.failures[-1] + (self._read_logger_stream(),) def addSubTest(self, test, subtest, err): super().addSubTest(test, subtest, err) if err is not None: - self.debug_sql_stream.seek(0) errors = ( self.failures if issubclass(err[0], test.failureException) else self.errors ) - errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) + errors[-1] = errors[-1] + (self._read_logger_stream(),) def printErrorList(self, flavour, errors): for test, err, sql_debug in errors: diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 0e5348e248..5e38f0112d 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -83,12 +83,7 @@ class LastExecutedQueryTest(TestCase): connection.ops.last_executed_query(cursor, "SELECT %s" + suffix, (1,)) def test_debug_sql(self): - qs = Reporter.objects.filter(first_name="test") - ops = connections[qs.db].ops - with mock.patch.object(ops, "format_debug_sql") as format_debug_sql: - list(qs) - # Queries are formatted with DatabaseOperations.format_debug_sql(). - format_debug_sql.assert_called() + list(Reporter.objects.filter(first_name="test")) sql = connection.queries[-1]["sql"].lower() self.assertIn("select", sql) self.assertIn(Reporter._meta.db_table, sql) @@ -580,13 +575,13 @@ class BackendTestCase(TransactionTestCase): @mock.patch("django.db.backends.utils.logger") @override_settings(DEBUG=True) def test_queries_logger(self, mocked_logger): - sql = "SELECT 1" + connection.features.bare_select_suffix - sql = connection.ops.format_debug_sql(sql) + sql = "select 1" + connection.features.bare_select_suffix with connection.cursor() as cursor: cursor.execute(sql) params, kwargs = mocked_logger.debug.call_args self.assertIn("; alias=%s", params[0]) self.assertEqual(params[2], sql) + self.assertNotEqual(params[2], connection.ops.format_debug_sql(sql)) self.assertIsNone(params[3]) self.assertEqual(params[4], connection.alias) self.assertEqual( diff --git a/tests/test_runner/test_debug_sql.py b/tests/test_runner/test_debug_sql.py index 27fc4001c2..acf66633ef 100644 --- a/tests/test_runner/test_debug_sql.py +++ b/tests/test_runner/test_debug_sql.py @@ -1,12 +1,184 @@ +import logging import unittest from io import StringIO +from time import time +from unittest import mock -from django.db import connection +from django.db import DEFAULT_DB_ALIAS, connection, connections from django.test import TestCase -from django.test.runner import DiscoverRunner +from django.test.runner import DiscoverRunner, QueryFormatter from .models import Person +logger = logging.getLogger(__name__) + + +class QueryFormatterTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.format_sql_calls = [] + + def new_format_sql(self, sql): + # Use time() to introduce some uniqueness. + formatted = "Formatted! %s at %s" % (sql.upper(), time()) + self.format_sql_calls.append({sql: formatted}) + return formatted + + def make_handler(self, **formatter_kwargs): + formatter = QueryFormatter(**formatter_kwargs) + + handler = logging.StreamHandler(StringIO()) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + + original_level = logger.getEffectiveLevel() + logger.setLevel(logging.DEBUG) + self.addCleanup(logger.setLevel, original_level) + logger.addHandler(handler) + self.addCleanup(logger.removeHandler, handler) + + return handler + + def do_log(self, msg, *logger_args, alias=DEFAULT_DB_ALIAS, extra=None): + if extra is None: + extra = {} + if alias and "alias" not in extra: + extra["alias"] = alias + # Patch connection's format_debug_sql to ensure it was properly called. + with mock.patch.object( + connections[alias].ops, "format_debug_sql", side_effect=self.new_format_sql + ): + logger.info(msg, *logger_args, extra=extra) + + def assertLogRecord(self, handler, expected): + handler.stream.seek(0) + self.assertEqual(handler.stream.read().strip(), expected) + + def assertSQLFormatted(self, handler, sql, total_calls=1): + self.assertEqual(len(self.format_sql_calls), total_calls) + formatted_sql = self.format_sql_calls[0][sql] + expected = f"=> Executing query duration=3.142 sql={formatted_sql}" + self.assertLogRecord(handler, expected) + + def test_formats_sql_bracket_format_style(self): + handler = self.make_handler( + fmt="{message} duration={duration:.3f} sql={sql}", style="{" + ) + msg = "=> Executing query" + sql = "select * from foo" + + self.do_log(msg, extra={"sql": sql, "duration": 3.1416}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_named_fmt_format_style(self): + handler = self.make_handler( + fmt="%(message)s duration=%(duration).3f sql=%(sql)s" + ) + msg = "=> Executing query" + sql = "select * from foo" + + self.do_log(msg, extra={"sql": sql, "duration": 3.1416}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_named_percent_format_style(self): + handler = self.make_handler() + msg = "=> Executing query duration=%(duration).3f sql=%(sql)s" + sql = "select * from foo" + + self.do_log(msg, {"duration": 3.1416, "sql": sql}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_default_percent_format_style(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql = "select * from foo" + + self.do_log(msg, 3.1416, sql) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_multiple_matching_sql(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql = "select * from foo" + + self.do_log(msg, 3.1416, sql, extra={"duration": 3.1416, "sql": sql}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_multiple_non_matching_sql(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql1 = "select * from foo" + sql2 = "select * from other" + + self.do_log(msg, 3.1416, sql1, extra={"duration": 3.1416, "sql": sql2}) + self.assertSQLFormatted(handler, sql1, total_calls=2) + # Second format call is triggered since the sql are different. + self.assertEqual(list(self.format_sql_calls[1].keys()), [sql2]) + + def test_log_record_no_args(self): + handler = self.make_handler() + msg = "=> Executing query no args" + + self.do_log(msg) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg) + + def test_log_record_not_enough_args(self): + handler = self.make_handler() + msg = "=> Executing query one args %r" + args = "not formatted" + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_not_key_in_dict_args(self): + handler = self.make_handler() + msg = "=> Executing query missing sql key %(foo)r" + args = {"foo": "bar"} + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_no_alias(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + args = (3.1416, "select * from foo") + + self.do_log(msg, *args, extra={"alias": "does not exist"}) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_arg_none(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + args = (3.1416, None) + + self.do_log(msg, *args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_key_none(self): + handler = self.make_handler() + msg = "=> Executing query duration=%(duration).3f sql=%(sql)s" + args = {"duration": 3.1416, "sql": None} + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_extra_none(self): + handler = self.make_handler( + fmt="{message} duration={duration:.3f} sql={sql}", style="{" + ) + msg = "=> Executing query" + + self.do_log(msg, extra={"sql": None, "duration": 3.1416}) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, f"{msg} duration=3.142 sql=None") + @unittest.skipUnless( connection.vendor == "sqlite", "Only run on sqlite so we can check output SQL."