1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Refs #31117 -- Made various tests properly handle unexpected databases aliases.

- Used selected "databases" instead of django.db.connections.
- Made routers in tests.migrations skip migrations on unexpected
  databases.
- Added DiscoverRunnerGetDatabasesTests.assertSkippedDatabases() hook
  which properly asserts messages about skipped databases.
This commit is contained in:
Matthijs Kooijman
2019-12-20 20:49:56 +01:00
committed by Mariusz Felisiak
parent 26be703fe6
commit 2a2ea4ee18
9 changed files with 46 additions and 34 deletions

View File

@@ -2,7 +2,6 @@ from unittest import mock
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import connections
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import path, reverse from django.urls import path, reverse
@@ -34,7 +33,7 @@ class MultiDatabaseTests(TestCase):
def setUpTestData(cls): def setUpTestData(cls):
cls.superusers = {} cls.superusers = {}
cls.test_book_ids = {} cls.test_book_ids = {}
for db in connections: for db in cls.databases:
Router.target_db = db Router.target_db = db
cls.superusers[db] = User.objects.create_superuser( cls.superusers[db] = User.objects.create_superuser(
username='admin', password='something', email='test@test.org', username='admin', password='something', email='test@test.org',
@@ -45,7 +44,7 @@ class MultiDatabaseTests(TestCase):
@mock.patch('django.contrib.admin.options.transaction') @mock.patch('django.contrib.admin.options.transaction')
def test_add_view(self, mock): def test_add_view(self, mock):
for db in connections: for db in self.databases:
with self.subTest(db=db): with self.subTest(db=db):
Router.target_db = db Router.target_db = db
self.client.force_login(self.superusers[db]) self.client.force_login(self.superusers[db])
@@ -57,7 +56,7 @@ class MultiDatabaseTests(TestCase):
@mock.patch('django.contrib.admin.options.transaction') @mock.patch('django.contrib.admin.options.transaction')
def test_change_view(self, mock): def test_change_view(self, mock):
for db in connections: for db in self.databases:
with self.subTest(db=db): with self.subTest(db=db):
Router.target_db = db Router.target_db = db
self.client.force_login(self.superusers[db]) self.client.force_login(self.superusers[db])
@@ -69,7 +68,7 @@ class MultiDatabaseTests(TestCase):
@mock.patch('django.contrib.admin.options.transaction') @mock.patch('django.contrib.admin.options.transaction')
def test_delete_view(self, mock): def test_delete_view(self, mock):
for db in connections: for db in self.databases:
with self.subTest(db=db): with self.subTest(db=db):
Router.target_db = db Router.target_db = db
self.client.force_login(self.superusers[db]) self.client.force_login(self.superusers[db])

View File

@@ -3,7 +3,6 @@ from unittest import mock
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.admin import UserAdmin from django.contrib.auth.admin import UserAdmin
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import connections
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import path, reverse from django.urls import path, reverse
@@ -32,7 +31,7 @@ class MultiDatabaseTests(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
cls.superusers = {} cls.superusers = {}
for db in connections: for db in cls.databases:
Router.target_db = db Router.target_db = db
cls.superusers[db] = User.objects.create_superuser( cls.superusers[db] = User.objects.create_superuser(
username='admin', password='something', email='test@test.org', username='admin', password='something', email='test@test.org',
@@ -40,7 +39,7 @@ class MultiDatabaseTests(TestCase):
@mock.patch('django.contrib.auth.admin.transaction') @mock.patch('django.contrib.auth.admin.transaction')
def test_add_view(self, mock): def test_add_view(self, mock):
for db in connections: for db in self.databases:
with self.subTest(db_connection=db): with self.subTest(db_connection=db):
Router.target_db = db Router.target_db = db
self.client.force_login(self.superusers[db]) self.client.force_login(self.superusers[db])

View File

@@ -687,11 +687,15 @@ class ThreadTests(TransactionTestCase):
conn.inc_thread_sharing() conn.inc_thread_sharing()
connections_dict[id(conn)] = conn connections_dict[id(conn)] = conn
try: try:
for x in range(2): num_new_threads = 2
for x in range(num_new_threads):
t = threading.Thread(target=runner) t = threading.Thread(target=runner)
t.start() t.start()
t.join() t.join()
self.assertEqual(len(connections_dict), 6) self.assertEqual(
len(connections_dict),
len(connections.all()) * (num_new_threads + 1),
)
finally: finally:
# Finish by closing the connections opened by the other threads # Finish by closing the connections opened by the other threads
# (the connection opened in the main thread will automatically be # (the connection opened in the main thread will automatically be

View File

@@ -1,5 +1,6 @@
class EmptyRouter: class DefaultOtherRouter:
pass def allow_migrate(self, db, app_label, model_name=None, **hints):
return db in {'default', 'other'}
class TestRouter: class TestRouter:
@@ -9,5 +10,5 @@ class TestRouter:
""" """
if model_name == 'tribble': if model_name == 'tribble':
return db == 'other' return db == 'other'
elif db == 'other': elif db != 'default':
return False return False

View File

@@ -24,7 +24,7 @@ class MigrationTestBase(TransactionTestCase):
def tearDown(self): def tearDown(self):
# Reset applied-migrations state. # Reset applied-migrations state.
for db in connections: for db in self.databases:
recorder = MigrationRecorder(connections[db]) recorder = MigrationRecorder(connections[db])
recorder.migration_qs.filter(app='migrations').delete() recorder.migration_qs.filter(app='migrations').delete()

View File

@@ -129,7 +129,7 @@ class MigrateTests(MigrationTestBase):
that check. that check.
""" """
# Make sure no tables are created # Make sure no tables are created
for db in connections: for db in self.databases:
self.assertTableNotExists("migrations_author", using=db) self.assertTableNotExists("migrations_author", using=db)
self.assertTableNotExists("migrations_tribble", using=db) self.assertTableNotExists("migrations_tribble", using=db)
# Run the migrations to 0001 only # Run the migrations to 0001 only
@@ -192,7 +192,7 @@ class MigrateTests(MigrationTestBase):
call_command("migrate", "migrations", "zero", verbosity=0) call_command("migrate", "migrations", "zero", verbosity=0)
call_command("migrate", "migrations", "zero", verbosity=0, database="other") call_command("migrate", "migrations", "zero", verbosity=0, database="other")
# Make sure it's all gone # Make sure it's all gone
for db in connections: for db in self.databases:
self.assertTableNotExists("migrations_author", using=db) self.assertTableNotExists("migrations_author", using=db)
self.assertTableNotExists("migrations_tribble", using=db) self.assertTableNotExists("migrations_tribble", using=db)
self.assertTableNotExists("migrations_book", using=db) self.assertTableNotExists("migrations_book", using=db)
@@ -931,7 +931,7 @@ class MakeMigrationsTests(MigrationTestBase):
# With a router that doesn't prohibit migrating 'other', # With a router that doesn't prohibit migrating 'other',
# consistency is checked. # consistency is checked.
with self.settings(DATABASE_ROUTERS=['migrations.routers.EmptyRouter']): with self.settings(DATABASE_ROUTERS=['migrations.routers.DefaultOtherRouter']):
with self.assertRaisesMessage(Exception, 'Other connection'): with self.assertRaisesMessage(Exception, 'Other connection'):
call_command('makemigrations', 'migrations', verbosity=0) call_command('makemigrations', 'migrations', verbosity=0)
self.assertEqual(has_table.call_count, 4) # 'default' and 'other' self.assertEqual(has_table.call_count, 4) # 'default' and 'other'
@@ -944,12 +944,14 @@ class MakeMigrationsTests(MigrationTestBase):
allow_migrate.assert_any_call('other', 'migrations', model_name='UnicodeModel') allow_migrate.assert_any_call('other', 'migrations', model_name='UnicodeModel')
# allow_migrate() is called with the correct arguments. # allow_migrate() is called with the correct arguments.
self.assertGreater(len(allow_migrate.mock_calls), 0) self.assertGreater(len(allow_migrate.mock_calls), 0)
called_aliases = set()
for mock_call in allow_migrate.mock_calls: for mock_call in allow_migrate.mock_calls:
_, call_args, call_kwargs = mock_call _, call_args, call_kwargs = mock_call
connection_alias, app_name = call_args connection_alias, app_name = call_args
self.assertIn(connection_alias, ['default', 'other']) called_aliases.add(connection_alias)
# Raises an error if invalid app_name/model_name occurs. # Raises an error if invalid app_name/model_name occurs.
apps.get_app_config(app_name).get_model(call_kwargs['model_name']) apps.get_app_config(app_name).get_model(call_kwargs['model_name'])
self.assertEqual(called_aliases, set(connections))
self.assertEqual(has_table.call_count, 4) self.assertEqual(has_table.call_count, 4)
def test_failing_migration(self): def test_failing_migration(self):

View File

@@ -7,7 +7,7 @@ from unittest.mock import Mock
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core import management from django.core import management
from django.db import DEFAULT_DB_ALIAS, connections, router, transaction from django.db import DEFAULT_DB_ALIAS, router, transaction
from django.db.models import signals from django.db.models import signals
from django.db.utils import ConnectionRouter from django.db.utils import ConnectionRouter
from django.test import SimpleTestCase, TestCase, override_settings from django.test import SimpleTestCase, TestCase, override_settings
@@ -1632,7 +1632,7 @@ class PickleQuerySetTestCase(TestCase):
databases = {'default', 'other'} databases = {'default', 'other'}
def test_pickling(self): def test_pickling(self):
for db in connections: for db in self.databases:
Book.objects.using(db).create(title='Dive into Python', published=datetime.date(2009, 5, 4)) Book.objects.using(db).create(title='Dive into Python', published=datetime.date(2009, 5, 4))
qs = Book.objects.all() qs = Book.objects.all()
self.assertEqual(qs.db, pickle.loads(pickle.dumps(qs)).db) self.assertEqual(qs.db, pickle.loads(pickle.dumps(qs)).db)

View File

@@ -308,6 +308,15 @@ class DiscoverRunnerGetDatabasesTests(SimpleTestCase):
databases = self.runner.get_databases(suite) databases = self.runner.get_databases(suite)
return databases, stdout.getvalue() return databases, stdout.getvalue()
def assertSkippedDatabases(self, test_labels, expected_databases):
databases, output = self.get_databases(test_labels)
self.assertEqual(databases, expected_databases)
skipped_databases = set(connections) - expected_databases
if skipped_databases:
self.assertIn(self.skip_msg + ', '.join(sorted(skipped_databases)), output)
else:
self.assertNotIn(self.skip_msg, output)
def test_mixed(self): def test_mixed(self):
databases, output = self.get_databases(['test_runner_apps.databases.tests']) databases, output = self.get_databases(['test_runner_apps.databases.tests'])
self.assertEqual(databases, set(connections)) self.assertEqual(databases, set(connections))
@@ -319,24 +328,22 @@ class DiscoverRunnerGetDatabasesTests(SimpleTestCase):
self.assertNotIn(self.skip_msg, output) self.assertNotIn(self.skip_msg, output)
def test_default_and_other(self): def test_default_and_other(self):
databases, output = self.get_databases([ self.assertSkippedDatabases([
'test_runner_apps.databases.tests.DefaultDatabaseTests', 'test_runner_apps.databases.tests.DefaultDatabaseTests',
'test_runner_apps.databases.tests.OtherDatabaseTests', 'test_runner_apps.databases.tests.OtherDatabaseTests',
]) ], {'default', 'other'})
self.assertEqual(databases, set(connections))
self.assertNotIn(self.skip_msg, output)
def test_default_only(self): def test_default_only(self):
databases, output = self.get_databases(['test_runner_apps.databases.tests.DefaultDatabaseTests']) self.assertSkippedDatabases([
self.assertEqual(databases, {'default'}) 'test_runner_apps.databases.tests.DefaultDatabaseTests',
self.assertIn(self.skip_msg + 'other', output) ], {'default'})
def test_other_only(self): def test_other_only(self):
databases, output = self.get_databases(['test_runner_apps.databases.tests.OtherDatabaseTests']) self.assertSkippedDatabases([
self.assertEqual(databases, {'other'}) 'test_runner_apps.databases.tests.OtherDatabaseTests'
self.assertIn(self.skip_msg + 'default', output) ], {'other'})
def test_no_databases_required(self): def test_no_databases_required(self):
databases, output = self.get_databases(['test_runner_apps.databases.tests.NoDatabaseTests']) self.assertSkippedDatabases([
self.assertEqual(databases, set()) 'test_runner_apps.databases.tests.NoDatabaseTests'
self.assertIn(self.skip_msg + 'default, other', output) ], set())

View File

@@ -44,7 +44,7 @@ class TransactionTestCaseDatabasesTests(TestCase):
so that it's less likely to overflow. An overflow causes so that it's less likely to overflow. An overflow causes
assertNumQueries() to fail. assertNumQueries() to fail.
""" """
for alias in connections: for alias in self.databases:
self.assertEqual(len(connections[alias].queries_log), 0, 'Failed for alias %s' % alias) self.assertEqual(len(connections[alias].queries_log), 0, 'Failed for alias %s' % alias)