From 3f8dbe267d35f0219277f0fe2d79915a4fb2b045 Mon Sep 17 00:00:00 2001 From: Olivier Tabone Date: Fri, 4 Aug 2023 09:14:19 +0200 Subject: [PATCH] Fixed #34757 -- Added support for following redirects to AsyncClient. --- django/test/client.py | 238 +++++++++++++++++++++++++++++++++- docs/releases/5.0.txt | 2 + docs/topics/testing/tools.txt | 5 +- tests/test_client/tests.py | 19 ++- 4 files changed, 252 insertions(+), 12 deletions(-) diff --git a/django/test/client.py b/django/test/client.py index eed2d4f828..d44e30ff56 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -705,9 +705,6 @@ class AsyncRequestFactory(RequestFactory): ] ) s["_body_file"] = FakePayload(data) - follow = extra.pop("follow", None) - if follow is not None: - s["follow"] = follow if query_string := extra.pop("QUERY_STRING", None): s["query_string"] = query_string if headers: @@ -1296,10 +1293,6 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): query environment, which can be overridden using the arguments to the request. """ - if "follow" in request: - raise NotImplementedError( - "AsyncClient request methods do not accept the follow parameter." - ) scope = self._base_scope(**request) # Curry a data dictionary into an instance of the template renderer # callback function. @@ -1338,3 +1331,234 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): if response.cookies: self.cookies.update(response.cookies) return response + + async def get( + self, + path, + data=None, + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Request a response from the server using GET.""" + self.extra = extra + self.headers = headers + response = await super().get( + path, data=data, secure=secure, headers=headers, **extra + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, headers=headers, **extra + ) + return response + + async def post( + self, + path, + data=None, + content_type=MULTIPART_CONTENT, + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Request a response from the server using POST.""" + self.extra = extra + self.headers = headers + response = await super().post( + path, + data=data, + content_type=content_type, + secure=secure, + headers=headers, + **extra, + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, content_type=content_type, headers=headers, **extra + ) + return response + + async def head( + self, + path, + data=None, + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Request a response from the server using HEAD.""" + self.extra = extra + self.headers = headers + response = await super().head( + path, data=data, secure=secure, headers=headers, **extra + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, headers=headers, **extra + ) + return response + + async def options( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Request a response from the server using OPTIONS.""" + self.extra = extra + self.headers = headers + response = await super().options( + path, + data=data, + content_type=content_type, + secure=secure, + headers=headers, + **extra, + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, content_type=content_type, headers=headers, **extra + ) + return response + + async def put( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Send a resource to the server using PUT.""" + self.extra = extra + self.headers = headers + response = await super().put( + path, + data=data, + content_type=content_type, + secure=secure, + headers=headers, + **extra, + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, content_type=content_type, headers=headers, **extra + ) + return response + + async def patch( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Send a resource to the server using PATCH.""" + self.extra = extra + self.headers = headers + response = await super().patch( + path, + data=data, + content_type=content_type, + secure=secure, + headers=headers, + **extra, + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, content_type=content_type, headers=headers, **extra + ) + return response + + async def delete( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Send a DELETE request to the server.""" + self.extra = extra + self.headers = headers + response = await super().delete( + path, + data=data, + content_type=content_type, + secure=secure, + headers=headers, + **extra, + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, content_type=content_type, headers=headers, **extra + ) + return response + + async def trace( + self, + path, + data="", + follow=False, + secure=False, + *, + headers=None, + **extra, + ): + """Send a TRACE request to the server.""" + self.extra = extra + self.headers = headers + response = await super().trace( + path, data=data, secure=secure, headers=headers, **extra + ) + if follow: + response = await self._ahandle_redirects( + response, data=data, headers=headers, **extra + ) + return response + + async def _ahandle_redirects( + self, + response, + data="", + content_type="", + headers=None, + **extra, + ): + """ + Follow any redirects by requesting responses from the server using GET. + """ + response.redirect_chain = [] + while response.status_code in REDIRECT_STATUS_CODES: + redirect_chain = response.redirect_chain + response = await self._follow_redirect( + response, + data=data, + content_type=content_type, + headers=headers, + **extra, + ) + response.redirect_chain = redirect_chain + self._ensure_redirects_not_cyclic(response) + return response diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 7379289e42..36c2c650fa 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -433,6 +433,8 @@ Tests :meth:`~django.test.Client.aforce_login`, and :meth:`~django.test.Client.alogout`. +* :class:`~django.test.AsyncClient` now supports the ``follow`` parameter. + URLs ~~~~ diff --git a/docs/topics/testing/tools.txt b/docs/topics/testing/tools.txt index 2378ff4ec1..a9373ff108 100644 --- a/docs/topics/testing/tools.txt +++ b/docs/topics/testing/tools.txt @@ -2032,7 +2032,6 @@ test client, with the following exceptions: * In the initialization, arbitrary keyword arguments in ``defaults`` are added directly into the ASGI scope. -* The ``follow`` parameter is not supported. * Headers passed as ``extra`` keyword arguments should not have the ``HTTP_`` prefix required by the synchronous client (see :meth:`Client.get`). For example, here is how to set an HTTP ``Accept`` header: @@ -2046,6 +2045,10 @@ test client, with the following exceptions: The ``headers`` parameter was added. +.. versionchanged:: 5.0 + + Support for the ``follow`` parameter was added to the ``AsyncClient``. + Using ``AsyncClient`` any method that makes a request must be awaited:: async def test_my_thing(self): diff --git a/tests/test_client/tests.py b/tests/test_client/tests.py index a5e980f3d0..402f282588 100644 --- a/tests/test_client/tests.py +++ b/tests/test_client/tests.py @@ -1135,8 +1135,11 @@ class AsyncClientTest(TestCase): response = await self.async_client.get("/middleware_urlconf_view/") self.assertEqual(response.resolver_match.url_name, "middleware_urlconf_view") - async def test_follow_parameter_not_implemented(self): - msg = "AsyncClient request methods do not accept the follow parameter." + async def test_redirect(self): + response = await self.async_client.get("/redirect_view/") + self.assertEqual(response.status_code, 302) + + async def test_follow_redirect(self): tests = ( "get", "post", @@ -1150,8 +1153,16 @@ class AsyncClientTest(TestCase): for method_name in tests: with self.subTest(method=method_name): method = getattr(self.async_client, method_name) - with self.assertRaisesMessage(NotImplementedError, msg): - await method("/redirect_view/", follow=True) + response = await method("/redirect_view/", follow=True) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.resolver_match.url_name, "get_view") + + async def test_follow_double_redirect(self): + response = await self.async_client.get("/double_redirect_view/", follow=True) + self.assertRedirects( + response, "/get_view/", status_code=302, target_status_code=200 + ) + self.assertEqual(len(response.redirect_chain), 2) async def test_get_data(self): response = await self.async_client.get("/get_view/", {"var": "val"})