diff --git a/django/test/client.py b/django/test/client.py index 8fdce54d4d..c303ca3d74 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -6,7 +6,7 @@ from copy import copy from functools import partial from http import HTTPStatus from importlib import import_module -from io import BytesIO +from io import BytesIO, IOBase from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit from asgiref.sync import sync_to_async @@ -55,7 +55,7 @@ class RedirectCycleError(Exception): self.redirect_chain = last_response.redirect_chain -class FakePayload: +class FakePayload(IOBase): """ A wrapper around BytesIO that restricts what can be read since data from the network can't be sought and cannot be read outside of its content @@ -63,39 +63,49 @@ class FakePayload: that wouldn't work in real life. """ - def __init__(self, content=None): + def __init__(self, initial_bytes=None): self.__content = BytesIO() self.__len = 0 self.read_started = False - if content is not None: - self.write(content) + if initial_bytes is not None: + self.write(initial_bytes) def __len__(self): return self.__len - def read(self, num_bytes=None): + def read(self, size=-1, /): if not self.read_started: self.__content.seek(0) self.read_started = True - if num_bytes is None: - num_bytes = self.__len or 0 + if size == -1 or size is None: + size = self.__len assert ( - self.__len >= num_bytes + self.__len >= size ), "Cannot read more than the available bytes from the HTTP incoming data." - content = self.__content.read(num_bytes) - self.__len -= num_bytes + content = self.__content.read(size) + self.__len -= len(content) return content - def write(self, content): + def readline(self, size=-1, /): + if not self.read_started: + self.__content.seek(0) + self.read_started = True + if size == -1 or size is None: + size = self.__len + assert ( + self.__len >= size + ), "Cannot read more than the available bytes from the HTTP incoming data." + content = self.__content.readline(size) + self.__len -= len(content) + return content + + def write(self, b, /): if self.read_started: raise ValueError("Unable to write a payload after it's been read") - content = force_bytes(content) + content = force_bytes(b) self.__content.write(content) self.__len += len(content) - def close(self): - pass - def closing_iterator_wrapper(iterable, close): try: diff --git a/tests/requests/tests.py b/tests/requests/tests.py index 4aef752894..ef218afe2f 100644 --- a/tests/requests/tests.py +++ b/tests/requests/tests.py @@ -290,7 +290,7 @@ class RequestsTests(SimpleTestCase): self.assertEqual(stream.read(2), b"") self.assertEqual(stream.read(), b"") - def test_stream(self): + def test_stream_read(self): payload = FakePayload("name=value") request = WSGIRequest( { @@ -302,6 +302,19 @@ class RequestsTests(SimpleTestCase): ) self.assertEqual(request.read(), b"name=value") + def test_stream_readline(self): + payload = FakePayload("name=value\nother=string") + request = WSGIRequest( + { + "REQUEST_METHOD": "POST", + "CONTENT_TYPE": "application/x-www-form-urlencoded", + "CONTENT_LENGTH": len(payload), + "wsgi.input": payload, + }, + ) + self.assertEqual(request.readline(), b"name=value\n") + self.assertEqual(request.readline(), b"other=string") + def test_read_after_value(self): """ Reading from request is allowed after accessing request contents as diff --git a/tests/test_client/test_fakepayload.py b/tests/test_client/test_fakepayload.py index 191bf0e111..222bef3b00 100644 --- a/tests/test_client/test_fakepayload.py +++ b/tests/test_client/test_fakepayload.py @@ -5,7 +5,9 @@ from django.test.client import FakePayload class FakePayloadTests(SimpleTestCase): def test_write_after_read(self): payload = FakePayload() - payload.read() - msg = "Unable to write a payload after it's been read" - with self.assertRaisesMessage(ValueError, msg): - payload.write(b"abc") + for operation in [payload.read, payload.readline]: + with self.subTest(operation=operation.__name__): + operation() + msg = "Unable to write a payload after it's been read" + with self.assertRaisesMessage(ValueError, msg): + payload.write(b"abc")