From 2a2ea4ee18fdcf2c95bf6435bc63b74623e3085b Mon Sep 17 00:00:00 2001 From: Matthijs Kooijman Date: Fri, 20 Dec 2019 20:49:56 +0100 Subject: [PATCH] Refs #31117 -- Made various tests properly handle unexpected databases aliases. - Used selected "databases" instead of django.db.connections. - Made routers in tests.migrations skip migrations on unexpected databases. - Added DiscoverRunnerGetDatabasesTests.assertSkippedDatabases() hook which properly asserts messages about skipped databases. --- tests/admin_views/test_multidb.py | 9 +++--- tests/auth_tests/test_admin_multidb.py | 5 ++- tests/backends/tests.py | 8 +++-- tests/migrations/routers.py | 7 +++-- tests/migrations/test_base.py | 2 +- tests/migrations/test_commands.py | 10 +++--- tests/multiple_database/tests.py | 4 +-- tests/test_runner/test_discover_runner.py | 33 ++++++++++++-------- tests/test_utils/test_transactiontestcase.py | 2 +- 9 files changed, 46 insertions(+), 34 deletions(-) diff --git a/tests/admin_views/test_multidb.py b/tests/admin_views/test_multidb.py index a02b637d34..7ff1f7f663 100644 --- a/tests/admin_views/test_multidb.py +++ b/tests/admin_views/test_multidb.py @@ -2,7 +2,6 @@ from unittest import mock from django.contrib import admin from django.contrib.auth.models import User -from django.db import connections from django.test import TestCase, override_settings from django.urls import path, reverse @@ -34,7 +33,7 @@ class MultiDatabaseTests(TestCase): def setUpTestData(cls): cls.superusers = {} cls.test_book_ids = {} - for db in connections: + for db in cls.databases: Router.target_db = db cls.superusers[db] = User.objects.create_superuser( username='admin', password='something', email='test@test.org', @@ -45,7 +44,7 @@ class MultiDatabaseTests(TestCase): @mock.patch('django.contrib.admin.options.transaction') def test_add_view(self, mock): - for db in connections: + for db in self.databases: with self.subTest(db=db): Router.target_db = db self.client.force_login(self.superusers[db]) @@ -57,7 +56,7 @@ class MultiDatabaseTests(TestCase): @mock.patch('django.contrib.admin.options.transaction') def test_change_view(self, mock): - for db in connections: + for db in self.databases: with self.subTest(db=db): Router.target_db = db self.client.force_login(self.superusers[db]) @@ -69,7 +68,7 @@ class MultiDatabaseTests(TestCase): @mock.patch('django.contrib.admin.options.transaction') def test_delete_view(self, mock): - for db in connections: + for db in self.databases: with self.subTest(db=db): Router.target_db = db self.client.force_login(self.superusers[db]) diff --git a/tests/auth_tests/test_admin_multidb.py b/tests/auth_tests/test_admin_multidb.py index 5849ef98e5..fac2c0fc7e 100644 --- a/tests/auth_tests/test_admin_multidb.py +++ b/tests/auth_tests/test_admin_multidb.py @@ -3,7 +3,6 @@ from unittest import mock from django.contrib import admin from django.contrib.auth.admin import UserAdmin from django.contrib.auth.models import User -from django.db import connections from django.test import TestCase, override_settings from django.urls import path, reverse @@ -32,7 +31,7 @@ class MultiDatabaseTests(TestCase): @classmethod def setUpTestData(cls): cls.superusers = {} - for db in connections: + for db in cls.databases: Router.target_db = db cls.superusers[db] = User.objects.create_superuser( username='admin', password='something', email='test@test.org', @@ -40,7 +39,7 @@ class MultiDatabaseTests(TestCase): @mock.patch('django.contrib.auth.admin.transaction') def test_add_view(self, mock): - for db in connections: + for db in self.databases: with self.subTest(db_connection=db): Router.target_db = db self.client.force_login(self.superusers[db]) diff --git a/tests/backends/tests.py b/tests/backends/tests.py index da20d94442..918fc32166 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -687,11 +687,15 @@ class ThreadTests(TransactionTestCase): conn.inc_thread_sharing() connections_dict[id(conn)] = conn try: - for x in range(2): + num_new_threads = 2 + for x in range(num_new_threads): t = threading.Thread(target=runner) t.start() t.join() - self.assertEqual(len(connections_dict), 6) + self.assertEqual( + len(connections_dict), + len(connections.all()) * (num_new_threads + 1), + ) finally: # Finish by closing the connections opened by the other threads # (the connection opened in the main thread will automatically be diff --git a/tests/migrations/routers.py b/tests/migrations/routers.py index 21dfc561bd..bc036382a7 100644 --- a/tests/migrations/routers.py +++ b/tests/migrations/routers.py @@ -1,5 +1,6 @@ -class EmptyRouter: - pass +class DefaultOtherRouter: + def allow_migrate(self, db, app_label, model_name=None, **hints): + return db in {'default', 'other'} class TestRouter: @@ -9,5 +10,5 @@ class TestRouter: """ if model_name == 'tribble': return db == 'other' - elif db == 'other': + elif db != 'default': return False diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py index 45c5472b0f..c4c8b1ee6c 100644 --- a/tests/migrations/test_base.py +++ b/tests/migrations/test_base.py @@ -24,7 +24,7 @@ class MigrationTestBase(TransactionTestCase): def tearDown(self): # Reset applied-migrations state. - for db in connections: + for db in self.databases: recorder = MigrationRecorder(connections[db]) recorder.migration_qs.filter(app='migrations').delete() diff --git a/tests/migrations/test_commands.py b/tests/migrations/test_commands.py index 5f57dc7cad..b98b77cd99 100644 --- a/tests/migrations/test_commands.py +++ b/tests/migrations/test_commands.py @@ -129,7 +129,7 @@ class MigrateTests(MigrationTestBase): that check. """ # Make sure no tables are created - for db in connections: + for db in self.databases: self.assertTableNotExists("migrations_author", using=db) self.assertTableNotExists("migrations_tribble", using=db) # Run the migrations to 0001 only @@ -192,7 +192,7 @@ class MigrateTests(MigrationTestBase): call_command("migrate", "migrations", "zero", verbosity=0) call_command("migrate", "migrations", "zero", verbosity=0, database="other") # Make sure it's all gone - for db in connections: + for db in self.databases: self.assertTableNotExists("migrations_author", using=db) self.assertTableNotExists("migrations_tribble", using=db) self.assertTableNotExists("migrations_book", using=db) @@ -931,7 +931,7 @@ class MakeMigrationsTests(MigrationTestBase): # With a router that doesn't prohibit migrating 'other', # consistency is checked. - with self.settings(DATABASE_ROUTERS=['migrations.routers.EmptyRouter']): + with self.settings(DATABASE_ROUTERS=['migrations.routers.DefaultOtherRouter']): with self.assertRaisesMessage(Exception, 'Other connection'): call_command('makemigrations', 'migrations', verbosity=0) self.assertEqual(has_table.call_count, 4) # 'default' and 'other' @@ -944,12 +944,14 @@ class MakeMigrationsTests(MigrationTestBase): allow_migrate.assert_any_call('other', 'migrations', model_name='UnicodeModel') # allow_migrate() is called with the correct arguments. self.assertGreater(len(allow_migrate.mock_calls), 0) + called_aliases = set() for mock_call in allow_migrate.mock_calls: _, call_args, call_kwargs = mock_call connection_alias, app_name = call_args - self.assertIn(connection_alias, ['default', 'other']) + called_aliases.add(connection_alias) # Raises an error if invalid app_name/model_name occurs. apps.get_app_config(app_name).get_model(call_kwargs['model_name']) + self.assertEqual(called_aliases, set(connections)) self.assertEqual(has_table.call_count, 4) def test_failing_migration(self): diff --git a/tests/multiple_database/tests.py b/tests/multiple_database/tests.py index e39967ce78..2a6a725472 100644 --- a/tests/multiple_database/tests.py +++ b/tests/multiple_database/tests.py @@ -7,7 +7,7 @@ from unittest.mock import Mock from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from django.core import management -from django.db import DEFAULT_DB_ALIAS, connections, router, transaction +from django.db import DEFAULT_DB_ALIAS, router, transaction from django.db.models import signals from django.db.utils import ConnectionRouter from django.test import SimpleTestCase, TestCase, override_settings @@ -1632,7 +1632,7 @@ class PickleQuerySetTestCase(TestCase): databases = {'default', 'other'} def test_pickling(self): - for db in connections: + for db in self.databases: Book.objects.using(db).create(title='Dive into Python', published=datetime.date(2009, 5, 4)) qs = Book.objects.all() self.assertEqual(qs.db, pickle.loads(pickle.dumps(qs)).db) diff --git a/tests/test_runner/test_discover_runner.py b/tests/test_runner/test_discover_runner.py index 186dc52a44..61af22d818 100644 --- a/tests/test_runner/test_discover_runner.py +++ b/tests/test_runner/test_discover_runner.py @@ -308,6 +308,15 @@ class DiscoverRunnerGetDatabasesTests(SimpleTestCase): databases = self.runner.get_databases(suite) return databases, stdout.getvalue() + def assertSkippedDatabases(self, test_labels, expected_databases): + databases, output = self.get_databases(test_labels) + self.assertEqual(databases, expected_databases) + skipped_databases = set(connections) - expected_databases + if skipped_databases: + self.assertIn(self.skip_msg + ', '.join(sorted(skipped_databases)), output) + else: + self.assertNotIn(self.skip_msg, output) + def test_mixed(self): databases, output = self.get_databases(['test_runner_apps.databases.tests']) self.assertEqual(databases, set(connections)) @@ -319,24 +328,22 @@ class DiscoverRunnerGetDatabasesTests(SimpleTestCase): self.assertNotIn(self.skip_msg, output) def test_default_and_other(self): - databases, output = self.get_databases([ + self.assertSkippedDatabases([ 'test_runner_apps.databases.tests.DefaultDatabaseTests', 'test_runner_apps.databases.tests.OtherDatabaseTests', - ]) - self.assertEqual(databases, set(connections)) - self.assertNotIn(self.skip_msg, output) + ], {'default', 'other'}) def test_default_only(self): - databases, output = self.get_databases(['test_runner_apps.databases.tests.DefaultDatabaseTests']) - self.assertEqual(databases, {'default'}) - self.assertIn(self.skip_msg + 'other', output) + self.assertSkippedDatabases([ + 'test_runner_apps.databases.tests.DefaultDatabaseTests', + ], {'default'}) def test_other_only(self): - databases, output = self.get_databases(['test_runner_apps.databases.tests.OtherDatabaseTests']) - self.assertEqual(databases, {'other'}) - self.assertIn(self.skip_msg + 'default', output) + self.assertSkippedDatabases([ + 'test_runner_apps.databases.tests.OtherDatabaseTests' + ], {'other'}) def test_no_databases_required(self): - databases, output = self.get_databases(['test_runner_apps.databases.tests.NoDatabaseTests']) - self.assertEqual(databases, set()) - self.assertIn(self.skip_msg + 'default, other', output) + self.assertSkippedDatabases([ + 'test_runner_apps.databases.tests.NoDatabaseTests' + ], set()) diff --git a/tests/test_utils/test_transactiontestcase.py b/tests/test_utils/test_transactiontestcase.py index 3a9d173138..4e183a5196 100644 --- a/tests/test_utils/test_transactiontestcase.py +++ b/tests/test_utils/test_transactiontestcase.py @@ -44,7 +44,7 @@ class TransactionTestCaseDatabasesTests(TestCase): so that it's less likely to overflow. An overflow causes assertNumQueries() to fail. """ - for alias in connections: + for alias in self.databases: self.assertEqual(len(connections[alias].queries_log), 0, 'Failed for alias %s' % alias)