mirror of
https://github.com/django/django.git
synced 2025-06-05 03:29:12 +00:00
[5.0.x] Fixed #35059 -- Ensured that ASGIHandler always sends the request_finished signal.
Prior to this work, when async tasks that process the request are cancelled due to receiving an early "http.disconnect" ASGI message, the request_finished signal was not being sent, potentially leading to resource leaks (such as database connections). This branch ensures that the request_finished signal is sent even in the case of early termination of the response. Regression in 64cea1e48f285ea2162c669208d95188b32bbc82. Co-authored-by: Natalia <124304+nessita@users.noreply.github.com> Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es> Backport of 11393ab1316f973c5fbb534305750740d909b4e4 from main
This commit is contained in:
parent
bbb9ef3c62
commit
f1fbd061ac
@ -186,11 +186,18 @@ class ASGIHandler(base.BaseHandler):
|
|||||||
if request is None:
|
if request is None:
|
||||||
body_file.close()
|
body_file.close()
|
||||||
await self.send_response(error_response, send)
|
await self.send_response(error_response, send)
|
||||||
|
await sync_to_async(error_response.close)()
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_request(request, send):
|
async def process_request(request, send):
|
||||||
response = await self.run_get_response(request)
|
response = await self.run_get_response(request)
|
||||||
await self.send_response(response, send)
|
try:
|
||||||
|
await self.send_response(response, send)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Client disconnected during send_response (ignore exception).
|
||||||
|
pass
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
# Try to catch a disconnect while getting response.
|
# Try to catch a disconnect while getting response.
|
||||||
tasks = [
|
tasks = [
|
||||||
@ -221,6 +228,14 @@ class ASGIHandler(base.BaseHandler):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Task re-raised the CancelledError as expected.
|
# Task re-raised the CancelledError as expected.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = tasks[1].result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
await signals.request_finished.asend(sender=self.__class__)
|
||||||
|
else:
|
||||||
|
await sync_to_async(response.close)()
|
||||||
|
|
||||||
body_file.close()
|
body_file.close()
|
||||||
|
|
||||||
async def listen_for_disconnect(self, receive):
|
async def listen_for_disconnect(self, receive):
|
||||||
@ -346,7 +361,6 @@ class ASGIHandler(base.BaseHandler):
|
|||||||
"more_body": not last,
|
"more_body": not last,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await sync_to_async(response.close, thread_sensitive=True)()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def chunk_bytes(cls, data):
|
def chunk_bytes(cls, data):
|
||||||
|
@ -28,3 +28,7 @@ Bugfixes
|
|||||||
* Fixed a regression in Django 5.0 that caused a crash of the ``dumpdata``
|
* Fixed a regression in Django 5.0 that caused a crash of the ``dumpdata``
|
||||||
management command when a base queryset used ``prefetch_related()``
|
management command when a base queryset used ``prefetch_related()``
|
||||||
(:ticket:`35159`).
|
(:ticket:`35159`).
|
||||||
|
|
||||||
|
* Fixed a regression in Django 5.0 that caused the ``request_finished`` signal to
|
||||||
|
sometimes not be fired when running Django through an ASGI server, resulting
|
||||||
|
in potential resource leaks (:ticket:`35059`).
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from asgiref.testing import ApplicationCommunicator
|
from asgiref.testing import ApplicationCommunicator
|
||||||
|
|
||||||
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
|
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
|
||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
|
from django.core.exceptions import RequestDataTooBig
|
||||||
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
|
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
|
||||||
from django.core.signals import request_finished, request_started
|
from django.core.signals import request_finished, request_started
|
||||||
from django.db import close_old_connections
|
from django.db import close_old_connections
|
||||||
@ -20,6 +23,7 @@ from django.test import (
|
|||||||
)
|
)
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
from django.utils.http import http_date
|
from django.utils.http import http_date
|
||||||
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
|
|
||||||
from .urls import sync_waiter, test_filename
|
from .urls import sync_waiter, test_filename
|
||||||
|
|
||||||
@ -207,6 +211,96 @@ class ASGITest(SimpleTestCase):
|
|||||||
self.assertEqual(response_body["type"], "http.response.body")
|
self.assertEqual(response_body["type"], "http.response.body")
|
||||||
self.assertEqual(response_body["body"], b"Echo!")
|
self.assertEqual(response_body["body"], b"Echo!")
|
||||||
|
|
||||||
|
async def test_create_request_error(self):
|
||||||
|
# Track request_finished signal.
|
||||||
|
signal_handler = SignalHandler()
|
||||||
|
request_finished.connect(signal_handler)
|
||||||
|
self.addCleanup(request_finished.disconnect, signal_handler)
|
||||||
|
|
||||||
|
# Request class that always fails creation with RequestDataTooBig.
|
||||||
|
class TestASGIRequest(ASGIRequest):
|
||||||
|
|
||||||
|
def __init__(self, scope, body_file):
|
||||||
|
super().__init__(scope, body_file)
|
||||||
|
raise RequestDataTooBig()
|
||||||
|
|
||||||
|
# Handler to use the custom request class.
|
||||||
|
class TestASGIHandler(ASGIHandler):
|
||||||
|
request_class = TestASGIRequest
|
||||||
|
|
||||||
|
application = TestASGIHandler()
|
||||||
|
scope = self.async_request_factory._base_scope(path="/not-important/")
|
||||||
|
communicator = ApplicationCommunicator(application, scope)
|
||||||
|
|
||||||
|
# Initiate request.
|
||||||
|
await communicator.send_input({"type": "http.request"})
|
||||||
|
# Give response.close() time to finish.
|
||||||
|
await communicator.wait()
|
||||||
|
|
||||||
|
self.assertEqual(len(signal_handler.calls), 1)
|
||||||
|
self.assertNotEqual(
|
||||||
|
signal_handler.calls[0]["thread"], threading.current_thread()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_cancel_post_request_with_sync_processing(self):
|
||||||
|
"""
|
||||||
|
The request.body object should be available and readable in view
|
||||||
|
code, even if the ASGIHandler cancels processing part way through.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
# Events to monitor the view processing from the parent test code.
|
||||||
|
view_started_event = asyncio.Event()
|
||||||
|
view_finished_event = asyncio.Event()
|
||||||
|
# Record received request body or exceptions raised in the test view
|
||||||
|
outcome = []
|
||||||
|
|
||||||
|
# This view will run in a new thread because it is wrapped in
|
||||||
|
# sync_to_async. The view consumes the POST body data after a short
|
||||||
|
# delay. The test will cancel the request using http.disconnect during
|
||||||
|
# the delay, but because this is a sync view the code runs to
|
||||||
|
# completion. There should be no exceptions raised inside the view
|
||||||
|
# code.
|
||||||
|
@csrf_exempt
|
||||||
|
@sync_to_async
|
||||||
|
def post_view(request):
|
||||||
|
try:
|
||||||
|
loop.call_soon_threadsafe(view_started_event.set)
|
||||||
|
time.sleep(0.1)
|
||||||
|
# Do something to read request.body after pause
|
||||||
|
outcome.append({"request_body": request.body})
|
||||||
|
return HttpResponse("ok")
|
||||||
|
except Exception as e:
|
||||||
|
outcome.append({"exception": e})
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(view_finished_event.set)
|
||||||
|
|
||||||
|
# Request class to use the view.
|
||||||
|
class TestASGIRequest(ASGIRequest):
|
||||||
|
urlconf = (path("post/", post_view),)
|
||||||
|
|
||||||
|
# Handler to use request class.
|
||||||
|
class TestASGIHandler(ASGIHandler):
|
||||||
|
request_class = TestASGIRequest
|
||||||
|
|
||||||
|
application = TestASGIHandler()
|
||||||
|
scope = self.async_request_factory._base_scope(
|
||||||
|
method="POST",
|
||||||
|
path="/post/",
|
||||||
|
)
|
||||||
|
communicator = ApplicationCommunicator(application, scope)
|
||||||
|
|
||||||
|
await communicator.send_input({"type": "http.request", "body": b"Body data!"})
|
||||||
|
|
||||||
|
# Wait until the view code has started, then send http.disconnect.
|
||||||
|
await view_started_event.wait()
|
||||||
|
await communicator.send_input({"type": "http.disconnect"})
|
||||||
|
# Wait until view code has finished.
|
||||||
|
await view_finished_event.wait()
|
||||||
|
with self.assertRaises(asyncio.TimeoutError):
|
||||||
|
await communicator.receive_output()
|
||||||
|
|
||||||
|
self.assertEqual(outcome, [{"request_body": b"Body data!"}])
|
||||||
|
|
||||||
async def test_untouched_request_body_gets_closed(self):
|
async def test_untouched_request_body_gets_closed(self):
|
||||||
application = get_asgi_application()
|
application = get_asgi_application()
|
||||||
scope = self.async_request_factory._base_scope(method="POST", path="/post/")
|
scope = self.async_request_factory._base_scope(method="POST", path="/post/")
|
||||||
@ -347,7 +441,9 @@ class ASGITest(SimpleTestCase):
|
|||||||
# AsyncToSync should have executed the signals in the same thread.
|
# AsyncToSync should have executed the signals in the same thread.
|
||||||
self.assertEqual(len(signal_handler.calls), 2)
|
self.assertEqual(len(signal_handler.calls), 2)
|
||||||
request_started_call, request_finished_call = signal_handler.calls
|
request_started_call, request_finished_call = signal_handler.calls
|
||||||
self.assertEqual(request_started_call["thread"], request_finished_call["thread"])
|
self.assertEqual(
|
||||||
|
request_started_call["thread"], request_finished_call["thread"]
|
||||||
|
)
|
||||||
|
|
||||||
async def test_concurrent_async_uses_multiple_thread_pools(self):
|
async def test_concurrent_async_uses_multiple_thread_pools(self):
|
||||||
sync_waiter.active_threads.clear()
|
sync_waiter.active_threads.clear()
|
||||||
@ -383,6 +479,10 @@ class ASGITest(SimpleTestCase):
|
|||||||
async def test_asyncio_cancel_error(self):
|
async def test_asyncio_cancel_error(self):
|
||||||
# Flag to check if the view was cancelled.
|
# Flag to check if the view was cancelled.
|
||||||
view_did_cancel = False
|
view_did_cancel = False
|
||||||
|
# Track request_finished signal.
|
||||||
|
signal_handler = SignalHandler()
|
||||||
|
request_finished.connect(signal_handler)
|
||||||
|
self.addCleanup(request_finished.disconnect, signal_handler)
|
||||||
|
|
||||||
# A view that will listen for the cancelled error.
|
# A view that will listen for the cancelled error.
|
||||||
async def view(request):
|
async def view(request):
|
||||||
@ -417,6 +517,13 @@ class ASGITest(SimpleTestCase):
|
|||||||
# Give response.close() time to finish.
|
# Give response.close() time to finish.
|
||||||
await communicator.wait()
|
await communicator.wait()
|
||||||
self.assertIs(view_did_cancel, False)
|
self.assertIs(view_did_cancel, False)
|
||||||
|
# Exactly one call to request_finished handler.
|
||||||
|
self.assertEqual(len(signal_handler.calls), 1)
|
||||||
|
handler_call = signal_handler.calls.pop()
|
||||||
|
# It was NOT on the async thread.
|
||||||
|
self.assertNotEqual(handler_call["thread"], threading.current_thread())
|
||||||
|
# The signal sender is the handler class.
|
||||||
|
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
|
||||||
|
|
||||||
# Request cycle with a disconnect before the view can respond.
|
# Request cycle with a disconnect before the view can respond.
|
||||||
application = TestASGIHandler()
|
application = TestASGIHandler()
|
||||||
@ -432,11 +539,22 @@ class ASGITest(SimpleTestCase):
|
|||||||
await communicator.receive_output()
|
await communicator.receive_output()
|
||||||
await communicator.wait()
|
await communicator.wait()
|
||||||
self.assertIs(view_did_cancel, True)
|
self.assertIs(view_did_cancel, True)
|
||||||
|
# Exactly one call to request_finished handler.
|
||||||
|
self.assertEqual(len(signal_handler.calls), 1)
|
||||||
|
handler_call = signal_handler.calls.pop()
|
||||||
|
# It was NOT on the async thread.
|
||||||
|
self.assertNotEqual(handler_call["thread"], threading.current_thread())
|
||||||
|
# The signal sender is the handler class.
|
||||||
|
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
|
||||||
|
|
||||||
async def test_asyncio_streaming_cancel_error(self):
|
async def test_asyncio_streaming_cancel_error(self):
|
||||||
# Similar to test_asyncio_cancel_error(), but during a streaming
|
# Similar to test_asyncio_cancel_error(), but during a streaming
|
||||||
# response.
|
# response.
|
||||||
view_did_cancel = False
|
view_did_cancel = False
|
||||||
|
# Track request_finished signals.
|
||||||
|
signal_handler = SignalHandler()
|
||||||
|
request_finished.connect(signal_handler)
|
||||||
|
self.addCleanup(request_finished.disconnect, signal_handler)
|
||||||
|
|
||||||
async def streaming_response():
|
async def streaming_response():
|
||||||
nonlocal view_did_cancel
|
nonlocal view_did_cancel
|
||||||
@ -471,6 +589,13 @@ class ASGITest(SimpleTestCase):
|
|||||||
self.assertEqual(response_body["body"], b"Hello World!")
|
self.assertEqual(response_body["body"], b"Hello World!")
|
||||||
await communicator.wait()
|
await communicator.wait()
|
||||||
self.assertIs(view_did_cancel, False)
|
self.assertIs(view_did_cancel, False)
|
||||||
|
# Exactly one call to request_finished handler.
|
||||||
|
self.assertEqual(len(signal_handler.calls), 1)
|
||||||
|
handler_call = signal_handler.calls.pop()
|
||||||
|
# It was NOT on the async thread.
|
||||||
|
self.assertNotEqual(handler_call["thread"], threading.current_thread())
|
||||||
|
# The signal sender is the handler class.
|
||||||
|
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
|
||||||
|
|
||||||
# Request cycle with a disconnect.
|
# Request cycle with a disconnect.
|
||||||
application = TestASGIHandler()
|
application = TestASGIHandler()
|
||||||
@ -489,6 +614,13 @@ class ASGITest(SimpleTestCase):
|
|||||||
await communicator.receive_output()
|
await communicator.receive_output()
|
||||||
await communicator.wait()
|
await communicator.wait()
|
||||||
self.assertIs(view_did_cancel, True)
|
self.assertIs(view_did_cancel, True)
|
||||||
|
# Exactly one call to request_finished handler.
|
||||||
|
self.assertEqual(len(signal_handler.calls), 1)
|
||||||
|
handler_call = signal_handler.calls.pop()
|
||||||
|
# It was NOT on the async thread.
|
||||||
|
self.assertNotEqual(handler_call["thread"], threading.current_thread())
|
||||||
|
# The signal sender is the handler class.
|
||||||
|
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
|
||||||
|
|
||||||
async def test_streaming(self):
|
async def test_streaming(self):
|
||||||
scope = self.async_request_factory._base_scope(
|
scope = self.async_request_factory._base_scope(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user