mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fixed #28337 -- Preserved extra headers of requests made with django.test.Client in assertRedirects().
Co-Authored-By: Hasan Ramezani <hasan.r67@gmail.com>
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							3ca9df51c7
						
					
				
				
					commit
					46e74a5256
				
			| @@ -444,6 +444,7 @@ class Client(RequestFactory): | |||||||
|         self.handler = ClientHandler(enforce_csrf_checks) |         self.handler = ClientHandler(enforce_csrf_checks) | ||||||
|         self.raise_request_exception = raise_request_exception |         self.raise_request_exception = raise_request_exception | ||||||
|         self.exc_info = None |         self.exc_info = None | ||||||
|  |         self.extra = None | ||||||
|  |  | ||||||
|     def store_exc_info(self, **kwargs): |     def store_exc_info(self, **kwargs): | ||||||
|         """Store exceptions when they are generated by a view.""" |         """Store exceptions when they are generated by a view.""" | ||||||
| @@ -515,6 +516,7 @@ class Client(RequestFactory): | |||||||
|  |  | ||||||
|     def get(self, path, data=None, follow=False, secure=False, **extra): |     def get(self, path, data=None, follow=False, secure=False, **extra): | ||||||
|         """Request a response from the server using GET.""" |         """Request a response from the server using GET.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().get(path, data=data, secure=secure, **extra) |         response = super().get(path, data=data, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, **extra) |             response = self._handle_redirects(response, data=data, **extra) | ||||||
| @@ -523,6 +525,7 @@ class Client(RequestFactory): | |||||||
|     def post(self, path, data=None, content_type=MULTIPART_CONTENT, |     def post(self, path, data=None, content_type=MULTIPART_CONTENT, | ||||||
|              follow=False, secure=False, **extra): |              follow=False, secure=False, **extra): | ||||||
|         """Request a response from the server using POST.""" |         """Request a response from the server using POST.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().post(path, data=data, content_type=content_type, secure=secure, **extra) |         response = super().post(path, data=data, content_type=content_type, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) |             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) | ||||||
| @@ -530,6 +533,7 @@ class Client(RequestFactory): | |||||||
|  |  | ||||||
|     def head(self, path, data=None, follow=False, secure=False, **extra): |     def head(self, path, data=None, follow=False, secure=False, **extra): | ||||||
|         """Request a response from the server using HEAD.""" |         """Request a response from the server using HEAD.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().head(path, data=data, secure=secure, **extra) |         response = super().head(path, data=data, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, **extra) |             response = self._handle_redirects(response, data=data, **extra) | ||||||
| @@ -538,6 +542,7 @@ class Client(RequestFactory): | |||||||
|     def options(self, path, data='', content_type='application/octet-stream', |     def options(self, path, data='', content_type='application/octet-stream', | ||||||
|                 follow=False, secure=False, **extra): |                 follow=False, secure=False, **extra): | ||||||
|         """Request a response from the server using OPTIONS.""" |         """Request a response from the server using OPTIONS.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().options(path, data=data, content_type=content_type, secure=secure, **extra) |         response = super().options(path, data=data, content_type=content_type, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) |             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) | ||||||
| @@ -546,6 +551,7 @@ class Client(RequestFactory): | |||||||
|     def put(self, path, data='', content_type='application/octet-stream', |     def put(self, path, data='', content_type='application/octet-stream', | ||||||
|             follow=False, secure=False, **extra): |             follow=False, secure=False, **extra): | ||||||
|         """Send a resource to the server using PUT.""" |         """Send a resource to the server using PUT.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().put(path, data=data, content_type=content_type, secure=secure, **extra) |         response = super().put(path, data=data, content_type=content_type, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) |             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) | ||||||
| @@ -554,6 +560,7 @@ class Client(RequestFactory): | |||||||
|     def patch(self, path, data='', content_type='application/octet-stream', |     def patch(self, path, data='', content_type='application/octet-stream', | ||||||
|               follow=False, secure=False, **extra): |               follow=False, secure=False, **extra): | ||||||
|         """Send a resource to the server using PATCH.""" |         """Send a resource to the server using PATCH.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra) |         response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) |             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) | ||||||
| @@ -562,6 +569,7 @@ class Client(RequestFactory): | |||||||
|     def delete(self, path, data='', content_type='application/octet-stream', |     def delete(self, path, data='', content_type='application/octet-stream', | ||||||
|                follow=False, secure=False, **extra): |                follow=False, secure=False, **extra): | ||||||
|         """Send a DELETE request to the server.""" |         """Send a DELETE request to the server.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra) |         response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) |             response = self._handle_redirects(response, data=data, content_type=content_type, **extra) | ||||||
| @@ -569,6 +577,7 @@ class Client(RequestFactory): | |||||||
|  |  | ||||||
|     def trace(self, path, data='', follow=False, secure=False, **extra): |     def trace(self, path, data='', follow=False, secure=False, **extra): | ||||||
|         """Send a TRACE request to the server.""" |         """Send a TRACE request to the server.""" | ||||||
|  |         self.extra = extra | ||||||
|         response = super().trace(path, data=data, secure=secure, **extra) |         response = super().trace(path, data=data, secure=secure, **extra) | ||||||
|         if follow: |         if follow: | ||||||
|             response = self._handle_redirects(response, data=data, **extra) |             response = self._handle_redirects(response, data=data, **extra) | ||||||
|   | |||||||
| @@ -347,10 +347,15 @@ class SimpleTestCase(unittest.TestCase): | |||||||
|                         "Otherwise, use assertRedirects(..., fetch_redirect_response=False)." |                         "Otherwise, use assertRedirects(..., fetch_redirect_response=False)." | ||||||
|                         % (url, domain) |                         % (url, domain) | ||||||
|                     ) |                     ) | ||||||
|                 redirect_response = response.client.get(path, QueryDict(query), secure=(scheme == 'https')) |  | ||||||
|  |  | ||||||
|                 # Get the redirection page, using the same client that was used |                 # Get the redirection page, using the same client that was used | ||||||
|                 # to obtain the original response. |                 # to obtain the original response. | ||||||
|  |                 extra = response.client.extra or {} | ||||||
|  |                 redirect_response = response.client.get( | ||||||
|  |                     path, | ||||||
|  |                     QueryDict(query), | ||||||
|  |                     secure=(scheme == 'https'), | ||||||
|  |                     **extra, | ||||||
|  |                 ) | ||||||
|                 self.assertEqual( |                 self.assertEqual( | ||||||
|                     redirect_response.status_code, target_status_code, |                     redirect_response.status_code, target_status_code, | ||||||
|                     msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" |                     msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" | ||||||
|   | |||||||
| @@ -508,6 +508,27 @@ class AssertRedirectsTests(SimpleTestCase): | |||||||
|             with self.assertRaises(AssertionError): |             with self.assertRaises(AssertionError): | ||||||
|                 self.assertRedirects(response, 'http://testserver/secure_view/', status_code=302) |                 self.assertRedirects(response, 'http://testserver/secure_view/', status_code=302) | ||||||
|  |  | ||||||
|  |     def test_redirect_fetch_redirect_response(self): | ||||||
|  |         """Preserve extra headers of requests made with django.test.Client.""" | ||||||
|  |         methods = ( | ||||||
|  |             'get', 'post', 'head', 'options', 'put', 'patch', 'delete', 'trace', | ||||||
|  |         ) | ||||||
|  |         for method in methods: | ||||||
|  |             with self.subTest(method=method): | ||||||
|  |                 req_method = getattr(self.client, method) | ||||||
|  |                 response = req_method( | ||||||
|  |                     '/redirect_based_on_extra_headers_1/', | ||||||
|  |                     follow=False, | ||||||
|  |                     HTTP_REDIRECT='val', | ||||||
|  |                 ) | ||||||
|  |                 self.assertRedirects( | ||||||
|  |                     response, | ||||||
|  |                     '/redirect_based_on_extra_headers_2/', | ||||||
|  |                     fetch_redirect_response=True, | ||||||
|  |                     status_code=302, | ||||||
|  |                     target_status_code=302, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @override_settings(ROOT_URLCONF='test_client_regress.urls') | @override_settings(ROOT_URLCONF='test_client_regress.urls') | ||||||
| class AssertFormErrorTests(SimpleTestCase): | class AssertFormErrorTests(SimpleTestCase): | ||||||
|   | |||||||
| @@ -25,6 +25,8 @@ urlpatterns = [ | |||||||
|     path('circular_redirect_2/', RedirectView.as_view(url='/circular_redirect_3/')), |     path('circular_redirect_2/', RedirectView.as_view(url='/circular_redirect_3/')), | ||||||
|     path('circular_redirect_3/', RedirectView.as_view(url='/circular_redirect_1/')), |     path('circular_redirect_3/', RedirectView.as_view(url='/circular_redirect_1/')), | ||||||
|     path('redirect_other_host/', RedirectView.as_view(url='https://otherserver:8443/no_template_view/')), |     path('redirect_other_host/', RedirectView.as_view(url='https://otherserver:8443/no_template_view/')), | ||||||
|  |     path('redirect_based_on_extra_headers_1/', views.redirect_based_on_extra_headers_1_view), | ||||||
|  |     path('redirect_based_on_extra_headers_2/', views.redirect_based_on_extra_headers_2_view), | ||||||
|     path('set_session/', views.set_session_view), |     path('set_session/', views.set_session_view), | ||||||
|     path('check_session/', views.check_session_view), |     path('check_session/', views.check_session_view), | ||||||
|     path('request_methods/', views.request_methods_view), |     path('request_methods/', views.request_methods_view), | ||||||
|   | |||||||
| @@ -154,3 +154,15 @@ def render_template_multiple_times(request): | |||||||
|     """A view that renders a template multiple times.""" |     """A view that renders a template multiple times.""" | ||||||
|     return HttpResponse( |     return HttpResponse( | ||||||
|         render_to_string('base.html') + render_to_string('base.html')) |         render_to_string('base.html') + render_to_string('base.html')) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def redirect_based_on_extra_headers_1_view(request): | ||||||
|  |     if 'HTTP_REDIRECT' in request.META: | ||||||
|  |         return HttpResponseRedirect('/redirect_based_on_extra_headers_2/') | ||||||
|  |     return HttpResponse() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def redirect_based_on_extra_headers_2_view(request): | ||||||
|  |     if 'HTTP_REDIRECT' in request.META: | ||||||
|  |         return HttpResponseRedirect('/redirects/further/more/') | ||||||
|  |     return HttpResponse() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user