mirror of
https://github.com/django/django.git
synced 2025-10-31 09:41:08 +00:00
Fixed #34074 -- Added headers argument to RequestFactory and Client classes.
This commit is contained in:
committed by
Mariusz Felisiak
parent
b181cae2e3
commit
67da22f08e
@@ -1,5 +1,6 @@
|
||||
from django.http.cookie import SimpleCookie, parse_cookie
|
||||
from django.http.request import (
|
||||
HttpHeaders,
|
||||
HttpRequest,
|
||||
QueryDict,
|
||||
RawPostDataException,
|
||||
@@ -27,6 +28,7 @@ from django.http.response import (
|
||||
__all__ = [
|
||||
"SimpleCookie",
|
||||
"parse_cookie",
|
||||
"HttpHeaders",
|
||||
"HttpRequest",
|
||||
"QueryDict",
|
||||
"RawPostDataException",
|
||||
|
||||
@@ -461,6 +461,31 @@ class HttpHeaders(CaseInsensitiveMapping):
|
||||
return None
|
||||
return header.replace("_", "-").title()
|
||||
|
||||
@classmethod
|
||||
def to_wsgi_name(cls, header):
|
||||
header = header.replace("-", "_").upper()
|
||||
if header in cls.UNPREFIXED_HEADERS:
|
||||
return header
|
||||
return f"{cls.HTTP_PREFIX}{header}"
|
||||
|
||||
@classmethod
|
||||
def to_asgi_name(cls, header):
|
||||
return header.replace("-", "_").upper()
|
||||
|
||||
@classmethod
|
||||
def to_wsgi_names(cls, headers):
|
||||
return {
|
||||
cls.to_wsgi_name(header_name): value
|
||||
for header_name, value in headers.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def to_asgi_names(cls, headers):
|
||||
return {
|
||||
cls.to_asgi_name(header_name): value
|
||||
for header_name, value in headers.items()
|
||||
}
|
||||
|
||||
|
||||
class QueryDict(MultiValueDict):
|
||||
"""
|
||||
|
||||
@@ -11,8 +11,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import DisallowedHost, ImproperlyConfigured
|
||||
from django.http import UnreadablePostError
|
||||
from django.http.request import HttpHeaders
|
||||
from django.http import HttpHeaders, UnreadablePostError
|
||||
from django.urls import get_callable
|
||||
from django.utils.cache import patch_vary_headers
|
||||
from django.utils.crypto import constant_time_compare, get_random_string
|
||||
|
||||
@@ -18,7 +18,7 @@ from django.core.handlers.wsgi import LimitedStream, WSGIRequest
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.core.signals import got_request_exception, request_finished, request_started
|
||||
from django.db import close_old_connections
|
||||
from django.http import HttpRequest, QueryDict, SimpleCookie
|
||||
from django.http import HttpHeaders, HttpRequest, QueryDict, SimpleCookie
|
||||
from django.test import signals
|
||||
from django.test.utils import ContextList
|
||||
from django.urls import resolve
|
||||
@@ -346,11 +346,13 @@ class RequestFactory:
|
||||
just as if that view had been hooked up using a URLconf.
|
||||
"""
|
||||
|
||||
def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
|
||||
def __init__(self, *, json_encoder=DjangoJSONEncoder, headers=None, **defaults):
|
||||
self.json_encoder = json_encoder
|
||||
self.defaults = defaults
|
||||
self.cookies = SimpleCookie()
|
||||
self.errors = BytesIO()
|
||||
if headers:
|
||||
self.defaults.update(HttpHeaders.to_wsgi_names(headers))
|
||||
|
||||
def _base_environ(self, **request):
|
||||
"""
|
||||
@@ -422,13 +424,14 @@ class RequestFactory:
|
||||
# Refs comment in `get_bytes_from_wsgi()`.
|
||||
return path.decode("iso-8859-1")
|
||||
|
||||
def get(self, path, data=None, secure=False, **extra):
|
||||
def get(self, path, data=None, secure=False, *, headers=None, **extra):
|
||||
"""Construct a GET request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic(
|
||||
"GET",
|
||||
path,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**{
|
||||
"QUERY_STRING": urlencode(data, doseq=True),
|
||||
**extra,
|
||||
@@ -436,32 +439,46 @@ class RequestFactory:
|
||||
)
|
||||
|
||||
def post(
|
||||
self, path, data=None, content_type=MULTIPART_CONTENT, secure=False, **extra
|
||||
self,
|
||||
path,
|
||||
data=None,
|
||||
content_type=MULTIPART_CONTENT,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Construct a POST request."""
|
||||
data = self._encode_json({} if data is None else data, content_type)
|
||||
post_data = self._encode_data(data, content_type)
|
||||
|
||||
return self.generic(
|
||||
"POST", path, post_data, content_type, secure=secure, **extra
|
||||
"POST",
|
||||
path,
|
||||
post_data,
|
||||
content_type,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**extra,
|
||||
)
|
||||
|
||||
def head(self, path, data=None, secure=False, **extra):
|
||||
def head(self, path, data=None, secure=False, *, headers=None, **extra):
|
||||
"""Construct a HEAD request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic(
|
||||
"HEAD",
|
||||
path,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**{
|
||||
"QUERY_STRING": urlencode(data, doseq=True),
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
def trace(self, path, secure=False, **extra):
|
||||
def trace(self, path, secure=False, *, headers=None, **extra):
|
||||
"""Construct a TRACE request."""
|
||||
return self.generic("TRACE", path, secure=secure, **extra)
|
||||
return self.generic("TRACE", path, secure=secure, headers=headers, **extra)
|
||||
|
||||
def options(
|
||||
self,
|
||||
@@ -469,10 +486,14 @@ class RequestFactory:
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"Construct an OPTIONS request."
|
||||
return self.generic("OPTIONS", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic(
|
||||
"OPTIONS", path, data, content_type, secure=secure, headers=headers, **extra
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
@@ -480,11 +501,15 @@ class RequestFactory:
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Construct a PUT request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic("PUT", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic(
|
||||
"PUT", path, data, content_type, secure=secure, headers=headers, **extra
|
||||
)
|
||||
|
||||
def patch(
|
||||
self,
|
||||
@@ -492,11 +517,15 @@ class RequestFactory:
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Construct a PATCH request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic("PATCH", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic(
|
||||
"PATCH", path, data, content_type, secure=secure, headers=headers, **extra
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@@ -504,11 +533,15 @@ class RequestFactory:
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Construct a DELETE request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic("DELETE", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic(
|
||||
"DELETE", path, data, content_type, secure=secure, headers=headers, **extra
|
||||
)
|
||||
|
||||
def generic(
|
||||
self,
|
||||
@@ -517,6 +550,8 @@ class RequestFactory:
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
@@ -536,6 +571,8 @@ class RequestFactory:
|
||||
"wsgi.input": FakePayload(data),
|
||||
}
|
||||
)
|
||||
if headers:
|
||||
extra.update(HttpHeaders.to_wsgi_names(headers))
|
||||
r.update(extra)
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the URL.
|
||||
if not r.get("QUERY_STRING"):
|
||||
@@ -611,6 +648,8 @@ class AsyncRequestFactory(RequestFactory):
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
@@ -636,6 +675,8 @@ class AsyncRequestFactory(RequestFactory):
|
||||
s["follow"] = follow
|
||||
if query_string := extra.pop("QUERY_STRING", None):
|
||||
s["query_string"] = query_string
|
||||
if headers:
|
||||
extra.update(HttpHeaders.to_asgi_names(headers))
|
||||
s["headers"] += [
|
||||
(key.lower().encode("ascii"), value.encode("latin1"))
|
||||
for key, value in extra.items()
|
||||
@@ -782,9 +823,14 @@ class Client(ClientMixin, RequestFactory):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
|
||||
self,
|
||||
enforce_csrf_checks=False,
|
||||
raise_request_exception=True,
|
||||
*,
|
||||
headers=None,
|
||||
**defaults,
|
||||
):
|
||||
super().__init__(**defaults)
|
||||
super().__init__(headers=headers, **defaults)
|
||||
self.handler = ClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
@@ -837,12 +883,23 @@ class Client(ClientMixin, RequestFactory):
|
||||
self.cookies.update(response.cookies)
|
||||
return response
|
||||
|
||||
def get(self, path, data=None, follow=False, secure=False, **extra):
|
||||
def get(
|
||||
self,
|
||||
path,
|
||||
data=None,
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Request a response from the server using GET."""
|
||||
self.extra = extra
|
||||
response = super().get(path, data=data, secure=secure, **extra)
|
||||
response = super().get(path, data=data, secure=secure, headers=headers, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
response = self._handle_redirects(
|
||||
response, data=data, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
def post(
|
||||
@@ -852,25 +909,45 @@ class Client(ClientMixin, RequestFactory):
|
||||
content_type=MULTIPART_CONTENT,
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Request a response from the server using POST."""
|
||||
self.extra = extra
|
||||
response = super().post(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
path,
|
||||
data=data,
|
||||
content_type=content_type,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**extra,
|
||||
)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
response, data=data, content_type=content_type, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
def head(self, path, data=None, follow=False, secure=False, **extra):
|
||||
def head(
|
||||
self,
|
||||
path,
|
||||
data=None,
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Request a response from the server using HEAD."""
|
||||
self.extra = extra
|
||||
response = super().head(path, data=data, secure=secure, **extra)
|
||||
response = super().head(
|
||||
path, data=data, secure=secure, headers=headers, **extra
|
||||
)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
response = self._handle_redirects(
|
||||
response, data=data, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
def options(
|
||||
@@ -880,16 +957,23 @@ class Client(ClientMixin, RequestFactory):
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Request a response from the server using OPTIONS."""
|
||||
self.extra = extra
|
||||
response = super().options(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
path,
|
||||
data=data,
|
||||
content_type=content_type,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**extra,
|
||||
)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
response, data=data, content_type=content_type, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -900,16 +984,23 @@ class Client(ClientMixin, RequestFactory):
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Send a resource to the server using PUT."""
|
||||
self.extra = extra
|
||||
response = super().put(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
path,
|
||||
data=data,
|
||||
content_type=content_type,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**extra,
|
||||
)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
response, data=data, content_type=content_type, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -920,16 +1011,23 @@ class Client(ClientMixin, RequestFactory):
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Send a resource to the server using PATCH."""
|
||||
self.extra = extra
|
||||
response = super().patch(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
path,
|
||||
data=data,
|
||||
content_type=content_type,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**extra,
|
||||
)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
response, data=data, content_type=content_type, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -940,28 +1038,55 @@ class Client(ClientMixin, RequestFactory):
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Send a DELETE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().delete(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
path,
|
||||
data=data,
|
||||
content_type=content_type,
|
||||
secure=secure,
|
||||
headers=headers,
|
||||
**extra,
|
||||
)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
response, data=data, content_type=content_type, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
def trace(self, path, data="", follow=False, secure=False, **extra):
|
||||
def trace(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
follow=False,
|
||||
secure=False,
|
||||
*,
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""Send a TRACE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().trace(path, data=data, secure=secure, **extra)
|
||||
response = super().trace(
|
||||
path, data=data, secure=secure, headers=headers, **extra
|
||||
)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
response = self._handle_redirects(
|
||||
response, data=data, headers=headers, **extra
|
||||
)
|
||||
return response
|
||||
|
||||
def _handle_redirects(self, response, data="", content_type="", **extra):
|
||||
def _handle_redirects(
|
||||
self,
|
||||
response,
|
||||
data="",
|
||||
content_type="",
|
||||
headers=None,
|
||||
**extra,
|
||||
):
|
||||
"""
|
||||
Follow any redirects by requesting responses from the server using GET.
|
||||
"""
|
||||
@@ -1010,7 +1135,12 @@ class Client(ClientMixin, RequestFactory):
|
||||
content_type = None
|
||||
|
||||
response = request_method(
|
||||
path, data=data, content_type=content_type, follow=False, **extra
|
||||
path,
|
||||
data=data,
|
||||
content_type=content_type,
|
||||
follow=False,
|
||||
headers=headers,
|
||||
**extra,
|
||||
)
|
||||
response.redirect_chain = redirect_chain
|
||||
|
||||
@@ -1038,9 +1168,14 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
|
||||
self,
|
||||
enforce_csrf_checks=False,
|
||||
raise_request_exception=True,
|
||||
*,
|
||||
headers=None,
|
||||
**defaults,
|
||||
):
|
||||
super().__init__(**defaults)
|
||||
super().__init__(headers=headers, **defaults)
|
||||
self.handler = AsyncClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
|
||||
Reference in New Issue
Block a user