diff --git a/django/test/testcases.py b/django/test/testcases.py index a79a304547..02cd00c27f 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -6,8 +6,10 @@ from xml.dom.minidom import parseString, Node from django.conf import settings from django.core import mail from django.core.management import call_command +from django.core.signals import request_started from django.core.urlresolvers import clear_url_caches -from django.db import transaction, connection, connections, DEFAULT_DB_ALIAS +from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS, + reset_queries) from django.http import QueryDict from django.test import _doctest as doctest from django.test.client import Client @@ -220,10 +222,12 @@ class _AssertNumQueriesContext(object): self.old_debug_cursor = self.connection.use_debug_cursor self.connection.use_debug_cursor = True self.starting_queries = len(self.connection.queries) + request_started.disconnect(reset_queries) return self def __exit__(self, exc_type, exc_value, traceback): self.connection.use_debug_cursor = self.old_debug_cursor + request_started.connect(reset_queries) if exc_type is not None: return diff --git a/tests/regressiontests/test_utils/tests.py b/tests/regressiontests/test_utils/tests.py index 2a9c826451..d5dd739782 100644 --- a/tests/regressiontests/test_utils/tests.py +++ b/tests/regressiontests/test_utils/tests.py @@ -2,20 +2,13 @@ import sys from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature +from models import Person if sys.version_info >= (2, 5): - from tests_25 import AssertNumQueriesTests + from tests_25 import AssertNumQueriesContextManagerTests class SkippingTestCase(TestCase): - def test_assert_num_queries(self): - def test_func(): - raise ValueError - - self.assertRaises(ValueError, - self.assertNumQueries, 2, test_func - ) - def test_skip_unless_db_feature(self): "A test that might be skipped is actually called." # Total hack, but it works, just want an attribute that's always true. @@ -26,8 +19,37 @@ class SkippingTestCase(TestCase): self.assertRaises(ValueError, test_func) -class SaveRestoreWarningState(TestCase): +class AssertNumQueriesTests(TestCase): + def test_assert_num_queries(self): + def test_func(): + raise ValueError + self.assertRaises(ValueError, + self.assertNumQueries, 2, test_func + ) + + def test_assert_num_queries_with_client(self): + person = Person.objects.create(name='test') + + self.assertNumQueries( + 1, + self.client.get, + "/test_utils/get_person/%s/" % person.pk + ) + + self.assertNumQueries( + 1, + self.client.get, + "/test_utils/get_person/%s/" % person.pk + ) + + def test_func(): + self.client.get("/test_utils/get_person/%s/" % person.pk) + self.client.get("/test_utils/get_person/%s/" % person.pk) + self.assertNumQueries(2, test_func) + + +class SaveRestoreWarningState(TestCase): def test_save_restore_warnings_state(self): """ Ensure save_warnings_state/restore_warnings_state work correctly. diff --git a/tests/regressiontests/test_utils/tests_25.py b/tests/regressiontests/test_utils/tests_25.py index 4adea6c080..9fe9c838e5 100644 --- a/tests/regressiontests/test_utils/tests_25.py +++ b/tests/regressiontests/test_utils/tests_25.py @@ -5,7 +5,7 @@ from django.test import TestCase from models import Person -class AssertNumQueriesTests(TestCase): +class AssertNumQueriesContextManagerTests(TestCase): def test_simple(self): with self.assertNumQueries(0): pass @@ -26,3 +26,16 @@ class AssertNumQueriesTests(TestCase): with self.assertRaises(TypeError): with self.assertNumQueries(4000): raise TypeError + + def test_with_client(self): + person = Person.objects.create(name="test") + + with self.assertNumQueries(1): + self.client.get("/test_utils/get_person/%s/" % person.pk) + + with self.assertNumQueries(1): + self.client.get("/test_utils/get_person/%s/" % person.pk) + + with self.assertNumQueries(2): + self.client.get("/test_utils/get_person/%s/" % person.pk) + self.client.get("/test_utils/get_person/%s/" % person.pk) diff --git a/tests/regressiontests/test_utils/urls.py b/tests/regressiontests/test_utils/urls.py new file mode 100644 index 0000000000..2c5821bc44 --- /dev/null +++ b/tests/regressiontests/test_utils/urls.py @@ -0,0 +1,8 @@ +from django.conf.urls.defaults import patterns + +import views + + +urlpatterns = patterns('', + (r'^get_person/(\d+)/$', views.get_person), +) diff --git a/tests/regressiontests/test_utils/views.py b/tests/regressiontests/test_utils/views.py new file mode 100644 index 0000000000..62af0d9c47 --- /dev/null +++ b/tests/regressiontests/test_utils/views.py @@ -0,0 +1,7 @@ +from django.http import HttpResponse +from django.shortcuts import get_object_or_404 +from models import Person + +def get_person(request, pk): + person = get_object_or_404(Person, pk=pk) + return HttpResponse(person.name) \ No newline at end of file diff --git a/tests/urls.py b/tests/urls.py index 01d6408c5a..d254407cc4 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,5 +1,6 @@ from django.conf.urls.defaults import * + urlpatterns = patterns('', # test_client modeltest urls (r'^test_client/', include('modeltests.test_client.urls')), @@ -41,4 +42,7 @@ urlpatterns = patterns('', # special headers views (r'special_headers/', include('regressiontests.special_headers.urls')), + + # test util views + (r'test_utils/', include('regressiontests.test_utils.urls')), )