mirror of
https://github.com/django/django.git
synced 2025-03-13 10:50:55 +00:00
Fixed #33277 -- Disallowed database connections in threads in SimpleTestCase.
This commit is contained in:
parent
45f778eded
commit
8fb0be3500
@ -10,6 +10,7 @@ from contextlib import contextmanager
|
|||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from difflib import get_close_matches
|
from difflib import get_close_matches
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from unittest import mock
|
||||||
from unittest.suite import _DebugResult
|
from unittest.suite import _DebugResult
|
||||||
from unittest.util import safe_repr
|
from unittest.util import safe_repr
|
||||||
from urllib.parse import (
|
from urllib.parse import (
|
||||||
@ -37,6 +38,7 @@ from django.core.management.sql import emit_post_migrate_signal
|
|||||||
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
|
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
|
||||||
from django.core.signals import setting_changed
|
from django.core.signals import setting_changed
|
||||||
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
|
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
|
||||||
|
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
||||||
from django.forms.fields import CharField
|
from django.forms.fields import CharField
|
||||||
from django.http import QueryDict
|
from django.http import QueryDict
|
||||||
from django.http.request import split_domain_port, validate_host
|
from django.http.request import split_domain_port, validate_host
|
||||||
@ -255,6 +257,13 @@ class SimpleTestCase(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
method = getattr(connection, name)
|
method = getattr(connection, name)
|
||||||
setattr(connection, name, _DatabaseFailure(method, message))
|
setattr(connection, name, _DatabaseFailure(method, message))
|
||||||
|
cls.enterClassContext(
|
||||||
|
mock.patch.object(
|
||||||
|
BaseDatabaseWrapper,
|
||||||
|
"ensure_connection",
|
||||||
|
new=cls.ensure_connection_patch_method(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _remove_databases_failures(cls):
|
def _remove_databases_failures(cls):
|
||||||
@ -266,6 +275,28 @@ class SimpleTestCase(unittest.TestCase):
|
|||||||
method = getattr(connection, name)
|
method = getattr(connection, name)
|
||||||
setattr(connection, name, method.wrapped)
|
setattr(connection, name, method.wrapped)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ensure_connection_patch_method(cls):
|
||||||
|
real_ensure_connection = BaseDatabaseWrapper.ensure_connection
|
||||||
|
|
||||||
|
def patched_ensure_connection(self, *args, **kwargs):
|
||||||
|
if (
|
||||||
|
self.connection is None
|
||||||
|
and self.alias not in cls.databases
|
||||||
|
and self.alias != NO_DB_ALIAS
|
||||||
|
):
|
||||||
|
# Connection has not yet been established, but the alias is not allowed.
|
||||||
|
message = cls._disallowed_database_msg % {
|
||||||
|
"test": f"{cls.__module__}.{cls.__qualname__}",
|
||||||
|
"alias": self.alias,
|
||||||
|
"operation": "threaded connections",
|
||||||
|
}
|
||||||
|
return _DatabaseFailure(self.ensure_connection, message)()
|
||||||
|
|
||||||
|
real_ensure_connection(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return patched_ensure_connection
|
||||||
|
|
||||||
def __call__(self, result=None):
|
def __call__(self, result=None):
|
||||||
"""
|
"""
|
||||||
Wrapper around default __call__ method to perform common Django test
|
Wrapper around default __call__ method to perform common Django test
|
||||||
|
@ -250,6 +250,9 @@ Tests
|
|||||||
* The new :meth:`.SimpleTestCase.assertNotInHTML` assertion allows testing that
|
* The new :meth:`.SimpleTestCase.assertNotInHTML` assertion allows testing that
|
||||||
an HTML fragment is not contained in the given HTML haystack.
|
an HTML fragment is not contained in the given HTML haystack.
|
||||||
|
|
||||||
|
* In order to enforce test isolation, database connections inside threads are
|
||||||
|
no longer allowed in :class:`~django.test.SimpleTestCase`.
|
||||||
|
|
||||||
URLs
|
URLs
|
||||||
~~~~
|
~~~~
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@ -2093,6 +2094,29 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase):
|
|||||||
with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
|
with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
|
||||||
next(Car.objects.iterator())
|
next(Car.objects.iterator())
|
||||||
|
|
||||||
|
def test_disallowed_thread_database_connection(self):
|
||||||
|
expected_message = (
|
||||||
|
"Database threaded connections to 'default' are not allowed in "
|
||||||
|
"SimpleTestCase subclasses. Either subclass TestCase or TransactionTestCase"
|
||||||
|
" to ensure proper test isolation or add 'default' to "
|
||||||
|
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
|
||||||
|
"silence this failure."
|
||||||
|
)
|
||||||
|
|
||||||
|
exceptions = []
|
||||||
|
|
||||||
|
def thread_func():
|
||||||
|
try:
|
||||||
|
Car.objects.first()
|
||||||
|
except DatabaseOperationForbidden as e:
|
||||||
|
exceptions.append(e)
|
||||||
|
|
||||||
|
t = threading.Thread(target=thread_func)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
self.assertEqual(len(exceptions), 1)
|
||||||
|
self.assertEqual(exceptions[0].args[0], expected_message)
|
||||||
|
|
||||||
|
|
||||||
class AllowedDatabaseQueriesTests(SimpleTestCase):
|
class AllowedDatabaseQueriesTests(SimpleTestCase):
|
||||||
databases = {"default"}
|
databases = {"default"}
|
||||||
@ -2103,6 +2127,14 @@ class AllowedDatabaseQueriesTests(SimpleTestCase):
|
|||||||
def test_allowed_database_chunked_cursor_queries(self):
|
def test_allowed_database_chunked_cursor_queries(self):
|
||||||
next(Car.objects.iterator(), None)
|
next(Car.objects.iterator(), None)
|
||||||
|
|
||||||
|
def test_allowed_threaded_database_queries(self):
|
||||||
|
def thread_func():
|
||||||
|
next(Car.objects.iterator(), None)
|
||||||
|
|
||||||
|
t = threading.Thread(target=thread_func)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
class DatabaseAliasTests(SimpleTestCase):
|
class DatabaseAliasTests(SimpleTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user