mirror of
https://github.com/django/django.git
synced 2024-12-22 09:05:43 +00:00
Fixed #35448 -- Created a custom StreamHandler
for DebugSQLTextTestResult
to keep track of the SQL query and reformat the records.
This commit is contained in:
parent
adae619426
commit
f879952744
@ -42,16 +42,49 @@ except ImportError:
|
||||
tblib = None
|
||||
|
||||
|
||||
class DebugSQLStreamHandler(logging.StreamHandler):
|
||||
def __init__(self, stream):
|
||||
super().__init__(stream)
|
||||
self.records = []
|
||||
|
||||
def handle(self, record):
|
||||
self.records.append(record)
|
||||
super().handle(record)
|
||||
|
||||
def handleError(self, record):
|
||||
self.records.append(record)
|
||||
super().handleError(record)
|
||||
|
||||
def format_sql_records(self):
|
||||
message = []
|
||||
for record in self.records:
|
||||
line = self.format(record)
|
||||
if hasattr(record, "sql"):
|
||||
formatted_sql = sqlparse.format(
|
||||
record.sql,
|
||||
reindent=True,
|
||||
keyword_case="upper",
|
||||
)
|
||||
line = line.replace("; ", ";\n\n").replace(
|
||||
f"{record.sql}",
|
||||
f"\n{formatted_sql}",
|
||||
)
|
||||
|
||||
message.append(line)
|
||||
return "\n\n\n".join(message)
|
||||
|
||||
|
||||
class DebugSQLTextTestResult(unittest.TextTestResult):
|
||||
def __init__(self, stream, descriptions, verbosity):
|
||||
self.logger = logging.getLogger("django.db.backends")
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
self.debug_sql_stream = None
|
||||
self.handler = None
|
||||
super().__init__(stream, descriptions, verbosity)
|
||||
|
||||
def startTest(self, test):
|
||||
self.debug_sql_stream = StringIO()
|
||||
self.handler = logging.StreamHandler(self.debug_sql_stream)
|
||||
self.handler = DebugSQLStreamHandler(self.debug_sql_stream)
|
||||
self.logger.addHandler(self.handler)
|
||||
super().startTest(test)
|
||||
|
||||
@ -62,6 +95,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.stream.write(self.debug_sql_stream.read())
|
||||
self.stream.writeln(self.separator2)
|
||||
self.handler = None
|
||||
|
||||
def addError(self, test, err):
|
||||
super().addError(test, err)
|
||||
@ -71,12 +105,16 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
|
||||
else:
|
||||
self.debug_sql_stream.seek(0)
|
||||
sql = self.debug_sql_stream.read()
|
||||
self.errors[-1] = self.errors[-1] + (sql,)
|
||||
|
||||
self.errors[-1] = self.errors[-1] + (sql, "")
|
||||
|
||||
def addFailure(self, test, err):
|
||||
super().addFailure(test, err)
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),)
|
||||
self.failures[-1] = self.failures[-1] + (
|
||||
self.debug_sql_stream.read(),
|
||||
self.handler.format_sql_records(),
|
||||
)
|
||||
|
||||
def addSubTest(self, test, subtest, err):
|
||||
super().addSubTest(test, subtest, err)
|
||||
@ -87,17 +125,18 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
|
||||
if issubclass(err[0], test.failureException)
|
||||
else self.errors
|
||||
)
|
||||
errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
|
||||
errors[-1] = errors[-1] + (self.debug_sql_stream.read(), "")
|
||||
|
||||
def printErrorList(self, flavour, errors):
|
||||
for test, err, sql_debug in errors:
|
||||
for test, err, sql_debug, formatted_sql_debug in errors:
|
||||
self.stream.writeln(self.separator1)
|
||||
self.stream.writeln("%s: %s" % (flavour, self.getDescription(test)))
|
||||
self.stream.writeln(self.separator2)
|
||||
self.stream.writeln(err)
|
||||
self.stream.writeln(self.separator2)
|
||||
self.stream.writeln(
|
||||
sqlparse.format(sql_debug, reindent=True, keyword_case="upper")
|
||||
formatted_sql_debug
|
||||
or sqlparse.format(sql_debug, reindent=True, keyword_case="upper")
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user