diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index 0edc98854f..7b0086fb76 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -187,30 +187,41 @@ class ASGIHandler(base.BaseHandler): body_file.close() await self.send_response(error_response, send) return + + async def process_request(request, send): + response = await self.run_get_response(request) + await self.send_response(response, send) + # Try to catch a disconnect while getting response. tasks = [ - asyncio.create_task(self.run_get_response(request)), + # Check the status of these tasks and (optionally) terminate them + # in this order. The listen_for_disconnect() task goes first + # because it should not raise unexpected errors that would prevent + # us from cancelling process_request(). asyncio.create_task(self.listen_for_disconnect(receive)), + asyncio.create_task(process_request(request, send)), ] - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - done, pending = done.pop(), pending.pop() - # Allow views to handle cancellation. - pending.cancel() - try: - await pending - except asyncio.CancelledError: - # Task re-raised the CancelledError as expected. - pass - try: - response = done.result() - except RequestAborted: - body_file.close() - return - except AssertionError: - body_file.close() - raise - # Send the response. - await self.send_response(response, send) + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + # Now wait on both tasks (they may have both finished by now). + for task in tasks: + if task.done(): + try: + task.result() + except RequestAborted: + # Ignore client disconnects. + pass + except AssertionError: + body_file.close() + raise + else: + # Allow views to handle cancellation. + task.cancel() + try: + await task + except asyncio.CancelledError: + # Task re-raised the CancelledError as expected. + pass + body_file.close() async def listen_for_disconnect(self, receive): """Listen for disconnect from the client.""" diff --git a/docs/ref/request-response.txt b/docs/ref/request-response.txt index e70dae4de7..ee98b4b8b1 100644 --- a/docs/ref/request-response.txt +++ b/docs/ref/request-response.txt @@ -1282,6 +1282,36 @@ Attributes This is useful for middleware needing to wrap :attr:`StreamingHttpResponse.streaming_content`. +.. _request-response-streaming-disconnect: + +Handling disconnects +-------------------- + +.. versionadded:: 5.0 + +If the client disconnects during a streaming response, Django will cancel the +coroutine that is handling the response. If you want to clean up resources +manually, you can do so by catching the ``asyncio.CancelledError``:: + + async def streaming_response(): + try: + # Do some work here + async for chunk in my_streaming_iterator(): + yield chunk + except asyncio.CancelledError: + # Handle disconnect + ... + raise + + + async def my_streaming_view(request): + return StreamingHttpResponse(streaming_response()) + +This example only shows how to handle client disconnection while the response +is streaming. If you perform long-running operations in your view before +returning the ``StreamingHttpResponse`` object, then you may also want to +:ref:`handle disconnections in the view ` itself. + ``FileResponse`` objects ======================== diff --git a/docs/topics/async.txt b/docs/topics/async.txt index b16ffe0f78..1faf787380 100644 --- a/docs/topics/async.txt +++ b/docs/topics/async.txt @@ -197,6 +197,9 @@ cleanup:: # Handle disconnect raise +You can also :ref:`handle client disconnects in streaming responses +`. + .. _async-safety: Async safety diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index 0222b5356e..ced24c658e 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -10,7 +10,7 @@ from django.core.asgi import get_asgi_application from django.core.handlers.asgi import ASGIHandler, ASGIRequest from django.core.signals import request_finished, request_started from django.db import close_old_connections -from django.http import HttpResponse +from django.http import HttpResponse, StreamingHttpResponse from django.test import ( AsyncRequestFactory, SimpleTestCase, @@ -237,6 +237,31 @@ class ASGITest(SimpleTestCase): with self.assertRaises(asyncio.TimeoutError): await communicator.receive_output() + async def test_disconnect_both_return(self): + # Force both the disconnect listener and the task that sends the + # response to finish at the same time. + application = get_asgi_application() + scope = self.async_request_factory._base_scope(path="/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request", "body": b"some body"}) + # Fetch response headers (this yields to asyncio and causes + # ASGHandler.send_response() to dump the body of the response in the + # queue). + await communicator.receive_output() + # Fetch response body (there's already some data queued up, so this + # doesn't actually yield to the event loop, it just succeeds + # instantly). + await communicator.receive_output() + # Send disconnect at the same time that response finishes (this just + # puts some info in a queue, it doesn't have to yield to the event + # loop). + await communicator.send_input({"type": "http.disconnect"}) + # Waiting for the communicator _does_ yield to the event loop, since + # ASGIHandler.send_response() is still waiting to do response.close(). + # It so happens that there are enough remaining yield points in both + # tasks that they both finish while the loop is running. + await communicator.wait() + async def test_disconnect_with_body(self): application = get_asgi_application() scope = self.async_request_factory._base_scope(path="/") @@ -254,7 +279,7 @@ class ASGITest(SimpleTestCase): await communicator.send_input({"type": "http.not_a_real_message"}) msg = "Invalid ASGI message after request body: http.not_a_real_message" with self.assertRaisesMessage(AssertionError, msg): - await communicator.receive_output() + await communicator.wait() async def test_delayed_disconnect_with_body(self): application = get_asgi_application() @@ -402,3 +427,95 @@ class ASGITest(SimpleTestCase): await communicator.receive_output() await communicator.wait() self.assertIs(view_did_cancel, True) + + async def test_asyncio_streaming_cancel_error(self): + # Similar to test_asyncio_cancel_error(), but during a streaming + # response. + view_did_cancel = False + + async def streaming_response(): + nonlocal view_did_cancel + try: + await asyncio.sleep(0.2) + yield b"Hello World!" + except asyncio.CancelledError: + # Set the flag. + view_did_cancel = True + raise + + async def view(request): + return StreamingHttpResponse(streaming_response()) + + class TestASGIRequest(ASGIRequest): + urlconf = (path("cancel/", view),) + + class TestASGIHandler(ASGIHandler): + request_class = TestASGIRequest + + # With no disconnect, the request cycle should complete in the same + # manner as the non-streaming response. + application = TestASGIHandler() + scope = self.async_request_factory._base_scope(path="/cancel/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request"}) + response_start = await communicator.receive_output() + self.assertEqual(response_start["type"], "http.response.start") + self.assertEqual(response_start["status"], 200) + response_body = await communicator.receive_output() + self.assertEqual(response_body["type"], "http.response.body") + self.assertEqual(response_body["body"], b"Hello World!") + await communicator.wait() + self.assertIs(view_did_cancel, False) + + # Request cycle with a disconnect. + application = TestASGIHandler() + scope = self.async_request_factory._base_scope(path="/cancel/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request"}) + response_start = await communicator.receive_output() + # Fetch the start of response so streaming can begin + self.assertEqual(response_start["type"], "http.response.start") + self.assertEqual(response_start["status"], 200) + await asyncio.sleep(0.1) + # Now disconnect the client. + await communicator.send_input({"type": "http.disconnect"}) + # This time the handler should not send a response. + with self.assertRaises(asyncio.TimeoutError): + await communicator.receive_output() + await communicator.wait() + self.assertIs(view_did_cancel, True) + + async def test_streaming(self): + scope = self.async_request_factory._base_scope( + path="/streaming/", query_string=b"sleep=0.001" + ) + application = get_asgi_application() + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request"}) + # Fetch http.response.start. + await communicator.receive_output(timeout=1) + # Fetch the 'first' and 'last'. + first_response = await communicator.receive_output(timeout=1) + self.assertEqual(first_response["body"], b"first\n") + second_response = await communicator.receive_output(timeout=1) + self.assertEqual(second_response["body"], b"last\n") + # Fetch the rest of the response so that coroutines are cleaned up. + await communicator.receive_output(timeout=1) + with self.assertRaises(asyncio.TimeoutError): + await communicator.receive_output(timeout=1) + + async def test_streaming_disconnect(self): + scope = self.async_request_factory._base_scope( + path="/streaming/", query_string=b"sleep=0.1" + ) + application = get_asgi_application() + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request"}) + await communicator.receive_output(timeout=1) + first_response = await communicator.receive_output(timeout=1) + self.assertEqual(first_response["body"], b"first\n") + # Disconnect the client. + await communicator.send_input({"type": "http.disconnect"}) + # 'last\n' isn't sent. + with self.assertRaises(asyncio.TimeoutError): + await communicator.receive_output(timeout=0.2) diff --git a/tests/asgi/urls.py b/tests/asgi/urls.py index 0f74fc9b97..931b7d5206 100644 --- a/tests/asgi/urls.py +++ b/tests/asgi/urls.py @@ -1,7 +1,8 @@ +import asyncio import threading import time -from django.http import FileResponse, HttpResponse +from django.http import FileResponse, HttpResponse, StreamingHttpResponse from django.urls import path from django.views.decorators.csrf import csrf_exempt @@ -44,6 +45,17 @@ sync_waiter.lock = threading.Lock() sync_waiter.barrier = threading.Barrier(2) +async def streaming_inner(sleep_time): + yield b"first\n" + await asyncio.sleep(sleep_time) + yield b"last\n" + + +async def streaming_view(request): + sleep_time = float(request.GET["sleep"]) + return StreamingHttpResponse(streaming_inner(sleep_time)) + + test_filename = __file__ @@ -54,4 +66,5 @@ urlpatterns = [ path("post/", post_echo), path("wait/", sync_waiter), path("delayed_hello/", hello_with_delay), + path("streaming/", streaming_view), ]