1
0
mirror of https://github.com/django/django.git synced 2025-10-24 22:26:08 +00:00

Fixed #34063 -- Fixed reading request body with async request factory and client.

Co-authored-by: Kevan Swanberg <kevswanberg@gmail.com>
Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>
This commit is contained in:
Scott Halgrim
2022-11-08 12:19:59 +01:00
committed by Carlton Gibson
parent 8e6ea1d153
commit c4eaa67e2b
3 changed files with 27 additions and 3 deletions

View File

@@ -14,7 +14,7 @@ from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.core.handlers.asgi import ASGIRequest from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest from django.core.handlers.wsgi import LimitedStream, WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder from django.core.serializers.json import DjangoJSONEncoder
from django.core.signals import got_request_exception, request_finished, request_started from django.core.signals import got_request_exception, request_finished, request_started
from django.db import close_old_connections from django.db import close_old_connections
@@ -198,7 +198,8 @@ class AsyncClientHandler(BaseHandler):
sender=self.__class__, scope=scope sender=self.__class__, scope=scope
) )
request_started.connect(close_old_connections) request_started.connect(close_old_connections)
request = ASGIRequest(scope, body_file) # Wrap FakePayload body_file to allow large read() in test environment.
request = ASGIRequest(scope, LimitedStream(body_file, len(body_file)))
# Sneaky little hack so that we can easily get round # Sneaky little hack so that we can easily get round
# CsrfViewMiddleware. This makes life easier, and is probably required # CsrfViewMiddleware. This makes life easier, and is probably required
# for backwards compatibility with external tests against admin views. # for backwards compatibility with external tests against admin views.
@@ -598,7 +599,10 @@ class AsyncRequestFactory(RequestFactory):
body_file = request.pop("_body_file") body_file = request.pop("_body_file")
else: else:
body_file = FakePayload("") body_file = FakePayload("")
return ASGIRequest(self._base_scope(**request), body_file) # Wrap FakePayload body_file to allow large read() in test environment.
return ASGIRequest(
self._base_scope(**request), LimitedStream(body_file, len(body_file))
)
def generic( def generic(
self, self,

View File

@@ -1103,6 +1103,14 @@ class AsyncClientTest(TestCase):
response = await self.async_client.get("/get_view/", {"var": "val"}) response = await self.async_client.get("/get_view/", {"var": "val"})
self.assertContains(response, "This is a test. val is the value.") self.assertContains(response, "This is a test. val is the value.")
async def test_post_data(self):
response = await self.async_client.post("/post_view/", {"value": 37})
self.assertContains(response, "Data received: 37 is the value.")
async def test_body_read_on_get_data(self):
response = await self.async_client.get("/post_view/")
self.assertContains(response, "Viewing GET page.")
@override_settings(ROOT_URLCONF="test_client.urls") @override_settings(ROOT_URLCONF="test_client.urls")
class AsyncRequestFactoryTest(SimpleTestCase): class AsyncRequestFactoryTest(SimpleTestCase):
@@ -1147,6 +1155,16 @@ class AsyncRequestFactoryTest(SimpleTestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'{"example": "data"}') self.assertEqual(response.content, b'{"example": "data"}')
async def test_request_limited_read(self):
tests = ["GET", "POST"]
for method in tests:
with self.subTest(method=method):
request = self.request_factory.generic(
method,
"/somewhere",
)
self.assertEqual(request.read(200), b"")
def test_request_factory_sets_headers(self): def test_request_factory_sets_headers(self):
request = self.request_factory.get( request = self.request_factory.get(
"/somewhere/", "/somewhere/",

View File

@@ -90,6 +90,8 @@ def post_view(request):
c = Context() c = Context()
else: else:
t = Template("Viewing GET page.", name="Empty GET Template") t = Template("Viewing GET page.", name="Empty GET Template")
# Used by test_body_read_on_get_data.
request.read(200)
c = Context() c = Context()
return HttpResponse(t.render(c)) return HttpResponse(t.render(c))