diff --git a/django/test/__init__.py b/django/test/__init__.py index b3198fcf88..68aea9aa49 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -2,6 +2,6 @@ Django Unit Test and Doctest framework. """ -from django.test.client import Client +from django.test.client import Client, RequestFactory from django.test.testcases import TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import Approximate diff --git a/django/test/client.py b/django/test/client.py index 05880884f0..f93664f6cb 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -156,55 +156,29 @@ def encode_file(boundary, key, file): file.read() ] -class Client(object): + + +class RequestFactory(object): """ - A class that can act as a client for testing purposes. + Class that lets you create mock Request objects for use in testing. - It allows the user to compose GET and POST requests, and - obtain the response that the server gave to those requests. - The server Response objects are annotated with the details - of the contexts and templates that were rendered during the - process of serving the request. + Usage: - Client objects are stateful - they will retain cookie (and - thus session) details for the lifetime of the Client instance. + rf = RequestFactory() + get_request = rf.get('/hello/') + post_request = rf.post('/submit/', {'foo': 'bar'}) - This is not intended as a replacement for Twill/Selenium or - the like - it is here to allow testing against the - contexts and templates produced by a view, rather than the - HTML rendered to the end-user. + Once you have a request object you can pass it to any view function, + just as if that view had been hooked up using a URLconf. """ - def __init__(self, enforce_csrf_checks=False, **defaults): - self.handler = ClientHandler(enforce_csrf_checks) + def __init__(self, **defaults): self.defaults = defaults self.cookies = SimpleCookie() - self.exc_info = None self.errors = StringIO() - def store_exc_info(self, **kwargs): + def _base_environ(self, **request): """ - Stores exceptions when they are generated by a view. - """ - self.exc_info = sys.exc_info() - - def _session(self): - """ - Obtains the current session variables. - """ - if 'django.contrib.sessions' in settings.INSTALLED_APPS: - engine = import_module(settings.SESSION_ENGINE) - cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None) - if cookie: - return engine.SessionStore(cookie.value) - return {} - session = property(_session) - - def request(self, **request): - """ - The master request method. Composes the environment dictionary - and passes to the handler, returning the result of the handler. - Assumes defaults for the query environment, which can be overridden - using the arguments to the request. + The base environment for a request. """ environ = { 'HTTP_COOKIE': self.cookies.output(header='', sep='; '), @@ -225,6 +199,171 @@ class Client(object): } environ.update(self.defaults) environ.update(request) + return environ + + def request(self, **request): + "Construct a generic request object." + return WSGIRequest(self._base_environ(**request)) + + def get(self, path, data={}, **extra): + "Construct a GET request" + + parsed = urlparse(path) + r = { + 'CONTENT_TYPE': 'text/html; charset=utf-8', + 'PATH_INFO': urllib.unquote(parsed[2]), + 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], + 'REQUEST_METHOD': 'GET', + 'wsgi.input': FakePayload('') + } + r.update(extra) + return self.request(**r) + + def post(self, path, data={}, content_type=MULTIPART_CONTENT, + **extra): + "Construct a POST request." + + if content_type is MULTIPART_CONTENT: + post_data = encode_multipart(BOUNDARY, data) + else: + # Encode the content so that the byte representation is correct. + match = CONTENT_TYPE_RE.match(content_type) + if match: + charset = match.group(1) + else: + charset = settings.DEFAULT_CHARSET + post_data = smart_str(data, encoding=charset) + + parsed = urlparse(path) + r = { + 'CONTENT_LENGTH': len(post_data), + 'CONTENT_TYPE': content_type, + 'PATH_INFO': urllib.unquote(parsed[2]), + 'QUERY_STRING': parsed[4], + 'REQUEST_METHOD': 'POST', + 'wsgi.input': FakePayload(post_data), + } + r.update(extra) + return self.request(**r) + + def head(self, path, data={}, **extra): + "Construct a HEAD request." + + parsed = urlparse(path) + r = { + 'CONTENT_TYPE': 'text/html; charset=utf-8', + 'PATH_INFO': urllib.unquote(parsed[2]), + 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], + 'REQUEST_METHOD': 'HEAD', + 'wsgi.input': FakePayload('') + } + r.update(extra) + return self.request(**r) + + def options(self, path, data={}, **extra): + "Constrict an OPTIONS request" + + parsed = urlparse(path) + r = { + 'PATH_INFO': urllib.unquote(parsed[2]), + 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], + 'REQUEST_METHOD': 'OPTIONS', + 'wsgi.input': FakePayload('') + } + r.update(extra) + return self.request(**r) + + def put(self, path, data={}, content_type=MULTIPART_CONTENT, + **extra): + "Construct a PUT request." + + if content_type is MULTIPART_CONTENT: + post_data = encode_multipart(BOUNDARY, data) + else: + post_data = data + + # Make `data` into a querystring only if it's not already a string. If + # it is a string, we'll assume that the caller has already encoded it. + query_string = None + if not isinstance(data, basestring): + query_string = urlencode(data, doseq=True) + + parsed = urlparse(path) + r = { + 'CONTENT_LENGTH': len(post_data), + 'CONTENT_TYPE': content_type, + 'PATH_INFO': urllib.unquote(parsed[2]), + 'QUERY_STRING': query_string or parsed[4], + 'REQUEST_METHOD': 'PUT', + 'wsgi.input': FakePayload(post_data), + } + r.update(extra) + return self.request(**r) + + def delete(self, path, data={}, **extra): + "Construct a DELETE request." + + parsed = urlparse(path) + r = { + 'PATH_INFO': urllib.unquote(parsed[2]), + 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], + 'REQUEST_METHOD': 'DELETE', + 'wsgi.input': FakePayload('') + } + r.update(extra) + return self.request(**r) + + +class Client(RequestFactory): + """ + A class that can act as a client for testing purposes. + + It allows the user to compose GET and POST requests, and + obtain the response that the server gave to those requests. + The server Response objects are annotated with the details + of the contexts and templates that were rendered during the + process of serving the request. + + Client objects are stateful - they will retain cookie (and + thus session) details for the lifetime of the Client instance. + + This is not intended as a replacement for Twill/Selenium or + the like - it is here to allow testing against the + contexts and templates produced by a view, rather than the + HTML rendered to the end-user. + """ + def __init__(self, enforce_csrf_checks=False, **defaults): + super(Client, self).__init__(**defaults) + self.handler = ClientHandler(enforce_csrf_checks) + self.exc_info = None + + def store_exc_info(self, **kwargs): + """ + Stores exceptions when they are generated by a view. + """ + self.exc_info = sys.exc_info() + + def _session(self): + """ + Obtains the current session variables. + """ + if 'django.contrib.sessions' in settings.INSTALLED_APPS: + engine = import_module(settings.SESSION_ENGINE) + cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None) + if cookie: + return engine.SessionStore(cookie.value) + return {} + session = property(_session) + + + def request(self, **request): + """ + The master request method. Composes the environment dictionary + and passes to the handler, returning the result of the handler. + Assumes defaults for the query environment, which can be overridden + using the arguments to the request. + """ + environ = self._base_environ(**request) # Curry a data dictionary into an instance of the template renderer # callback function. @@ -290,22 +429,11 @@ class Client(object): signals.template_rendered.disconnect(dispatch_uid="template-render") got_request_exception.disconnect(dispatch_uid="request-exception") - def get(self, path, data={}, follow=False, **extra): """ Requests a response from the server using GET. """ - parsed = urlparse(path) - r = { - 'CONTENT_TYPE': 'text/html; charset=utf-8', - 'PATH_INFO': urllib.unquote(parsed[2]), - 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], - 'REQUEST_METHOD': 'GET', - 'wsgi.input': FakePayload('') - } - r.update(extra) - - response = self.request(**r) + response = super(Client, self).get(path, data=data, **extra) if follow: response = self._handle_redirects(response, **extra) return response @@ -315,29 +443,7 @@ class Client(object): """ Requests a response from the server using POST. """ - if content_type is MULTIPART_CONTENT: - post_data = encode_multipart(BOUNDARY, data) - else: - # Encode the content so that the byte representation is correct. - match = CONTENT_TYPE_RE.match(content_type) - if match: - charset = match.group(1) - else: - charset = settings.DEFAULT_CHARSET - post_data = smart_str(data, encoding=charset) - - parsed = urlparse(path) - r = { - 'CONTENT_LENGTH': len(post_data), - 'CONTENT_TYPE': content_type, - 'PATH_INFO': urllib.unquote(parsed[2]), - 'QUERY_STRING': parsed[4], - 'REQUEST_METHOD': 'POST', - 'wsgi.input': FakePayload(post_data), - } - r.update(extra) - - response = self.request(**r) + response = super(Client, self).post(path, data=data, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response @@ -346,17 +452,7 @@ class Client(object): """ Request a response from the server using HEAD. """ - parsed = urlparse(path) - r = { - 'CONTENT_TYPE': 'text/html; charset=utf-8', - 'PATH_INFO': urllib.unquote(parsed[2]), - 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], - 'REQUEST_METHOD': 'HEAD', - 'wsgi.input': FakePayload('') - } - r.update(extra) - - response = self.request(**r) + response = super(Client, self).head(path, data=data, **extra) if follow: response = self._handle_redirects(response, **extra) return response @@ -365,16 +461,7 @@ class Client(object): """ Request a response from the server using OPTIONS. """ - parsed = urlparse(path) - r = { - 'PATH_INFO': urllib.unquote(parsed[2]), - 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], - 'REQUEST_METHOD': 'OPTIONS', - 'wsgi.input': FakePayload('') - } - r.update(extra) - - response = self.request(**r) + response = super(Client, self).options(path, data=data, **extra) if follow: response = self._handle_redirects(response, **extra) return response @@ -384,29 +471,7 @@ class Client(object): """ Send a resource to the server using PUT. """ - if content_type is MULTIPART_CONTENT: - post_data = encode_multipart(BOUNDARY, data) - else: - post_data = data - - # Make `data` into a querystring only if it's not already a string. If - # it is a string, we'll assume that the caller has already encoded it. - query_string = None - if not isinstance(data, basestring): - query_string = urlencode(data, doseq=True) - - parsed = urlparse(path) - r = { - 'CONTENT_LENGTH': len(post_data), - 'CONTENT_TYPE': content_type, - 'PATH_INFO': urllib.unquote(parsed[2]), - 'QUERY_STRING': query_string or parsed[4], - 'REQUEST_METHOD': 'PUT', - 'wsgi.input': FakePayload(post_data), - } - r.update(extra) - - response = self.request(**r) + response = super(Client, self).put(path, data=data, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response @@ -415,23 +480,14 @@ class Client(object): """ Send a DELETE request to the server. """ - parsed = urlparse(path) - r = { - 'PATH_INFO': urllib.unquote(parsed[2]), - 'QUERY_STRING': urlencode(data, doseq=True) or parsed[4], - 'REQUEST_METHOD': 'DELETE', - 'wsgi.input': FakePayload('') - } - r.update(extra) - - response = self.request(**r) + response = super(Client, self).delete(path, data=data, **extra) if follow: response = self._handle_redirects(response, **extra) return response def login(self, **credentials): """ - Sets the Client to appear as if it has successfully logged into a site. + Sets the Factory to appear as if it has successfully logged into a site. Returns True if login is possible; False if the provided credentials are incorrect, or the user is inactive, or if the sessions framework is @@ -506,4 +562,3 @@ class Client(object): if response.redirect_chain[-1] in response.redirect_chain[0:-1]: break return response - diff --git a/docs/topics/testing.txt b/docs/topics/testing.txt index 465807021a..efd2593b71 100644 --- a/docs/topics/testing.txt +++ b/docs/topics/testing.txt @@ -1014,6 +1014,51 @@ The following is a simple unit test using the test client:: # Check that the rendered context contains 5 customers. self.assertEqual(len(response.context['customers']), 5) +The request factory +------------------- + +.. Class:: RequestFactory + +The :class:`~django.test.client.RequestFactory` is a simplified +version of the test client that provides a way to generate a request +instance that can be used as the first argument to any view. This +means you can test a view function the same way as you would test any +other function -- as a black box, with exactly known inputs, testing +for specific outputs. + +The API for the :class:`~django.test.client.RequestFactory` is a slightly +restricted subset of the test client API: + + * It only has access to the HTTP methods :meth:`~Client.get()`, + :meth:`~Client.post()`, :meth:`~Client.put()`, + :meth:`~Client.delete()`, :meth:`~Client.head()` and + :meth:`~Client.options()`. + + * These methods accept all the same arguments *except* for + ``follows``. Since this is just a factory for producing + requests, it's up to you to handle the response. + +Example +~~~~~~~ + +The following is a simple unit test using the request factory:: + + from django.utils import unittest + from django.test.client import RequestFactory + + class SimpleTest(unittest.TestCase): + def setUp(self): + # Every test needs a client. + self.factory = RequestFactory() + + def test_details(self): + # Issue a GET request. + request = self.factory.get('/customer/details') + + # Test my_view() as if it were deployed at /customer/details + response = my_view(request) + self.assertEquals(response.status_code, 200) + TestCase -------- diff --git a/tests/modeltests/test_client/models.py b/tests/modeltests/test_client/models.py index 4a0c9d2ae0..59814a9bf1 100644 --- a/tests/modeltests/test_client/models.py +++ b/tests/modeltests/test_client/models.py @@ -20,9 +20,12 @@ testing against the contexts and templates produced by a view, rather than the HTML rendered to the end-user. """ -from django.test import Client, TestCase from django.conf import settings from django.core import mail +from django.test import Client, TestCase, RequestFactory + +from views import get_view + class ClientTest(TestCase): fixtures = ['testdata.json'] @@ -469,3 +472,12 @@ class CustomTestClientTest(TestCase): """A test case can specify a custom class for self.client.""" self.assertEqual(hasattr(self.client, "i_am_customized"), True) + +class RequestFactoryTest(TestCase): + def test_request_factory(self): + factory = RequestFactory() + request = factory.get('/somewhere/') + response = get_view(request) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'This is a test')