1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #35448 -- Created a custom StreamHandler for DebugSQLTextTestResult to keep track of the SQL query and reformat the records.

This commit is contained in:
HanaPoulpe 2024-06-08 16:20:01 +02:00
parent adae619426
commit f879952744

View File

@ -42,16 +42,49 @@ except ImportError:
tblib = None tblib = None
class DebugSQLStreamHandler(logging.StreamHandler):
def __init__(self, stream):
super().__init__(stream)
self.records = []
def handle(self, record):
self.records.append(record)
super().handle(record)
def handleError(self, record):
self.records.append(record)
super().handleError(record)
def format_sql_records(self):
message = []
for record in self.records:
line = self.format(record)
if hasattr(record, "sql"):
formatted_sql = sqlparse.format(
record.sql,
reindent=True,
keyword_case="upper",
)
line = line.replace("; ", ";\n\n").replace(
f"{record.sql}",
f"\n{formatted_sql}",
)
message.append(line)
return "\n\n\n".join(message)
class DebugSQLTextTestResult(unittest.TextTestResult): class DebugSQLTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity): def __init__(self, stream, descriptions, verbosity):
self.logger = logging.getLogger("django.db.backends") self.logger = logging.getLogger("django.db.backends")
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
self.debug_sql_stream = None self.debug_sql_stream = None
self.handler = None
super().__init__(stream, descriptions, verbosity) super().__init__(stream, descriptions, verbosity)
def startTest(self, test): def startTest(self, test):
self.debug_sql_stream = StringIO() self.debug_sql_stream = StringIO()
self.handler = logging.StreamHandler(self.debug_sql_stream) self.handler = DebugSQLStreamHandler(self.debug_sql_stream)
self.logger.addHandler(self.handler) self.logger.addHandler(self.handler)
super().startTest(test) super().startTest(test)
@ -62,6 +95,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
self.debug_sql_stream.seek(0) self.debug_sql_stream.seek(0)
self.stream.write(self.debug_sql_stream.read()) self.stream.write(self.debug_sql_stream.read())
self.stream.writeln(self.separator2) self.stream.writeln(self.separator2)
self.handler = None
def addError(self, test, err): def addError(self, test, err):
super().addError(test, err) super().addError(test, err)
@ -71,12 +105,16 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
else: else:
self.debug_sql_stream.seek(0) self.debug_sql_stream.seek(0)
sql = self.debug_sql_stream.read() sql = self.debug_sql_stream.read()
self.errors[-1] = self.errors[-1] + (sql,)
self.errors[-1] = self.errors[-1] + (sql, "")
def addFailure(self, test, err): def addFailure(self, test, err):
super().addFailure(test, err) super().addFailure(test, err)
self.debug_sql_stream.seek(0) self.debug_sql_stream.seek(0)
self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),) self.failures[-1] = self.failures[-1] + (
self.debug_sql_stream.read(),
self.handler.format_sql_records(),
)
def addSubTest(self, test, subtest, err): def addSubTest(self, test, subtest, err):
super().addSubTest(test, subtest, err) super().addSubTest(test, subtest, err)
@ -87,17 +125,18 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
if issubclass(err[0], test.failureException) if issubclass(err[0], test.failureException)
else self.errors else self.errors
) )
errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) errors[-1] = errors[-1] + (self.debug_sql_stream.read(), "")
def printErrorList(self, flavour, errors): def printErrorList(self, flavour, errors):
for test, err, sql_debug in errors: for test, err, sql_debug, formatted_sql_debug in errors:
self.stream.writeln(self.separator1) self.stream.writeln(self.separator1)
self.stream.writeln("%s: %s" % (flavour, self.getDescription(test))) self.stream.writeln("%s: %s" % (flavour, self.getDescription(test)))
self.stream.writeln(self.separator2) self.stream.writeln(self.separator2)
self.stream.writeln(err) self.stream.writeln(err)
self.stream.writeln(self.separator2) self.stream.writeln(self.separator2)
self.stream.writeln( self.stream.writeln(
sqlparse.format(sql_debug, reindent=True, keyword_case="upper") formatted_sql_debug
or sqlparse.format(sql_debug, reindent=True, keyword_case="upper")
) )