1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #35448 -- Created a custom StreamHandler for DebugSQLTextTestResult to keep track of the SQL query and reformat the records.

This commit is contained in:
HanaPoulpe 2024-06-08 16:20:01 +02:00
parent adae619426
commit f879952744

View File

@ -42,16 +42,49 @@ except ImportError:
tblib = None
class DebugSQLStreamHandler(logging.StreamHandler):
def __init__(self, stream):
super().__init__(stream)
self.records = []
def handle(self, record):
self.records.append(record)
super().handle(record)
def handleError(self, record):
self.records.append(record)
super().handleError(record)
def format_sql_records(self):
message = []
for record in self.records:
line = self.format(record)
if hasattr(record, "sql"):
formatted_sql = sqlparse.format(
record.sql,
reindent=True,
keyword_case="upper",
)
line = line.replace("; ", ";\n\n").replace(
f"{record.sql}",
f"\n{formatted_sql}",
)
message.append(line)
return "\n\n\n".join(message)
class DebugSQLTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity):
self.logger = logging.getLogger("django.db.backends")
self.logger.setLevel(logging.DEBUG)
self.debug_sql_stream = None
self.handler = None
super().__init__(stream, descriptions, verbosity)
def startTest(self, test):
self.debug_sql_stream = StringIO()
self.handler = logging.StreamHandler(self.debug_sql_stream)
self.handler = DebugSQLStreamHandler(self.debug_sql_stream)
self.logger.addHandler(self.handler)
super().startTest(test)
@ -62,6 +95,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
self.debug_sql_stream.seek(0)
self.stream.write(self.debug_sql_stream.read())
self.stream.writeln(self.separator2)
self.handler = None
def addError(self, test, err):
super().addError(test, err)
@ -71,12 +105,16 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
else:
self.debug_sql_stream.seek(0)
sql = self.debug_sql_stream.read()
self.errors[-1] = self.errors[-1] + (sql,)
self.errors[-1] = self.errors[-1] + (sql, "")
def addFailure(self, test, err):
super().addFailure(test, err)
self.debug_sql_stream.seek(0)
self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),)
self.failures[-1] = self.failures[-1] + (
self.debug_sql_stream.read(),
self.handler.format_sql_records(),
)
def addSubTest(self, test, subtest, err):
super().addSubTest(test, subtest, err)
@ -87,17 +125,18 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
if issubclass(err[0], test.failureException)
else self.errors
)
errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
errors[-1] = errors[-1] + (self.debug_sql_stream.read(), "")
def printErrorList(self, flavour, errors):
for test, err, sql_debug in errors:
for test, err, sql_debug, formatted_sql_debug in errors:
self.stream.writeln(self.separator1)
self.stream.writeln("%s: %s" % (flavour, self.getDescription(test)))
self.stream.writeln(self.separator2)
self.stream.writeln(err)
self.stream.writeln(self.separator2)
self.stream.writeln(
sqlparse.format(sql_debug, reindent=True, keyword_case="upper")
formatted_sql_debug
or sqlparse.format(sql_debug, reindent=True, keyword_case="upper")
)