diff --git a/django/test/runner.py b/django/test/runner.py index 27eb9613e9..ed94f92808 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -42,16 +42,49 @@ except ImportError: tblib = None +class DebugSQLStreamHandler(logging.StreamHandler): + def __init__(self, stream): + super().__init__(stream) + self.records = [] + + def handle(self, record): + self.records.append(record) + super().handle(record) + + def handleError(self, record): + self.records.append(record) + super().handleError(record) + + def format_sql_records(self): + message = [] + for record in self.records: + line = self.format(record) + if hasattr(record, "sql"): + formatted_sql = sqlparse.format( + record.sql, + reindent=True, + keyword_case="upper", + ) + line = line.replace("; ", ";\n\n").replace( + f"{record.sql}", + f"\n{formatted_sql}", + ) + + message.append(line) + return "\n\n\n".join(message) + + class DebugSQLTextTestResult(unittest.TextTestResult): def __init__(self, stream, descriptions, verbosity): self.logger = logging.getLogger("django.db.backends") self.logger.setLevel(logging.DEBUG) self.debug_sql_stream = None + self.handler = None super().__init__(stream, descriptions, verbosity) def startTest(self, test): self.debug_sql_stream = StringIO() - self.handler = logging.StreamHandler(self.debug_sql_stream) + self.handler = DebugSQLStreamHandler(self.debug_sql_stream) self.logger.addHandler(self.handler) super().startTest(test) @@ -62,6 +95,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult): self.debug_sql_stream.seek(0) self.stream.write(self.debug_sql_stream.read()) self.stream.writeln(self.separator2) + self.handler = None def addError(self, test, err): super().addError(test, err) @@ -71,12 +105,16 @@ class DebugSQLTextTestResult(unittest.TextTestResult): else: self.debug_sql_stream.seek(0) sql = self.debug_sql_stream.read() - self.errors[-1] = self.errors[-1] + (sql,) + + self.errors[-1] = self.errors[-1] + (sql, "") def addFailure(self, test, err): super().addFailure(test, err) self.debug_sql_stream.seek(0) - self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),) + self.failures[-1] = self.failures[-1] + ( + self.debug_sql_stream.read(), + self.handler.format_sql_records(), + ) def addSubTest(self, test, subtest, err): super().addSubTest(test, subtest, err) @@ -87,17 +125,18 @@ class DebugSQLTextTestResult(unittest.TextTestResult): if issubclass(err[0], test.failureException) else self.errors ) - errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) + errors[-1] = errors[-1] + (self.debug_sql_stream.read(), "") def printErrorList(self, flavour, errors): - for test, err, sql_debug in errors: + for test, err, sql_debug, formatted_sql_debug in errors: self.stream.writeln(self.separator1) self.stream.writeln("%s: %s" % (flavour, self.getDescription(test))) self.stream.writeln(self.separator2) self.stream.writeln(err) self.stream.writeln(self.separator2) self.stream.writeln( - sqlparse.format(sql_debug, reindent=True, keyword_case="upper") + formatted_sql_debug + or sqlparse.format(sql_debug, reindent=True, keyword_case="upper") )