import os
import unittest
import warnings
from io import StringIO
from unittest import mock
from django.conf import settings
from django.contrib.staticfiles.finders import get_finder, get_finders
from django.contrib.staticfiles.storage import staticfiles_storage
from django.core.exceptions import ImproperlyConfigured
from django.core.files.storage import default_storage
from django.db import (
IntegrityError, connection, connections, models, router, transaction,
)
from django.forms import EmailField, IntegerField
from django.http import HttpResponse
from django.template.loader import render_to_string
from django.test import (
SimpleTestCase, TestCase, TransactionTestCase, ignore_warnings,
skipIfDBFeature, skipUnlessDBFeature,
)
from django.test.html import HTMLParseError, parse_html
from django.test.utils import (
CaptureQueriesContext, TestContextDecorator, isolate_apps,
override_settings, setup_test_environment,
)
from django.urls import NoReverseMatch, path, reverse, reverse_lazy
from django.utils.deprecation import RemovedInDjango41Warning
from .models import Car, Person, PossessedCar
from .views import empty_response
class SkippingTestCase(SimpleTestCase):
def _assert_skipping(self, func, expected_exc, msg=None):
try:
if msg is not None:
with self.assertRaisesMessage(expected_exc, msg):
func()
else:
with self.assertRaises(expected_exc):
func()
except unittest.SkipTest:
self.fail('%s should not result in a skipped test.' % func.__name__)
def test_skip_unless_db_feature(self):
"""
Testing the django.test.skipUnlessDBFeature decorator.
"""
# Total hack, but it works, just want an attribute that's always true.
@skipUnlessDBFeature("__class__")
def test_func():
raise ValueError
@skipUnlessDBFeature("notprovided")
def test_func2():
raise ValueError
@skipUnlessDBFeature("__class__", "__class__")
def test_func3():
raise ValueError
@skipUnlessDBFeature("__class__", "notprovided")
def test_func4():
raise ValueError
self._assert_skipping(test_func, ValueError)
self._assert_skipping(test_func2, unittest.SkipTest)
self._assert_skipping(test_func3, ValueError)
self._assert_skipping(test_func4, unittest.SkipTest)
class SkipTestCase(SimpleTestCase):
@skipUnlessDBFeature('missing')
def test_foo(self):
pass
self._assert_skipping(
SkipTestCase('test_foo').test_foo,
ValueError,
"skipUnlessDBFeature cannot be used on test_foo (test_utils.tests."
"SkippingTestCase.test_skip_unless_db_feature..SkipTestCase) "
"as SkippingTestCase.test_skip_unless_db_feature..SkipTestCase "
"doesn't allow queries against the 'default' database."
)
def test_skip_if_db_feature(self):
"""
Testing the django.test.skipIfDBFeature decorator.
"""
@skipIfDBFeature("__class__")
def test_func():
raise ValueError
@skipIfDBFeature("notprovided")
def test_func2():
raise ValueError
@skipIfDBFeature("__class__", "__class__")
def test_func3():
raise ValueError
@skipIfDBFeature("__class__", "notprovided")
def test_func4():
raise ValueError
@skipIfDBFeature("notprovided", "notprovided")
def test_func5():
raise ValueError
self._assert_skipping(test_func, unittest.SkipTest)
self._assert_skipping(test_func2, ValueError)
self._assert_skipping(test_func3, unittest.SkipTest)
self._assert_skipping(test_func4, unittest.SkipTest)
self._assert_skipping(test_func5, ValueError)
class SkipTestCase(SimpleTestCase):
@skipIfDBFeature('missing')
def test_foo(self):
pass
self._assert_skipping(
SkipTestCase('test_foo').test_foo,
ValueError,
"skipIfDBFeature cannot be used on test_foo (test_utils.tests."
"SkippingTestCase.test_skip_if_db_feature..SkipTestCase) "
"as SkippingTestCase.test_skip_if_db_feature..SkipTestCase "
"doesn't allow queries against the 'default' database."
)
class SkippingClassTestCase(TestCase):
def test_skip_class_unless_db_feature(self):
@skipUnlessDBFeature("__class__")
class NotSkippedTests(TestCase):
def test_dummy(self):
return
@skipUnlessDBFeature("missing")
@skipIfDBFeature("__class__")
class SkippedTests(TestCase):
def test_will_be_skipped(self):
self.fail("We should never arrive here.")
@skipIfDBFeature("__dict__")
class SkippedTestsSubclass(SkippedTests):
pass
test_suite = unittest.TestSuite()
test_suite.addTest(NotSkippedTests('test_dummy'))
try:
test_suite.addTest(SkippedTests('test_will_be_skipped'))
test_suite.addTest(SkippedTestsSubclass('test_will_be_skipped'))
except unittest.SkipTest:
self.fail('SkipTest should not be raised here.')
result = unittest.TextTestRunner(stream=StringIO()).run(test_suite)
self.assertEqual(result.testsRun, 3)
self.assertEqual(len(result.skipped), 2)
self.assertEqual(result.skipped[0][1], 'Database has feature(s) __class__')
self.assertEqual(result.skipped[1][1], 'Database has feature(s) __class__')
def test_missing_default_databases(self):
@skipIfDBFeature('missing')
class MissingDatabases(SimpleTestCase):
def test_assertion_error(self):
pass
suite = unittest.TestSuite()
try:
suite.addTest(MissingDatabases('test_assertion_error'))
except unittest.SkipTest:
self.fail("SkipTest should not be raised at this stage")
runner = unittest.TextTestRunner(stream=StringIO())
msg = (
"skipIfDBFeature cannot be used on ."
"MissingDatabases'> as it doesn't allow queries against the "
"'default' database."
)
with self.assertRaisesMessage(ValueError, msg):
runner.run(suite)
@override_settings(ROOT_URLCONF='test_utils.urls')
class AssertNumQueriesTests(TestCase):
def test_assert_num_queries(self):
def test_func():
raise ValueError
with self.assertRaises(ValueError):
self.assertNumQueries(2, test_func)
def test_assert_num_queries_with_client(self):
person = Person.objects.create(name='test')
self.assertNumQueries(
1,
self.client.get,
"/test_utils/get_person/%s/" % person.pk
)
self.assertNumQueries(
1,
self.client.get,
"/test_utils/get_person/%s/" % person.pk
)
def test_func():
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.assertNumQueries(2, test_func)
@unittest.skipUnless(
connection.vendor != 'sqlite' or not connection.is_in_memory_db(),
'For SQLite in-memory tests, closing the connection destroys the database.'
)
class AssertNumQueriesUponConnectionTests(TransactionTestCase):
available_apps = []
def test_ignores_connection_configuration_queries(self):
real_ensure_connection = connection.ensure_connection
connection.close()
def make_configuration_query():
is_opening_connection = connection.connection is None
real_ensure_connection()
if is_opening_connection:
# Avoid infinite recursion. Creating a cursor calls
# ensure_connection() which is currently mocked by this method.
with connection.cursor() as cursor:
cursor.execute('SELECT 1' + connection.features.bare_select_suffix)
ensure_connection = 'django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection'
with mock.patch(ensure_connection, side_effect=make_configuration_query):
with self.assertNumQueries(1):
list(Car.objects.all())
class AssertQuerysetEqualTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.p1 = Person.objects.create(name='p1')
cls.p2 = Person.objects.create(name='p2')
def test_empty(self):
self.assertQuerysetEqual(Person.objects.filter(name='p3'), [])
def test_ordered(self):
self.assertQuerysetEqual(
Person.objects.all().order_by('name'),
[self.p1, self.p2],
)
def test_unordered(self):
self.assertQuerysetEqual(
Person.objects.all().order_by('name'),
[self.p2, self.p1],
ordered=False
)
def test_queryset(self):
self.assertQuerysetEqual(
Person.objects.all().order_by('name'),
Person.objects.all().order_by('name'),
)
def test_flat_values_list(self):
self.assertQuerysetEqual(
Person.objects.all().order_by('name').values_list('name', flat=True),
['p1', 'p2'],
)
def test_transform(self):
self.assertQuerysetEqual(
Person.objects.all().order_by('name'),
[self.p1.pk, self.p2.pk],
transform=lambda x: x.pk
)
def test_repr_transform(self):
self.assertQuerysetEqual(
Person.objects.all().order_by('name'),
[repr(self.p1), repr(self.p2)],
transform=repr,
)
def test_undefined_order(self):
# Using an unordered queryset with more than one ordered value
# is an error.
msg = 'Trying to compare non-ordered queryset against more than one ordered values'
with self.assertRaisesMessage(ValueError, msg):
self.assertQuerysetEqual(
Person.objects.all(),
[self.p1, self.p2],
)
# No error for one value.
self.assertQuerysetEqual(Person.objects.filter(name='p1'), [self.p1])
def test_repeated_values(self):
"""
assertQuerysetEqual checks the number of appearance of each item
when used with option ordered=False.
"""
batmobile = Car.objects.create(name='Batmobile')
k2000 = Car.objects.create(name='K 2000')
PossessedCar.objects.bulk_create([
PossessedCar(car=batmobile, belongs_to=self.p1),
PossessedCar(car=batmobile, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
])
with self.assertRaises(AssertionError):
self.assertQuerysetEqual(
self.p1.cars.all(),
[batmobile, k2000],
ordered=False
)
self.assertQuerysetEqual(
self.p1.cars.all(),
[batmobile] * 2 + [k2000] * 4,
ordered=False
)
def test_maxdiff(self):
names = ['Joe Smith %s' % i for i in range(20)]
Person.objects.bulk_create([Person(name=name) for name in names])
names.append('Extra Person')
with self.assertRaises(AssertionError) as ctx:
self.assertQuerysetEqual(
Person.objects.filter(name__startswith='Joe'),
names,
ordered=False,
transform=lambda p: p.name,
)
self.assertIn('Set self.maxDiff to None to see it.', str(ctx.exception))
original = self.maxDiff
self.maxDiff = None
try:
with self.assertRaises(AssertionError) as ctx:
self.assertQuerysetEqual(
Person.objects.filter(name__startswith='Joe'),
names,
ordered=False,
transform=lambda p: p.name,
)
finally:
self.maxDiff = original
exception_msg = str(ctx.exception)
self.assertNotIn('Set self.maxDiff to None to see it.', exception_msg)
for name in names:
self.assertIn(name, exception_msg)
class AssertQuerysetEqualDeprecationTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.p1 = Person.objects.create(name='p1')
cls.p2 = Person.objects.create(name='p2')
@ignore_warnings(category=RemovedInDjango41Warning)
def test_str_values(self):
self.assertQuerysetEqual(
Person.objects.all().order_by('name'),
[repr(self.p1), repr(self.p2)],
)
def test_str_values_warning(self):
msg = (
"In Django 4.1, repr() will not be called automatically on a "
"queryset when compared to string values. Set an explicit "
"'transform' to silence this warning."
)
with self.assertRaisesMessage(RemovedInDjango41Warning, msg):
self.assertQuerysetEqual(
Person.objects.all().order_by('name'),
[repr(self.p1), repr(self.p2)],
)
@override_settings(ROOT_URLCONF='test_utils.urls')
class CaptureQueriesContextManagerTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.person_pk = str(Person.objects.create(name='test').pk)
def test_simple(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.get(pk=self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]['sql'])
with CaptureQueriesContext(connection) as captured_queries:
pass
self.assertEqual(0, len(captured_queries))
def test_within(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.get(pk=self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]['sql'])
def test_nested(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.count()
with CaptureQueriesContext(connection) as nested_captured_queries:
Person.objects.count()
self.assertEqual(1, len(nested_captured_queries))
self.assertEqual(2, len(captured_queries))
def test_failure(self):
with self.assertRaises(TypeError):
with CaptureQueriesContext(connection):
raise TypeError
def test_with_client(self):
with CaptureQueriesContext(connection) as captured_queries:
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]['sql'])
with CaptureQueriesContext(connection) as captured_queries:
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]['sql'])
with CaptureQueriesContext(connection) as captured_queries:
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
self.assertEqual(len(captured_queries), 2)
self.assertIn(self.person_pk, captured_queries[0]['sql'])
self.assertIn(self.person_pk, captured_queries[1]['sql'])
@override_settings(ROOT_URLCONF='test_utils.urls')
class AssertNumQueriesContextManagerTests(TestCase):
def test_simple(self):
with self.assertNumQueries(0):
pass
with self.assertNumQueries(1):
Person.objects.count()
with self.assertNumQueries(2):
Person.objects.count()
Person.objects.count()
def test_failure(self):
msg = (
'1 != 2 : 1 queries executed, 2 expected\nCaptured queries were:\n'
'1.'
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertNumQueries(2):
Person.objects.count()
with self.assertRaises(TypeError):
with self.assertNumQueries(4000):
raise TypeError
def test_with_client(self):
person = Person.objects.create(name="test")
with self.assertNumQueries(1):
self.client.get("/test_utils/get_person/%s/" % person.pk)
with self.assertNumQueries(1):
self.client.get("/test_utils/get_person/%s/" % person.pk)
with self.assertNumQueries(2):
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.client.get("/test_utils/get_person/%s/" % person.pk)
@override_settings(ROOT_URLCONF='test_utils.urls')
class AssertTemplateUsedContextManagerTests(SimpleTestCase):
def test_usage(self):
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/base.html')
with self.assertTemplateUsed(template_name='template_used/base.html'):
render_to_string('template_used/base.html')
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/include.html')
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/extends.html')
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/base.html')
render_to_string('template_used/base.html')
def test_nested_usage(self):
with self.assertTemplateUsed('template_used/base.html'):
with self.assertTemplateUsed('template_used/include.html'):
render_to_string('template_used/include.html')
with self.assertTemplateUsed('template_used/extends.html'):
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/extends.html')
with self.assertTemplateUsed('template_used/base.html'):
with self.assertTemplateUsed('template_used/alternative.html'):
render_to_string('template_used/alternative.html')
render_to_string('template_used/base.html')
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/extends.html')
with self.assertTemplateNotUsed('template_used/base.html'):
render_to_string('template_used/alternative.html')
render_to_string('template_used/base.html')
def test_not_used(self):
with self.assertTemplateNotUsed('template_used/base.html'):
pass
with self.assertTemplateNotUsed('template_used/alternative.html'):
pass
def test_error_message(self):
msg = 'template_used/base.html was not rendered. No template was rendered.'
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed('template_used/base.html'):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(template_name='template_used/base.html'):
pass
msg2 = (
'template_used/base.html was not rendered. Following templates '
'were rendered: template_used/alternative.html'
)
with self.assertRaisesMessage(AssertionError, msg2):
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/alternative.html')
with self.assertRaisesMessage(AssertionError, 'No templates used to render the response'):
response = self.client.get('/test_utils/no_template_used/')
self.assertTemplateUsed(response, 'template_used/base.html')
def test_failure(self):
msg = 'response and/or template_name argument must be provided'
with self.assertRaisesMessage(TypeError, msg):
with self.assertTemplateUsed():
pass
msg = 'No templates used to render the response'
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(''):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(''):
render_to_string('template_used/base.html')
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(template_name=''):
pass
msg = (
'template_used/base.html was not rendered. Following '
'templates were rendered: template_used/alternative.html'
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed('template_used/base.html'):
render_to_string('template_used/alternative.html')
def test_assert_used_on_http_response(self):
response = HttpResponse()
error_msg = (
'assertTemplateUsed() and assertTemplateNotUsed() are only '
'usable on responses fetched using the Django test Client.'
)
with self.assertRaisesMessage(ValueError, error_msg):
self.assertTemplateUsed(response, 'template.html')
with self.assertRaisesMessage(ValueError, error_msg):
self.assertTemplateNotUsed(response, 'template.html')
class HTMLEqualTests(SimpleTestCase):
def test_html_parser(self):
element = parse_html('
')
self.assertEqual(dom2.count(dom1), 0)
# HTML with a root element contains the same HTML with no root element.
dom1 = parse_html('
foo
bar
')
dom2 = parse_html('
foo
bar
')
self.assertEqual(dom2.count(dom1), 1)
# Target of search is a sequence of child elements and appears more
# than once.
dom2 = parse_html('
foo
bar
foo
bar
')
self.assertEqual(dom2.count(dom1), 2)
# Searched HTML has additional children.
dom1 = parse_html('')
dom2 = parse_html('')
self.assertEqual(dom2.count(dom1), 1)
# No match found in children.
dom1 = parse_html('')
self.assertEqual(dom2.count(dom1), 0)
# Target of search found among children and grandchildren.
dom1 = parse_html('')
dom2 = parse_html('')
self.assertEqual(dom2.count(dom1), 2)
def test_parsing_errors(self):
with self.assertRaises(AssertionError):
self.assertHTMLEqual('
', '')
with self.assertRaises(AssertionError):
self.assertHTMLEqual('', '
')
error_msg = (
"First argument is not valid HTML:\n"
"('Unexpected end tag `div` (Line 1, Column 6)', (1, 6))"
)
with self.assertRaisesMessage(AssertionError, error_msg):
self.assertHTMLEqual('< div> div>', '
')
with self.assertRaises(HTMLParseError):
parse_html('')
def test_contains_html(self):
response = HttpResponse('''
This is a form: ''')
self.assertNotContains(response, "")
self.assertContains(response, '