From 852fa7617e24a68a990eaf0f7a597edb434ffd76 Mon Sep 17 00:00:00 2001 From: Virtosu Bogdan Date: Fri, 23 Jul 2021 12:13:31 +0200 Subject: [PATCH] Refs #32329 -- Allowed specifying request class in csrf_tests test hooks. --- tests/csrf_tests/tests.py | 82 +++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/tests/csrf_tests/tests.py b/tests/csrf_tests/tests.py index 216625067c..e823ff11ee 100644 --- a/tests/csrf_tests/tests.py +++ b/tests/csrf_tests/tests.py @@ -99,6 +99,23 @@ class TestingHttpRequest(HttpRequest): return getattr(self, '_is_secure_override', False) +class PostErrorRequest(TestingHttpRequest): + """ + TestingHttpRequest that can raise errors when accessing POST data. + """ + post_error = None + + def _get_post(self): + if self.post_error is not None: + raise self.post_error + return self._post + + def _set_post(self, post): + self._post = post + + POST = property(_get_post, _set_post) + + class CsrfViewMiddlewareTestMixin: """ Shared methods and tests for session-based and cookie-based tokens. @@ -131,10 +148,12 @@ class CsrfViewMiddlewareTestMixin: secrets_set = [_unmask_cipher_token(cookie) for cookie in cookies_set] self.assertEqual(secrets_set, expected_secrets) - def _get_request(self, method=None, cookie=None): + def _get_request(self, method=None, cookie=None, request_class=None): if method is None: method = 'GET' - req = TestingHttpRequest() + if request_class is None: + request_class = TestingHttpRequest + req = request_class() req.method = method if cookie is not None: self._set_csrf_cookie(req, cookie) @@ -142,7 +161,7 @@ class CsrfViewMiddlewareTestMixin: def _get_csrf_cookie_request( self, method=None, cookie=None, post_token=None, meta_token=None, - token_header=None, + token_header=None, request_class=None, ): """ The method argument defaults to "GET". The cookie argument defaults to @@ -156,7 +175,11 @@ class CsrfViewMiddlewareTestMixin: cookie = self._csrf_id_cookie if token_header is None: token_header = 'HTTP_X_CSRFTOKEN' - req = self._get_request(method=method, cookie=cookie) + req = self._get_request( + method=method, + cookie=cookie, + request_class=request_class, + ) if post_token is not None: req.POST['csrfmiddlewaretoken'] = post_token if meta_token is not None: @@ -165,15 +188,21 @@ class CsrfViewMiddlewareTestMixin: def _get_POST_csrf_cookie_request( self, cookie=None, post_token=None, meta_token=None, token_header=None, + request_class=None, ): return self._get_csrf_cookie_request( method='POST', cookie=cookie, post_token=post_token, meta_token=meta_token, token_header=token_header, + request_class=request_class, ) - def _get_POST_request_with_token(self, cookie=None): + def _get_POST_request_with_token(self, cookie=None, request_class=None): """The cookie argument defaults to this class's default test cookie.""" - return self._get_POST_csrf_cookie_request(cookie=cookie, post_token=self._csrf_id_token) + return self._get_POST_csrf_cookie_request( + cookie=cookie, + post_token=self._csrf_id_token, + request_class=request_class, + ) def _check_token_present(self, response, csrf_id=None): text = str(response.content, response.charset) @@ -702,49 +731,16 @@ class CsrfViewMiddlewareTestMixin: def test_post_data_read_failure(self): """ OSErrors during POST data reading are caught and treated as if the - POST data wasn't there (#20128). + POST data wasn't there. """ - class CsrfPostRequest(HttpRequest): - """ - HttpRequest that can raise an OSError when accessing POST data - """ - def __init__(self, token, raise_error): - super().__init__() - self.method = 'POST' - - self.raise_error = False - self.COOKIES[settings.CSRF_COOKIE_NAME] = token - - # Handle both cases here to prevent duplicate code in the - # session tests. - self.session = {} - self.session[CSRF_SESSION_KEY] = token - - self.POST['csrfmiddlewaretoken'] = token - self.raise_error = raise_error - - def _load_post_and_files(self): - raise OSError('error reading input data') - - def _get_post(self): - if self.raise_error: - self._load_post_and_files() - return self._post - - def _set_post(self, post): - self._post = post - - POST = property(_get_post, _set_post) - - token = ('ABC' + self._csrf_id_token)[:CSRF_TOKEN_LENGTH] - - req = CsrfPostRequest(token, raise_error=False) + req = self._get_POST_request_with_token() mw = CsrfViewMiddleware(post_form_view) mw.process_request(req) resp = mw.process_view(req, post_form_view, (), {}) self.assertIsNone(resp) - req = CsrfPostRequest(token, raise_error=True) + req = self._get_POST_request_with_token(request_class=PostErrorRequest) + req.post_error = OSError('error reading input data') mw.process_request(req) with self.assertLogs('django.security.csrf', 'WARNING') as cm: resp = mw.process_view(req, post_form_view, (), {})