import os
import unittest
import warnings
from io import StringIO
from unittest import mock
from django.conf import STATICFILES_STORAGE_ALIAS, settings
from django.contrib.staticfiles.finders import get_finder, get_finders
from django.contrib.staticfiles.storage import staticfiles_storage
from django.core.exceptions import ImproperlyConfigured
from django.core.files.storage import default_storage
from django.db import (
IntegrityError,
connection,
connections,
models,
router,
transaction,
)
from django.forms import (
CharField,
EmailField,
Form,
IntegerField,
ValidationError,
formset_factory,
)
from django.http import HttpResponse
from django.template.loader import render_to_string
from django.test import (
SimpleTestCase,
TestCase,
TransactionTestCase,
skipIfDBFeature,
skipUnlessDBFeature,
)
from django.test.html import HTMLParseError, parse_html
from django.test.testcases import DatabaseOperationForbidden
from django.test.utils import (
CaptureQueriesContext,
TestContextDecorator,
ignore_warnings,
isolate_apps,
override_settings,
setup_test_environment,
)
from django.urls import NoReverseMatch, path, reverse, reverse_lazy
from django.utils.deprecation import RemovedInDjango51Warning
from django.utils.html import VOID_ELEMENTS
from django.utils.version import PY311
from .models import Car, Person, PossessedCar
from .views import empty_response
class SkippingTestCase(SimpleTestCase):
def _assert_skipping(self, func, expected_exc, msg=None):
try:
if msg is not None:
with self.assertRaisesMessage(expected_exc, msg):
func()
else:
with self.assertRaises(expected_exc):
func()
except unittest.SkipTest:
self.fail("%s should not result in a skipped test." % func.__name__)
def test_skip_unless_db_feature(self):
"""
Testing the django.test.skipUnlessDBFeature decorator.
"""
# Total hack, but it works, just want an attribute that's always true.
@skipUnlessDBFeature("__class__")
def test_func():
raise ValueError
@skipUnlessDBFeature("notprovided")
def test_func2():
raise ValueError
@skipUnlessDBFeature("__class__", "__class__")
def test_func3():
raise ValueError
@skipUnlessDBFeature("__class__", "notprovided")
def test_func4():
raise ValueError
self._assert_skipping(test_func, ValueError)
self._assert_skipping(test_func2, unittest.SkipTest)
self._assert_skipping(test_func3, ValueError)
self._assert_skipping(test_func4, unittest.SkipTest)
class SkipTestCase(SimpleTestCase):
@skipUnlessDBFeature("missing")
def test_foo(self):
pass
self._assert_skipping(
SkipTestCase("test_foo").test_foo,
ValueError,
"skipUnlessDBFeature cannot be used on test_foo (test_utils.tests."
"SkippingTestCase.test_skip_unless_db_feature..SkipTestCase%s) "
"as SkippingTestCase.test_skip_unless_db_feature..SkipTestCase "
"doesn't allow queries against the 'default' database."
# Python 3.11 uses fully qualified test name in the output.
% (".test_foo" if PY311 else ""),
)
def test_skip_if_db_feature(self):
"""
Testing the django.test.skipIfDBFeature decorator.
"""
@skipIfDBFeature("__class__")
def test_func():
raise ValueError
@skipIfDBFeature("notprovided")
def test_func2():
raise ValueError
@skipIfDBFeature("__class__", "__class__")
def test_func3():
raise ValueError
@skipIfDBFeature("__class__", "notprovided")
def test_func4():
raise ValueError
@skipIfDBFeature("notprovided", "notprovided")
def test_func5():
raise ValueError
self._assert_skipping(test_func, unittest.SkipTest)
self._assert_skipping(test_func2, ValueError)
self._assert_skipping(test_func3, unittest.SkipTest)
self._assert_skipping(test_func4, unittest.SkipTest)
self._assert_skipping(test_func5, ValueError)
class SkipTestCase(SimpleTestCase):
@skipIfDBFeature("missing")
def test_foo(self):
pass
self._assert_skipping(
SkipTestCase("test_foo").test_foo,
ValueError,
"skipIfDBFeature cannot be used on test_foo (test_utils.tests."
"SkippingTestCase.test_skip_if_db_feature..SkipTestCase%s) "
"as SkippingTestCase.test_skip_if_db_feature..SkipTestCase "
"doesn't allow queries against the 'default' database."
# Python 3.11 uses fully qualified test name in the output.
% (".test_foo" if PY311 else ""),
)
class SkippingClassTestCase(TestCase):
def test_skip_class_unless_db_feature(self):
@skipUnlessDBFeature("__class__")
class NotSkippedTests(TestCase):
def test_dummy(self):
return
@skipUnlessDBFeature("missing")
@skipIfDBFeature("__class__")
class SkippedTests(TestCase):
def test_will_be_skipped(self):
self.fail("We should never arrive here.")
@skipIfDBFeature("__dict__")
class SkippedTestsSubclass(SkippedTests):
pass
test_suite = unittest.TestSuite()
test_suite.addTest(NotSkippedTests("test_dummy"))
try:
test_suite.addTest(SkippedTests("test_will_be_skipped"))
test_suite.addTest(SkippedTestsSubclass("test_will_be_skipped"))
except unittest.SkipTest:
self.fail("SkipTest should not be raised here.")
result = unittest.TextTestRunner(stream=StringIO()).run(test_suite)
self.assertEqual(result.testsRun, 3)
self.assertEqual(len(result.skipped), 2)
self.assertEqual(result.skipped[0][1], "Database has feature(s) __class__")
self.assertEqual(result.skipped[1][1], "Database has feature(s) __class__")
def test_missing_default_databases(self):
@skipIfDBFeature("missing")
class MissingDatabases(SimpleTestCase):
def test_assertion_error(self):
pass
suite = unittest.TestSuite()
try:
suite.addTest(MissingDatabases("test_assertion_error"))
except unittest.SkipTest:
self.fail("SkipTest should not be raised at this stage")
runner = unittest.TextTestRunner(stream=StringIO())
msg = (
"skipIfDBFeature cannot be used on ."
"MissingDatabases'> as it doesn't allow queries against the "
"'default' database."
)
with self.assertRaisesMessage(ValueError, msg):
runner.run(suite)
@override_settings(ROOT_URLCONF="test_utils.urls")
class AssertNumQueriesTests(TestCase):
def test_assert_num_queries(self):
def test_func():
raise ValueError
with self.assertRaises(ValueError):
self.assertNumQueries(2, test_func)
def test_assert_num_queries_with_client(self):
person = Person.objects.create(name="test")
self.assertNumQueries(
1, self.client.get, "/test_utils/get_person/%s/" % person.pk
)
self.assertNumQueries(
1, self.client.get, "/test_utils/get_person/%s/" % person.pk
)
def test_func():
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.assertNumQueries(2, test_func)
class AssertNumQueriesUponConnectionTests(TransactionTestCase):
available_apps = []
def test_ignores_connection_configuration_queries(self):
real_ensure_connection = connection.ensure_connection
connection.close()
def make_configuration_query():
is_opening_connection = connection.connection is None
real_ensure_connection()
if is_opening_connection:
# Avoid infinite recursion. Creating a cursor calls
# ensure_connection() which is currently mocked by this method.
with connection.cursor() as cursor:
cursor.execute("SELECT 1" + connection.features.bare_select_suffix)
ensure_connection = (
"django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection"
)
with mock.patch(ensure_connection, side_effect=make_configuration_query):
with self.assertNumQueries(1):
list(Car.objects.all())
class AssertQuerySetEqualTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.p1 = Person.objects.create(name="p1")
cls.p2 = Person.objects.create(name="p2")
def test_rename_assertquerysetequal_deprecation_warning(self):
msg = "assertQuerysetEqual() is deprecated in favor of assertQuerySetEqual()."
with self.assertRaisesMessage(RemovedInDjango51Warning, msg):
self.assertQuerysetEqual()
@ignore_warnings(category=RemovedInDjango51Warning)
def test_deprecated_assertquerysetequal(self):
self.assertQuerysetEqual(Person.objects.filter(name="p3"), [])
def test_empty(self):
self.assertQuerySetEqual(Person.objects.filter(name="p3"), [])
def test_ordered(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
[self.p1, self.p2],
)
def test_unordered(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"), [self.p2, self.p1], ordered=False
)
def test_queryset(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
Person.objects.order_by("name"),
)
def test_flat_values_list(self):
self.assertQuerySetEqual(
Person.objects.order_by("name").values_list("name", flat=True),
["p1", "p2"],
)
def test_transform(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
[self.p1.pk, self.p2.pk],
transform=lambda x: x.pk,
)
def test_repr_transform(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
[repr(self.p1), repr(self.p2)],
transform=repr,
)
def test_undefined_order(self):
# Using an unordered queryset with more than one ordered value
# is an error.
msg = (
"Trying to compare non-ordered queryset against more than one "
"ordered value."
)
with self.assertRaisesMessage(ValueError, msg):
self.assertQuerySetEqual(
Person.objects.all(),
[self.p1, self.p2],
)
# No error for one value.
self.assertQuerySetEqual(Person.objects.filter(name="p1"), [self.p1])
def test_repeated_values(self):
"""
assertQuerySetEqual checks the number of appearance of each item
when used with option ordered=False.
"""
batmobile = Car.objects.create(name="Batmobile")
k2000 = Car.objects.create(name="K 2000")
PossessedCar.objects.bulk_create(
[
PossessedCar(car=batmobile, belongs_to=self.p1),
PossessedCar(car=batmobile, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
]
)
with self.assertRaises(AssertionError):
self.assertQuerySetEqual(
self.p1.cars.all(), [batmobile, k2000], ordered=False
)
self.assertQuerySetEqual(
self.p1.cars.all(), [batmobile] * 2 + [k2000] * 4, ordered=False
)
def test_maxdiff(self):
names = ["Joe Smith %s" % i for i in range(20)]
Person.objects.bulk_create([Person(name=name) for name in names])
names.append("Extra Person")
with self.assertRaises(AssertionError) as ctx:
self.assertQuerySetEqual(
Person.objects.filter(name__startswith="Joe"),
names,
ordered=False,
transform=lambda p: p.name,
)
self.assertIn("Set self.maxDiff to None to see it.", str(ctx.exception))
original = self.maxDiff
self.maxDiff = None
try:
with self.assertRaises(AssertionError) as ctx:
self.assertQuerySetEqual(
Person.objects.filter(name__startswith="Joe"),
names,
ordered=False,
transform=lambda p: p.name,
)
finally:
self.maxDiff = original
exception_msg = str(ctx.exception)
self.assertNotIn("Set self.maxDiff to None to see it.", exception_msg)
for name in names:
self.assertIn(name, exception_msg)
@override_settings(ROOT_URLCONF="test_utils.urls")
class CaptureQueriesContextManagerTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.person_pk = str(Person.objects.create(name="test").pk)
def test_simple(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.get(pk=self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
with CaptureQueriesContext(connection) as captured_queries:
pass
self.assertEqual(0, len(captured_queries))
def test_within(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.get(pk=self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
def test_nested(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.count()
with CaptureQueriesContext(connection) as nested_captured_queries:
Person.objects.count()
self.assertEqual(1, len(nested_captured_queries))
self.assertEqual(2, len(captured_queries))
def test_failure(self):
with self.assertRaises(TypeError):
with CaptureQueriesContext(connection):
raise TypeError
def test_with_client(self):
with CaptureQueriesContext(connection) as captured_queries:
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
with CaptureQueriesContext(connection) as captured_queries:
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
with CaptureQueriesContext(connection) as captured_queries:
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.assertEqual(len(captured_queries), 2)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
self.assertIn(self.person_pk, captured_queries[1]["sql"])
@override_settings(ROOT_URLCONF="test_utils.urls")
class AssertNumQueriesContextManagerTests(TestCase):
def test_simple(self):
with self.assertNumQueries(0):
pass
with self.assertNumQueries(1):
Person.objects.count()
with self.assertNumQueries(2):
Person.objects.count()
Person.objects.count()
def test_failure(self):
msg = "1 != 2 : 1 queries executed, 2 expected\nCaptured queries were:\n1."
with self.assertRaisesMessage(AssertionError, msg):
with self.assertNumQueries(2):
Person.objects.count()
with self.assertRaises(TypeError):
with self.assertNumQueries(4000):
raise TypeError
def test_with_client(self):
person = Person.objects.create(name="test")
with self.assertNumQueries(1):
self.client.get("/test_utils/get_person/%s/" % person.pk)
with self.assertNumQueries(1):
self.client.get("/test_utils/get_person/%s/" % person.pk)
with self.assertNumQueries(2):
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.client.get("/test_utils/get_person/%s/" % person.pk)
@override_settings(ROOT_URLCONF="test_utils.urls")
class AssertTemplateUsedContextManagerTests(SimpleTestCase):
def test_usage(self):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/base.html")
with self.assertTemplateUsed(template_name="template_used/base.html"):
render_to_string("template_used/base.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/include.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/extends.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/base.html")
render_to_string("template_used/base.html")
def test_nested_usage(self):
with self.assertTemplateUsed("template_used/base.html"):
with self.assertTemplateUsed("template_used/include.html"):
render_to_string("template_used/include.html")
with self.assertTemplateUsed("template_used/extends.html"):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/extends.html")
with self.assertTemplateUsed("template_used/base.html"):
with self.assertTemplateUsed("template_used/alternative.html"):
render_to_string("template_used/alternative.html")
render_to_string("template_used/base.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/extends.html")
with self.assertTemplateNotUsed("template_used/base.html"):
render_to_string("template_used/alternative.html")
render_to_string("template_used/base.html")
def test_not_used(self):
with self.assertTemplateNotUsed("template_used/base.html"):
pass
with self.assertTemplateNotUsed("template_used/alternative.html"):
pass
def test_error_message(self):
msg = "No templates used to render the response"
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html"):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(template_name="template_used/base.html"):
pass
msg2 = (
"Template 'template_used/base.html' was not a template used to render "
"the response. Actual template(s) used: template_used/alternative.html"
)
with self.assertRaisesMessage(AssertionError, msg2):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/alternative.html")
with self.assertRaisesMessage(
AssertionError, "No templates used to render the response"
):
response = self.client.get("/test_utils/no_template_used/")
self.assertTemplateUsed(response, "template_used/base.html")
def test_msg_prefix(self):
msg_prefix = "Prefix"
msg = f"{msg_prefix}: No templates used to render the response"
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(
"template_used/base.html", msg_prefix=msg_prefix
):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(
template_name="template_used/base.html",
msg_prefix=msg_prefix,
):
pass
msg = (
f"{msg_prefix}: Template 'template_used/base.html' was not a "
f"template used to render the response. Actual template(s) used: "
f"template_used/alternative.html"
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(
"template_used/base.html", msg_prefix=msg_prefix
):
render_to_string("template_used/alternative.html")
def test_count(self):
with self.assertTemplateUsed("template_used/base.html", count=2):
render_to_string("template_used/base.html")
render_to_string("template_used/base.html")
msg = (
"Template 'template_used/base.html' was expected to be rendered "
"3 time(s) but was actually rendered 2 time(s)."
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html", count=3):
render_to_string("template_used/base.html")
render_to_string("template_used/base.html")
def test_failure(self):
msg = "response and/or template_name argument must be provided"
with self.assertRaisesMessage(TypeError, msg):
with self.assertTemplateUsed():
pass
msg = "No templates used to render the response"
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(""):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(""):
render_to_string("template_used/base.html")
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(template_name=""):
pass
msg = (
"Template 'template_used/base.html' was not a template used to "
"render the response. Actual template(s) used: "
"template_used/alternative.html"
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/alternative.html")
def test_assert_used_on_http_response(self):
response = HttpResponse()
msg = "%s() is only usable on responses fetched using the Django test Client."
with self.assertRaisesMessage(ValueError, msg % "assertTemplateUsed"):
self.assertTemplateUsed(response, "template.html")
with self.assertRaisesMessage(ValueError, msg % "assertTemplateNotUsed"):
self.assertTemplateNotUsed(response, "template.html")
class HTMLEqualTests(SimpleTestCase):
def test_html_parser(self):
element = parse_html("
")
self.assertEqual(dom2.count(dom1), 0)
# HTML with a root element contains the same HTML with no root element.
dom1 = parse_html("
foo
bar
")
dom2 = parse_html("
foo
bar
")
self.assertEqual(dom2.count(dom1), 1)
# Target of search is a sequence of child elements and appears more
# than once.
dom2 = parse_html("
foo
bar
foo
bar
")
self.assertEqual(dom2.count(dom1), 2)
# Searched HTML has additional children.
dom1 = parse_html("")
dom2 = parse_html("")
self.assertEqual(dom2.count(dom1), 1)
# No match found in children.
dom1 = parse_html("")
self.assertEqual(dom2.count(dom1), 0)
# Target of search found among children and grandchildren.
dom1 = parse_html("")
dom2 = parse_html("")
self.assertEqual(dom2.count(dom1), 2)
def test_root_element_escaped_html(self):
html = "<br>"
parsed = parse_html(html)
self.assertEqual(str(parsed), html)
def test_parsing_errors(self):
with self.assertRaises(AssertionError):
self.assertHTMLEqual("
", "")
with self.assertRaises(AssertionError):
self.assertHTMLEqual("", "
")
error_msg = (
"First argument is not valid HTML:\n"
"('Unexpected end tag `div` (Line 1, Column 6)', (1, 6))"
)
with self.assertRaisesMessage(AssertionError, error_msg):
self.assertHTMLEqual("< div> div>", "
")
with self.assertRaises(HTMLParseError):
parse_html("")
def test_escaped_html_errors(self):
msg = "
\n\n
!=
\n<foo>\n
\n"
with self.assertRaisesMessage(AssertionError, msg):
self.assertHTMLEqual("
", "
<foo>
")
with self.assertRaisesMessage(AssertionError, msg):
self.assertHTMLEqual("
", "
<foo>
")
def test_contains_html(self):
response = HttpResponse(
"""
This is a form: """
)
self.assertNotContains(response, "")
self.assertContains(response, '