From 52b054824e899db40ba48f908a9a00dadc56cb89 Mon Sep 17 00:00:00 2001 From: Alexandre Spaeth Date: Wed, 15 Feb 2023 15:16:51 -0800 Subject: [PATCH] Fixed #34342, Refs #33735 -- Fixed test client handling of async streaming responses. Bug in 0bd2c0c9015b53c41394a1c0989afbfd94dc2830. Co-authored-by: Carlton Gibson --- django/test/client.py | 35 ++++++++++++++++++++++++++--------- tests/handlers/tests.py | 17 +++++++++++++++++ tests/handlers/urls.py | 1 + tests/handlers/views.py | 9 +++++++++ 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/django/test/client.py b/django/test/client.py index c699eb9264..cf63265faa 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -116,6 +116,16 @@ def closing_iterator_wrapper(iterable, close): request_finished.connect(close_old_connections) +async def aclosing_iterator_wrapper(iterable, close): + try: + async for chunk in iterable: + yield chunk + finally: + request_finished.disconnect(close_old_connections) + close() # will fire request_finished + request_finished.connect(close_old_connections) + + def conditional_content_removal(request, response): """ Simulate the behavior of most web servers by removing the content of @@ -174,9 +184,14 @@ class ClientHandler(BaseHandler): # Emulate a WSGI server by calling the close method on completion. if response.streaming: - response.streaming_content = closing_iterator_wrapper( - response.streaming_content, response.close - ) + if response.is_async: + response.streaming_content = aclosing_iterator_wrapper( + response.streaming_content, response.close + ) + else: + response.streaming_content = closing_iterator_wrapper( + response.streaming_content, response.close + ) else: request_finished.disconnect(close_old_connections) response.close() # will fire request_finished @@ -223,12 +238,14 @@ class AsyncClientHandler(BaseHandler): response.asgi_request = request # Emulate a server by calling the close method on completion. if response.streaming: - response.streaming_content = await sync_to_async( - closing_iterator_wrapper, thread_sensitive=False - )( - response.streaming_content, - response.close, - ) + if response.is_async: + response.streaming_content = aclosing_iterator_wrapper( + response.streaming_content, response.close + ) + else: + response.streaming_content = closing_iterator_wrapper( + response.streaming_content, response.close + ) else: request_finished.disconnect(close_old_connections) # Will fire request_finished. diff --git a/tests/handlers/tests.py b/tests/handlers/tests.py index 0df481c2fc..0348b8e5d6 100644 --- a/tests/handlers/tests.py +++ b/tests/handlers/tests.py @@ -253,6 +253,16 @@ class HandlerRequestTests(SimpleTestCase): self.assertEqual(response.status_code, 200) self.assertEqual(b"".join(list(response)), b"streaming content") + def test_async_streaming(self): + response = self.client.get("/async_streaming/") + self.assertEqual(response.status_code, 200) + msg = ( + "StreamingHttpResponse must consume asynchronous iterators in order to " + "serve them synchronously. Use a synchronous iterator instead." + ) + with self.assertWarnsMessage(Warning, msg): + self.assertEqual(b"".join(list(response)), b"streaming content") + class ScriptNameTests(SimpleTestCase): def test_get_script_name(self): @@ -329,3 +339,10 @@ class AsyncHandlerRequestTests(SimpleTestCase): self.assertEqual( b"".join([chunk async for chunk in response]), b"streaming content" ) + + async def test_async_streaming(self): + response = await self.async_client.get("/async_streaming/") + self.assertEqual(response.status_code, 200) + self.assertEqual( + b"".join([chunk async for chunk in response]), b"streaming content" + ) diff --git a/tests/handlers/urls.py b/tests/handlers/urls.py index 73d99c7edf..a0efece602 100644 --- a/tests/handlers/urls.py +++ b/tests/handlers/urls.py @@ -8,6 +8,7 @@ urlpatterns = [ path("no_response_fbv/", views.no_response), path("no_response_cbv/", views.NoResponse()), path("streaming/", views.streaming), + path("async_streaming/", views.async_streaming), path("in_transaction/", views.in_transaction), path("not_in_transaction/", views.not_in_transaction), path("not_in_transaction_using_none/", views.not_in_transaction_using_none), diff --git a/tests/handlers/views.py b/tests/handlers/views.py index 351eb65264..95d663323d 100644 --- a/tests/handlers/views.py +++ b/tests/handlers/views.py @@ -65,6 +65,15 @@ async def async_regular(request): return HttpResponse(b"regular content") +async def async_streaming(request): + async def async_streaming_generator(): + yield b"streaming" + yield b" " + yield b"content" + + return StreamingHttpResponse(async_streaming_generator()) + + class CoroutineClearingView: def __call__(self, request): """Return an unawaited coroutine (common error for async views)."""