mirror of
				https://github.com/django/django.git
				synced 2025-10-30 17:16:10 +00:00 
			
		
		
		
	Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given function executes the correct number of queries.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local): | ||||
|         self.settings_dict = settings_dict | ||||
|         self.alias = alias | ||||
|         self.vendor = 'unknown' | ||||
|         self.use_debug_cursor = None | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return self.settings_dict == other.settings_dict | ||||
| @@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local): | ||||
|     def cursor(self): | ||||
|         from django.conf import settings | ||||
|         cursor = self._cursor() | ||||
|         if settings.DEBUG: | ||||
|         if (self.use_debug_cursor or | ||||
|             (self.use_debug_cursor is None and settings.DEBUG)): | ||||
|             return self.make_debug_cursor(cursor) | ||||
|         return cursor | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import re | ||||
| import sys | ||||
| from urlparse import urlsplit, urlunsplit | ||||
| from xml.dom.minidom import parseString, Node | ||||
|  | ||||
| @@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner): | ||||
|         for conn in connections: | ||||
|             transaction.rollback_unless_managed(using=conn) | ||||
|  | ||||
| class _AssertNumQueriesContext(object): | ||||
|     def __init__(self, test_case, num, connection): | ||||
|         self.test_case = test_case | ||||
|         self.num = num | ||||
|         self.connection = connection | ||||
|  | ||||
|     def __enter__(self): | ||||
|         self.old_debug_cursor = self.connection.use_debug_cursor | ||||
|         self.connection.use_debug_cursor = True | ||||
|         self.starting_queries = len(self.connection.queries) | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_value, traceback): | ||||
|         if exc_type is not None: | ||||
|             return | ||||
|  | ||||
|         self.connection.use_debug_cursor = self.old_debug_cursor | ||||
|         final_queries = len(self.connection.queries) | ||||
|         executed = final_queries - self.starting_queries | ||||
|  | ||||
|         self.test_case.assertEqual( | ||||
|             executed, self.num, "%d queries executed, %d expected" % ( | ||||
|                 executed, self.num | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class TransactionTestCase(unittest.TestCase): | ||||
|     # The class we'll use for the test client self.client. | ||||
|     # Can be overridden in derived classes. | ||||
| @@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase): | ||||
|     def assertQuerysetEqual(self, qs, values, transform=repr): | ||||
|         return self.assertEqual(map(transform, qs), values) | ||||
|  | ||||
|     def assertNumQueries(self, num, func=None, *args, **kwargs): | ||||
|         using = kwargs.pop("using", DEFAULT_DB_ALIAS) | ||||
|         connection = connections[using] | ||||
|  | ||||
|         context = _AssertNumQueriesContext(self, num, connection) | ||||
|         if func is None: | ||||
|             return context | ||||
|  | ||||
|         # Basically emulate the `with` statement here. | ||||
|  | ||||
|         context.__enter__() | ||||
|         try: | ||||
|             func(*args, **kwargs) | ||||
|         finally: | ||||
|             context.__exit__(*sys.exc_info()) | ||||
|  | ||||
| def connections_support_transactions(): | ||||
|     """ | ||||
|     Returns True if all connections support transactions.  This is messy | ||||
|   | ||||
| @@ -1372,6 +1372,32 @@ cause of an failure in your test suite. | ||||
|     implicit ordering, you will need to apply a ``order_by()`` clause to your | ||||
|     queryset to ensure that the test will pass reliably. | ||||
|  | ||||
| .. method:: TestCase.assertNumQueries(num, func, *args, **kwargs): | ||||
|  | ||||
|     .. versionadded:: 1.3 | ||||
|  | ||||
|     Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that | ||||
|     ``num`` database queries are executed. | ||||
|  | ||||
|     If a ``"using"`` key is present in ``kwargs`` it is used as the database | ||||
|     alias for which to check the number of queries.  If you wish to call a | ||||
|     function with a ``using`` parameter you can do it by wrapping the call with | ||||
|     a ``lambda`` to add an extra parameter:: | ||||
|  | ||||
|         self.assertNumQueries(7, lambda: my_function(using=7)) | ||||
|  | ||||
|     If you're using Python 2.5 or greater you can also use this as a context | ||||
|     manager:: | ||||
|  | ||||
|         # This is necessary in Python 2.5 to enable the with statement, in 2.6 | ||||
|         # and up it is no longer necessary. | ||||
|         from __future__ import with_statement | ||||
|  | ||||
|         with self.assertNumQueries(2): | ||||
|             Person.objects.create(name="Aaron") | ||||
|             Person.objects.create(name="Daniel") | ||||
|  | ||||
|  | ||||
| .. _topics-testing-email: | ||||
|  | ||||
| E-mail services | ||||
|   | ||||
| @@ -1,6 +1,4 @@ | ||||
| from django.test import TestCase | ||||
| from django.conf import settings | ||||
| from django import db | ||||
|  | ||||
| from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species | ||||
|  | ||||
| @@ -36,73 +34,73 @@ class SelectRelatedTests(TestCase): | ||||
|         # queries so we'll set it to True here and reset it at the end of the | ||||
|         # test case. | ||||
|         self.create_base_data() | ||||
|         settings.DEBUG = True | ||||
|         db.reset_queries() | ||||
|  | ||||
|     def tearDown(self): | ||||
|         settings.DEBUG = False | ||||
|  | ||||
|     def test_access_fks_without_select_related(self): | ||||
|         """ | ||||
|         Normally, accessing FKs doesn't fill in related objects | ||||
|         """ | ||||
|         fly = Species.objects.get(name="melanogaster") | ||||
|         domain = fly.genus.family.order.klass.phylum.kingdom.domain | ||||
|         self.assertEqual(domain.name, 'Eukaryota') | ||||
|         self.assertEqual(len(db.connection.queries), 8) | ||||
|         def test(): | ||||
|             fly = Species.objects.get(name="melanogaster") | ||||
|             domain = fly.genus.family.order.klass.phylum.kingdom.domain | ||||
|             self.assertEqual(domain.name, 'Eukaryota') | ||||
|         self.assertNumQueries(8, test) | ||||
|  | ||||
|     def test_access_fks_with_select_related(self): | ||||
|         """ | ||||
|         A select_related() call will fill in those related objects without any | ||||
|         extra queries | ||||
|         """ | ||||
|         person = Species.objects.select_related(depth=10).get(name="sapiens") | ||||
|         domain = person.genus.family.order.klass.phylum.kingdom.domain | ||||
|         self.assertEqual(domain.name, 'Eukaryota') | ||||
|         self.assertEqual(len(db.connection.queries), 1) | ||||
|         def test(): | ||||
|             person = Species.objects.select_related(depth=10).get(name="sapiens") | ||||
|             domain = person.genus.family.order.klass.phylum.kingdom.domain | ||||
|             self.assertEqual(domain.name, 'Eukaryota') | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_list_without_select_related(self): | ||||
|         """ | ||||
|         select_related() also of course applies to entire lists, not just | ||||
|         items. This test verifies the expected behavior without select_related. | ||||
|         """ | ||||
|         world = Species.objects.all() | ||||
|         families = [o.genus.family.name for o in world] | ||||
|         self.assertEqual(families, [ | ||||
|             'Drosophilidae', | ||||
|             'Hominidae', | ||||
|             'Fabaceae', | ||||
|             'Amanitacae', | ||||
|         ]) | ||||
|         self.assertEqual(len(db.connection.queries), 9) | ||||
|         def test(): | ||||
|             world = Species.objects.all() | ||||
|             families = [o.genus.family.name for o in world] | ||||
|             self.assertEqual(families, [ | ||||
|                 'Drosophilidae', | ||||
|                 'Hominidae', | ||||
|                 'Fabaceae', | ||||
|                 'Amanitacae', | ||||
|             ]) | ||||
|         self.assertNumQueries(9, test) | ||||
|  | ||||
|     def test_list_with_select_related(self): | ||||
|         """ | ||||
|         select_related() also of course applies to entire lists, not just | ||||
|         items. This test verifies the expected behavior with select_related. | ||||
|         """ | ||||
|         world = Species.objects.all().select_related() | ||||
|         families = [o.genus.family.name for o in world] | ||||
|         self.assertEqual(families, [ | ||||
|             'Drosophilidae', | ||||
|             'Hominidae', | ||||
|             'Fabaceae', | ||||
|             'Amanitacae', | ||||
|         ]) | ||||
|         self.assertEqual(len(db.connection.queries), 1) | ||||
|         def test(): | ||||
|             world = Species.objects.all().select_related() | ||||
|             families = [o.genus.family.name for o in world] | ||||
|             self.assertEqual(families, [ | ||||
|                 'Drosophilidae', | ||||
|                 'Hominidae', | ||||
|                 'Fabaceae', | ||||
|                 'Amanitacae', | ||||
|             ]) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_depth(self, depth=1, expected=7): | ||||
|         """ | ||||
|         The "depth" argument to select_related() will stop the descent at a | ||||
|         particular level. | ||||
|         """ | ||||
|         pea = Species.objects.select_related(depth=depth).get(name="sativum") | ||||
|         self.assertEqual( | ||||
|             pea.genus.family.order.klass.phylum.kingdom.domain.name, | ||||
|             'Eukaryota' | ||||
|         ) | ||||
|         def test(): | ||||
|             pea = Species.objects.select_related(depth=depth).get(name="sativum") | ||||
|             self.assertEqual( | ||||
|                 pea.genus.family.order.klass.phylum.kingdom.domain.name, | ||||
|                 'Eukaryota' | ||||
|             ) | ||||
|         # Notice: one fewer queries than above because of depth=1 | ||||
|         self.assertEqual(len(db.connection.queries), expected) | ||||
|         self.assertNumQueries(expected, test) | ||||
|  | ||||
|     def test_larger_depth(self): | ||||
|         """ | ||||
| @@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase): | ||||
|         The "depth" argument to select_related() will stop the descent at a | ||||
|         particular level. This can be used on lists as well. | ||||
|         """ | ||||
|         world = Species.objects.all().select_related(depth=2) | ||||
|         orders = [o.genus.family.order.name for o in world] | ||||
|         self.assertEqual(orders, | ||||
|             ['Diptera', 'Primates', 'Fabales', 'Agaricales']) | ||||
|         self.assertEqual(len(db.connection.queries), 5) | ||||
|         def test(): | ||||
|             world = Species.objects.all().select_related(depth=2) | ||||
|             orders = [o.genus.family.order.name for o in world] | ||||
|             self.assertEqual(orders, | ||||
|                 ['Diptera', 'Primates', 'Fabales', 'Agaricales']) | ||||
|         self.assertNumQueries(5, test) | ||||
|  | ||||
|     def test_select_related_with_extra(self): | ||||
|         s = Species.objects.all().select_related(depth=1)\ | ||||
| @@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase): | ||||
|         In this case, we explicitly say to select the 'genus' and | ||||
|         'genus.family' models, leading to the same number of queries as before. | ||||
|         """ | ||||
|         world = Species.objects.select_related('genus__family') | ||||
|         families = [o.genus.family.name for o in world] | ||||
|         self.assertEqual(families, | ||||
|             ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae']) | ||||
|         self.assertEqual(len(db.connection.queries), 1) | ||||
|         def test(): | ||||
|             world = Species.objects.select_related('genus__family') | ||||
|             families = [o.genus.family.name for o in world] | ||||
|             self.assertEqual(families, | ||||
|                 ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae']) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_more_certain_fields(self): | ||||
|         """ | ||||
|         In this case, we explicitly say to select the 'genus' and | ||||
|         'genus.family' models, leading to the same number of queries as before. | ||||
|         """ | ||||
|         world = Species.objects.filter(genus__name='Amanita')\ | ||||
|             .select_related('genus__family') | ||||
|         orders = [o.genus.family.order.name for o in world] | ||||
|         self.assertEqual(orders, [u'Agaricales']) | ||||
|         self.assertEqual(len(db.connection.queries), 2) | ||||
|         def test(): | ||||
|             world = Species.objects.filter(genus__name='Amanita')\ | ||||
|                 .select_related('genus__family') | ||||
|             orders = [o.genus.family.order.name for o in world] | ||||
|             self.assertEqual(orders, [u'Agaricales']) | ||||
|         self.assertNumQueries(2, test) | ||||
|  | ||||
|     def test_field_traversal(self): | ||||
|         s = Species.objects.all().select_related('genus__family__order' | ||||
|             ).order_by('id')[0:1].get().genus.family.order.name | ||||
|         self.assertEqual(s, u'Diptera') | ||||
|         self.assertEqual(len(db.connection.queries), 1) | ||||
|         def test(): | ||||
|             s = Species.objects.all().select_related('genus__family__order' | ||||
|                 ).order_by('id')[0:1].get().genus.family.order.name | ||||
|             self.assertEqual(s, u'Diptera') | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_depth_fields_fails(self): | ||||
|         self.assertRaises(TypeError, | ||||
|   | ||||
| @@ -2,9 +2,11 @@ import datetime | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import connection | ||||
| from django.test import TestCase | ||||
| from django.utils import unittest | ||||
|  | ||||
| from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate | ||||
| from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, | ||||
|     UniqueForDateModel, ModelToValidate) | ||||
|  | ||||
|  | ||||
| class GetUniqueCheckTests(unittest.TestCase): | ||||
| @@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase): | ||||
|             ), m._get_unique_checks(exclude='start_date') | ||||
|         ) | ||||
|  | ||||
| class PerformUniqueChecksTest(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         # Set debug to True to gain access to connection.queries. | ||||
|         self._old_debug, settings.DEBUG = settings.DEBUG, True | ||||
|         super(PerformUniqueChecksTest, self).setUp() | ||||
|  | ||||
|     def tearDown(self): | ||||
|         # Restore old debug value. | ||||
|         settings.DEBUG = self._old_debug | ||||
|         super(PerformUniqueChecksTest, self).tearDown() | ||||
|  | ||||
| class PerformUniqueChecksTest(TestCase): | ||||
|     def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self): | ||||
|         # Regression test for #12560 | ||||
|         query_count = len(connection.queries) | ||||
|         mtv = ModelToValidate(number=10, name='Some Name') | ||||
|         setattr(mtv, '_adding', True) | ||||
|         mtv.full_clean() | ||||
|         self.assertEqual(query_count, len(connection.queries)) | ||||
|         def test(): | ||||
|             mtv = ModelToValidate(number=10, name='Some Name') | ||||
|             setattr(mtv, '_adding', True) | ||||
|             mtv.full_clean() | ||||
|         self.assertNumQueries(0, test) | ||||
|  | ||||
|     def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self): | ||||
|         # Regression test for #12560 | ||||
|         query_count = len(connection.queries) | ||||
|         mtv = ModelToValidate(number=10, name='Some Name', id=123) | ||||
|         setattr(mtv, '_adding', True) | ||||
|         mtv.full_clean() | ||||
|         self.assertEqual(query_count + 1, len(connection.queries)) | ||||
|         def test(): | ||||
|             mtv = ModelToValidate(number=10, name='Some Name', id=123) | ||||
|             setattr(mtv, '_adding', True) | ||||
|             mtv.full_clean() | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_primary_key_unique_check_not_performed_when_not_adding(self): | ||||
|         # Regression test for #12132 | ||||
|         query_count= len(connection.queries) | ||||
|         mtv = ModelToValidate(number=10, name='Some Name') | ||||
|         mtv.full_clean() | ||||
|         self.assertEqual(query_count, len(connection.queries)) | ||||
|  | ||||
|         def test(): | ||||
|             mtv = ModelToValidate(number=10, name='Some Name') | ||||
|             mtv.full_clean() | ||||
|         self.assertNumQueries(0, test) | ||||
|   | ||||
| @@ -6,7 +6,8 @@ from modeltests.validation.models import Author, Article, ModelToValidate | ||||
|  | ||||
| # Import other tests for this package. | ||||
| from modeltests.validation.validators import TestModelsWithValidators | ||||
| from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest | ||||
| from modeltests.validation.test_unique import (GetUniqueCheckTests, | ||||
|     PerformUniqueChecksTest) | ||||
| from modeltests.validation.test_custom_messages import CustomMessagesTest | ||||
|  | ||||
|  | ||||
| @@ -111,4 +112,3 @@ class ModelFormsTests(TestCase): | ||||
|         article = Article(author_id=self.author.id) | ||||
|         form = ArticleForm(data, instance=article) | ||||
|         self.assertEqual(form.errors.keys(), ['pub_date']) | ||||
|  | ||||
|   | ||||
| @@ -11,17 +11,6 @@ from models import ResolveThis, Item, RelatedItem, Child, Leaf | ||||
|  | ||||
|  | ||||
| class DeferRegressionTest(TestCase): | ||||
|     def assert_num_queries(self, n, func, *args, **kwargs): | ||||
|         old_DEBUG = settings.DEBUG | ||||
|         settings.DEBUG = True | ||||
|         starting_queries = len(connection.queries) | ||||
|         try: | ||||
|             func(*args, **kwargs) | ||||
|         finally: | ||||
|             settings.DEBUG = old_DEBUG | ||||
|         self.assertEqual(starting_queries + n, len(connection.queries)) | ||||
|  | ||||
|  | ||||
|     def test_basic(self): | ||||
|         # Deferred fields should really be deferred and not accidentally use | ||||
|         # the field's default value just because they aren't passed to __init__ | ||||
| @@ -33,19 +22,19 @@ class DeferRegressionTest(TestCase): | ||||
|         def test(): | ||||
|             self.assertEqual(obj.name, "first") | ||||
|             self.assertEqual(obj.other_value, 0) | ||||
|         self.assert_num_queries(0, test) | ||||
|         self.assertNumQueries(0, test) | ||||
|  | ||||
|         def test(): | ||||
|             self.assertEqual(obj.value, 42) | ||||
|         self.assert_num_queries(1, test) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|         def test(): | ||||
|             self.assertEqual(obj.text, "xyzzy") | ||||
|         self.assert_num_queries(1, test) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|         def test(): | ||||
|             self.assertEqual(obj.text, "xyzzy") | ||||
|         self.assert_num_queries(0, test) | ||||
|         self.assertNumQueries(0, test) | ||||
|  | ||||
|         # Regression test for #10695. Make sure different instances don't | ||||
|         # inadvertently share data in the deferred descriptor objects. | ||||
|   | ||||
| @@ -1,10 +1,9 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import datetime | ||||
| import tempfile | ||||
| import shutil | ||||
| import tempfile | ||||
|  | ||||
| from django.db import models, connection | ||||
| from django.conf import settings | ||||
| from django.db import models | ||||
| # Can't import as "forms" due to implementation details in the test suite (the | ||||
| # current file is called "forms" and is already imported). | ||||
| from django import forms as django_forms | ||||
| @@ -77,19 +76,13 @@ class TestTicket12510(TestCase): | ||||
|     ''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). ''' | ||||
|     def setUp(self): | ||||
|         self.groups = [Group.objects.create(name=name) for name in 'abc'] | ||||
|         self.old_debug = settings.DEBUG | ||||
|         # turn debug on to get access to connection.queries | ||||
|         settings.DEBUG = True | ||||
|  | ||||
|     def tearDown(self): | ||||
|         settings.DEBUG = self.old_debug | ||||
|  | ||||
|     def test_choices_not_fetched_when_not_rendering(self): | ||||
|         initial_queries = len(connection.queries) | ||||
|         field = django_forms.ModelChoiceField(Group.objects.order_by('-name')) | ||||
|         self.assertEqual('a', field.clean(self.groups[0].pk).name) | ||||
|         def test(): | ||||
|             field = django_forms.ModelChoiceField(Group.objects.order_by('-name')) | ||||
|             self.assertEqual('a', field.clean(self.groups[0].pk).name) | ||||
|         # only one query is required to pull the model from DB | ||||
|         self.assertEqual(initial_queries+1, len(connection.queries)) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
| class ModelFormCallableModelDefault(TestCase): | ||||
|     def test_no_empty_option(self): | ||||
|   | ||||
| @@ -1,10 +1,8 @@ | ||||
| import unittest | ||||
| from datetime import date | ||||
|  | ||||
| from django import db | ||||
| from django import forms | ||||
| from django.forms.models import modelform_factory, ModelChoiceField | ||||
| from django.conf import settings | ||||
| from django.test import TestCase | ||||
| from django.core.exceptions import FieldError, ValidationError | ||||
| from django.core.files.uploadedfile import SimpleUploadedFile | ||||
| @@ -14,14 +12,6 @@ from models import Person, RealPerson, Triple, FilePathModel, Article, \ | ||||
|  | ||||
|  | ||||
| class ModelMultipleChoiceFieldTests(TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.old_debug = settings.DEBUG | ||||
|         settings.DEBUG = True | ||||
|  | ||||
|     def tearDown(self): | ||||
|         settings.DEBUG = self.old_debug | ||||
|  | ||||
|     def test_model_multiple_choice_number_of_queries(self): | ||||
|         """ | ||||
|         Test that ModelMultipleChoiceField does O(1) queries instead of | ||||
| @@ -30,10 +20,8 @@ class ModelMultipleChoiceFieldTests(TestCase): | ||||
|         for i in range(30): | ||||
|             Person.objects.create(name="Person %s" % i) | ||||
|  | ||||
|         db.reset_queries() | ||||
|         f = forms.ModelMultipleChoiceField(queryset=Person.objects.all()) | ||||
|         selected = f.clean([1, 3, 5, 7, 9]) | ||||
|         self.assertEquals(len(db.connection.queries), 1) | ||||
|         self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9]) | ||||
|  | ||||
| class TripleForm(forms.ModelForm): | ||||
|     class Meta: | ||||
|   | ||||
| @@ -7,11 +7,6 @@ from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, | ||||
|  | ||||
| class ReverseSelectRelatedTestCase(TestCase): | ||||
|     def setUp(self): | ||||
|         # Explicitly enable debug for these tests - we need to count | ||||
|         # the queries that have been issued. | ||||
|         self.old_debug = settings.DEBUG | ||||
|         settings.DEBUG = True | ||||
|  | ||||
|         user = User.objects.create(username="test") | ||||
|         userprofile = UserProfile.objects.create(user=user, state="KS", | ||||
|                                                  city="Lawrence") | ||||
| @@ -26,65 +21,66 @@ class ReverseSelectRelatedTestCase(TestCase): | ||||
|                                                   results=results2) | ||||
|         StatDetails.objects.create(base_stats=advstat, comments=250) | ||||
|  | ||||
|         db.reset_queries() | ||||
|  | ||||
|     def assertQueries(self, queries): | ||||
|         self.assertEqual(len(db.connection.queries), queries) | ||||
|  | ||||
|     def tearDown(self): | ||||
|         settings.DEBUG = self.old_debug | ||||
|  | ||||
|     def test_basic(self): | ||||
|         u = User.objects.select_related("userprofile").get(username="test") | ||||
|         self.assertEqual(u.userprofile.state, "KS") | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             u = User.objects.select_related("userprofile").get(username="test") | ||||
|             self.assertEqual(u.userprofile.state, "KS") | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_follow_next_level(self): | ||||
|         u = User.objects.select_related("userstat__results").get(username="test") | ||||
|         self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertEqual(u.userstat.results.results, 'first results') | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             u = User.objects.select_related("userstat__results").get(username="test") | ||||
|             self.assertEqual(u.userstat.posts, 150) | ||||
|             self.assertEqual(u.userstat.results.results, 'first results') | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_follow_two(self): | ||||
|         u = User.objects.select_related("userprofile", "userstat").get(username="test") | ||||
|         self.assertEqual(u.userprofile.state, "KS") | ||||
|         self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             u = User.objects.select_related("userprofile", "userstat").get(username="test") | ||||
|             self.assertEqual(u.userprofile.state, "KS") | ||||
|             self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_follow_two_next_level(self): | ||||
|         u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") | ||||
|         self.assertEqual(u.userstat.results.results, 'first results') | ||||
|         self.assertEqual(u.userstat.statdetails.comments, 259) | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") | ||||
|             self.assertEqual(u.userstat.results.results, 'first results') | ||||
|             self.assertEqual(u.userstat.statdetails.comments, 259) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_forward_and_back(self): | ||||
|         stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") | ||||
|         self.assertEqual(stat.user.userprofile.state, 'KS') | ||||
|         self.assertEqual(stat.user.userstat.posts, 150) | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") | ||||
|             self.assertEqual(stat.user.userprofile.state, 'KS') | ||||
|             self.assertEqual(stat.user.userstat.posts, 150) | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_back_and_forward(self): | ||||
|         u = User.objects.select_related("userstat").get(username="test") | ||||
|         self.assertEqual(u.userstat.user.username, 'test') | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             u = User.objects.select_related("userstat").get(username="test") | ||||
|             self.assertEqual(u.userstat.user.username, 'test') | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_not_followed_by_default(self): | ||||
|         u = User.objects.select_related().get(username="test") | ||||
|         self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertQueries(2) | ||||
|         def test(): | ||||
|             u = User.objects.select_related().get(username="test") | ||||
|             self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertNumQueries(2, test) | ||||
|  | ||||
|     def test_follow_from_child_class(self): | ||||
|         stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) | ||||
|         self.assertEqual(stat.statdetails.comments, 250) | ||||
|         self.assertEqual(stat.user.username, 'bob') | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) | ||||
|             self.assertEqual(stat.statdetails.comments, 250) | ||||
|             self.assertEqual(stat.user.username, 'bob') | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_follow_inheritance(self): | ||||
|         stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) | ||||
|         self.assertEqual(stat.advanceduserstat.posts, 200) | ||||
|         self.assertEqual(stat.user.username, 'bob') | ||||
|         self.assertEqual(stat.advanceduserstat.user.username, 'bob') | ||||
|         self.assertQueries(1) | ||||
|         def test(): | ||||
|             stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) | ||||
|             self.assertEqual(stat.advanceduserstat.posts, 200) | ||||
|             self.assertEqual(stat.user.username, 'bob') | ||||
|             self.assertEqual(stat.advanceduserstat.user.username, 'bob') | ||||
|         self.assertNumQueries(1, test) | ||||
|  | ||||
|     def test_nullable_relation(self): | ||||
|         im = Image.objects.create(name="imag1") | ||||
|   | ||||
| @@ -0,0 +1,5 @@ | ||||
| from django.db import models | ||||
|  | ||||
|  | ||||
| class Person(models.Model): | ||||
|     name = models.CharField(max_length=100) | ||||
|   | ||||
							
								
								
									
										30
									
								
								tests/regressiontests/test_utils/python_25.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tests/regressiontests/test_utils/python_25.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| from __future__ import with_statement | ||||
|  | ||||
| from django.test import TestCase | ||||
|  | ||||
| from models import Person | ||||
|  | ||||
|  | ||||
| class AssertNumQueriesTests(TestCase): | ||||
|     def test_simple(self): | ||||
|         with self.assertNumQueries(0): | ||||
|             pass | ||||
|  | ||||
|         with self.assertNumQueries(1): | ||||
|             # Guy who wrote Linux | ||||
|             Person.objects.create(name="Linus Torvalds") | ||||
|  | ||||
|         with self.assertNumQueries(2): | ||||
|             # Guy who owns the bagel place I like | ||||
|             Person.objects.create(name="Uncle Ricky") | ||||
|             self.assertEqual(Person.objects.count(), 2) | ||||
|  | ||||
|     def test_failure(self): | ||||
|         with self.assertRaises(AssertionError) as exc_info: | ||||
|             with self.assertNumQueries(2): | ||||
|                 Person.objects.count() | ||||
|         self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected") | ||||
|  | ||||
|         with self.assertRaises(TypeError): | ||||
|             with self.assertNumQueries(4000): | ||||
|                 raise TypeError | ||||
| @@ -1,4 +1,10 @@ | ||||
| r""" | ||||
| import sys | ||||
|  | ||||
| if sys.version_info >= (2, 5): | ||||
|     from python_25 import AssertNumQueriesTests | ||||
|  | ||||
|  | ||||
| __test__ = {"API_TEST": r""" | ||||
| # Some checks of the doctest output normalizer. | ||||
| # Standard doctests do fairly | ||||
| >>> from django.utils import simplejson | ||||
| @@ -69,4 +75,4 @@ r""" | ||||
| >>> produce_xml_fragment() | ||||
| '<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>' | ||||
|  | ||||
| """ | ||||
| """} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user