diff --git a/django/core/servers/basehttp.py b/django/core/servers/basehttp.py index 41719034fb..d62b88d286 100644 --- a/django/core/servers/basehttp.py +++ b/django/core/servers/basehttp.py @@ -18,6 +18,7 @@ from django.core.exceptions import ImproperlyConfigured from django.core.handlers.wsgi import LimitedStream from django.core.wsgi import get_wsgi_application from django.db import connections +from django.utils.log import log_message from django.utils.module_loading import import_string __all__ = ("WSGIServer", "WSGIRequestHandler") @@ -182,35 +183,27 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler): return self.client_address[0] def log_message(self, format, *args): - extra = { - "request": self.request, - "server_time": self.log_date_time_string(), - } - if args[1][0] == "4": + if args[1][0] == "4" and args[0].startswith("\x16\x03"): # 0x16 = Handshake, 0x03 = SSL 3.0 or TLS 1.x - if args[0].startswith("\x16\x03"): - extra["status_code"] = 500 - logger.error( - "You're accessing the development server over HTTPS, but " - "it only supports HTTP.", - extra=extra, - ) - return - - if args[1].isdigit() and len(args[1]) == 3: + format = ( + "You're accessing the development server over HTTPS, but it only " + "supports HTTP." + ) + status_code = 500 + args = () + elif args[1].isdigit() and len(args[1]) == 3: status_code = int(args[1]) - extra["status_code"] = status_code - - if status_code >= 500: - level = logger.error - elif status_code >= 400: - level = logger.warning - else: - level = logger.info else: - level = logger.info + status_code = None - level(format, *args, extra=extra) + log_message( + logger, + format, + *args, + request=self.request, + status_code=status_code, + server_time=self.log_date_time_string(), + ) def get_environ(self): # Strip all headers with underscores in the name before constructing diff --git a/django/utils/log.py b/django/utils/log.py index 67a40270f0..d4e96a9816 100644 --- a/django/utils/log.py +++ b/django/utils/log.py @@ -214,6 +214,46 @@ class ServerFormatter(logging.Formatter): return self._fmt.find("{server_time}") >= 0 +def log_message( + logger, + message, + *args, + level=None, + status_code=None, + request=None, + exception=None, + **extra, +): + """Log `message` using `logger` based on `status_code` and logger `level`. + + Pass `request`, `status_code` (if defined) and any provided `extra` as such + to the logging method, + + Arguments from `args` will be escaped to avoid potential log injections. + + """ + extra = {"request": request, **extra} + if status_code is not None: + extra["status_code"] = status_code + if level is None: + if status_code >= 500: + level = "error" + elif status_code >= 400: + level = "warning" + + escaped_args = tuple( + a.encode("unicode_escape").decode("ascii") if isinstance(a, str) else a + for a in args + ) + + getattr(logger, level or "info")( + message, + *escaped_args, + extra=extra, + exc_info=exception, + ) + + def log_response( message, *args, @@ -237,26 +277,13 @@ def log_response( if getattr(response, "_has_been_logged", False): return - if level is None: - if response.status_code >= 500: - level = "error" - elif response.status_code >= 400: - level = "warning" - else: - level = "info" - - escaped_args = tuple( - a.encode("unicode_escape").decode("ascii") if isinstance(a, str) else a - for a in args - ) - - getattr(logger, level)( + log_message( + logger, message, - *escaped_args, - extra={ - "status_code": response.status_code, - "request": request, - }, - exc_info=exception, + *args, + level=level, + status_code=response.status_code, + request=request, + exception=exception, ) response._has_been_logged = True diff --git a/tests/servers/test_basehttp.py b/tests/servers/test_basehttp.py index cc4701114a..9190fc8a20 100644 --- a/tests/servers/test_basehttp.py +++ b/tests/servers/test_basehttp.py @@ -50,6 +50,21 @@ class WSGIRequestHandlerTestCase(SimpleTestCase): cm.records[0].levelname, wrong_level.upper() ) + def test_log_message_escapes_control_sequences(self): + request = WSGIRequest(self.request_factory.get("/").environ) + request.makefile = lambda *args, **kwargs: BytesIO() + handler = WSGIRequestHandler(request, "192.168.0.2", None) + + malicious_path = "\x1b[31mALERT\x1b[0m" + + with self.assertLogs("django.server", "WARNING") as cm: + handler.log_message("GET %s %s", malicious_path, "404") + + log = cm.output[0] + + self.assertNotIn("\x1b[31m", log) + self.assertIn("\\x1b[31mALERT\\x1b[0m", log) + def test_https(self): request = WSGIRequest(self.request_factory.get("/").environ) request.makefile = lambda *args, **kwargs: BytesIO()