mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Previously, the `django.db.close_old_connections` handler for `request_started` and `request_finished` would incorrectly close any connection within a `transaction.atomic` block (such as the one automatically used by `django.db.TestCase`), leading to `InterfaceError: connection already closed`. The test client and many of our tests have been working around this bug by manually suppressing the `close_old_connections` handler, but the workarounds are incomplete and the bug still affects other projects. Fix the bug and remove the workarounds. Signed-off-by: Anders Kaseorg <andersk@mit.edu>
This commit is contained in:
		@@ -593,7 +593,7 @@ class BaseDatabaseWrapper:
 | 
			
		||||
        Close the current connection if unrecoverable errors have occurred
 | 
			
		||||
        or if it outlived its maximum age.
 | 
			
		||||
        """
 | 
			
		||||
        if self.connection is not None:
 | 
			
		||||
        if self.connection is not None and not self.in_atomic_block:
 | 
			
		||||
            self.health_check_done = False
 | 
			
		||||
            # If the application didn't restore the original autocommit setting,
 | 
			
		||||
            # don't take chances, drop the connection.
 | 
			
		||||
 
 | 
			
		||||
@@ -17,8 +17,7 @@ from django.core.handlers.asgi import ASGIRequest
 | 
			
		||||
from django.core.handlers.base import BaseHandler
 | 
			
		||||
from django.core.handlers.wsgi import LimitedStream, WSGIRequest
 | 
			
		||||
from django.core.serializers.json import DjangoJSONEncoder
 | 
			
		||||
from django.core.signals import got_request_exception, request_finished, request_started
 | 
			
		||||
from django.db import close_old_connections
 | 
			
		||||
from django.core.signals import got_request_exception, request_started
 | 
			
		||||
from django.http import HttpHeaders, HttpRequest, QueryDict, SimpleCookie
 | 
			
		||||
from django.test import signals
 | 
			
		||||
from django.test.utils import ContextList
 | 
			
		||||
@@ -121,9 +120,7 @@ def closing_iterator_wrapper(iterable, close):
 | 
			
		||||
    try:
 | 
			
		||||
        yield from iterable
 | 
			
		||||
    finally:
 | 
			
		||||
        request_finished.disconnect(close_old_connections)
 | 
			
		||||
        close()  # will fire request_finished
 | 
			
		||||
        request_finished.connect(close_old_connections)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def aclosing_iterator_wrapper(iterable, close):
 | 
			
		||||
@@ -131,9 +128,7 @@ async def aclosing_iterator_wrapper(iterable, close):
 | 
			
		||||
        async for chunk in iterable:
 | 
			
		||||
            yield chunk
 | 
			
		||||
    finally:
 | 
			
		||||
        request_finished.disconnect(close_old_connections)
 | 
			
		||||
        close()  # will fire request_finished
 | 
			
		||||
        request_finished.connect(close_old_connections)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def conditional_content_removal(request, response):
 | 
			
		||||
@@ -172,9 +167,7 @@ class ClientHandler(BaseHandler):
 | 
			
		||||
        if self._middleware_chain is None:
 | 
			
		||||
            self.load_middleware()
 | 
			
		||||
 | 
			
		||||
        request_started.disconnect(close_old_connections)
 | 
			
		||||
        request_started.send(sender=self.__class__, environ=environ)
 | 
			
		||||
        request_started.connect(close_old_connections)
 | 
			
		||||
        request = WSGIRequest(environ)
 | 
			
		||||
        # sneaky little hack so that we can easily get round
 | 
			
		||||
        # CsrfViewMiddleware.  This makes life easier, and is probably
 | 
			
		||||
@@ -203,9 +196,7 @@ class ClientHandler(BaseHandler):
 | 
			
		||||
                    response.streaming_content, response.close
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            request_finished.disconnect(close_old_connections)
 | 
			
		||||
            response.close()  # will fire request_finished
 | 
			
		||||
            request_finished.connect(close_old_connections)
 | 
			
		||||
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
@@ -228,9 +219,7 @@ class AsyncClientHandler(BaseHandler):
 | 
			
		||||
        else:
 | 
			
		||||
            body_file = FakePayload("")
 | 
			
		||||
 | 
			
		||||
        request_started.disconnect(close_old_connections)
 | 
			
		||||
        await request_started.asend(sender=self.__class__, scope=scope)
 | 
			
		||||
        request_started.connect(close_old_connections)
 | 
			
		||||
        # Wrap FakePayload body_file to allow large read() in test environment.
 | 
			
		||||
        request = ASGIRequest(scope, LimitedStream(body_file, len(body_file)))
 | 
			
		||||
        # Sneaky little hack so that we can easily get round
 | 
			
		||||
@@ -255,10 +244,8 @@ class AsyncClientHandler(BaseHandler):
 | 
			
		||||
                    response.streaming_content, response.close
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            request_finished.disconnect(close_old_connections)
 | 
			
		||||
            # Will fire request_finished.
 | 
			
		||||
            await sync_to_async(response.close, thread_sensitive=False)()
 | 
			
		||||
            request_finished.connect(close_old_connections)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,6 @@ from django.core.asgi import get_asgi_application
 | 
			
		||||
from django.core.exceptions import RequestDataTooBig
 | 
			
		||||
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
 | 
			
		||||
from django.core.signals import request_finished, request_started
 | 
			
		||||
from django.db import close_old_connections
 | 
			
		||||
from django.http import HttpResponse, StreamingHttpResponse
 | 
			
		||||
from django.test import (
 | 
			
		||||
    AsyncRequestFactory,
 | 
			
		||||
@@ -45,10 +44,6 @@ class SignalHandler:
 | 
			
		||||
class ASGITest(SimpleTestCase):
 | 
			
		||||
    async_request_factory = AsyncRequestFactory()
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        request_started.disconnect(close_old_connections)
 | 
			
		||||
        self.addCleanup(request_started.connect, close_old_connections)
 | 
			
		||||
 | 
			
		||||
    async def test_get_asgi_application(self):
 | 
			
		||||
        """
 | 
			
		||||
        get_asgi_application() returns a functioning ASGI callable.
 | 
			
		||||
 
 | 
			
		||||
@@ -381,6 +381,11 @@ class ConnectionHealthChecksTests(SimpleTestCase):
 | 
			
		||||
        connection.set_autocommit(True)
 | 
			
		||||
        self.assertIs(new_connection, connection.connection)
 | 
			
		||||
 | 
			
		||||
    def test_no_close_in_atomic(self):
 | 
			
		||||
        with transaction.atomic():
 | 
			
		||||
            connection.close_if_unusable_or_obsolete()
 | 
			
		||||
            self.run_query()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MultiDatabaseTests(TestCase):
 | 
			
		||||
    databases = {"default", "other"}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								tests/cache/tests.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								tests/cache/tests.py
									
									
									
									
										vendored
									
									
								
							@@ -27,7 +27,7 @@ from django.core.cache import (
 | 
			
		||||
from django.core.cache.backends.base import InvalidCacheBackendError
 | 
			
		||||
from django.core.cache.backends.redis import RedisCacheClient
 | 
			
		||||
from django.core.cache.utils import make_template_fragment_key
 | 
			
		||||
from django.db import close_old_connections, connection, connections
 | 
			
		||||
from django.db import connection, connections
 | 
			
		||||
from django.db.backends.utils import CursorWrapper
 | 
			
		||||
from django.http import (
 | 
			
		||||
    HttpRequest,
 | 
			
		||||
@@ -1560,15 +1560,11 @@ class BaseMemcachedTests(BaseCacheTests):
 | 
			
		||||
    def test_close(self):
 | 
			
		||||
        # For clients that don't manage their connections properly, the
 | 
			
		||||
        # connection is closed when the request is complete.
 | 
			
		||||
        signals.request_finished.disconnect(close_old_connections)
 | 
			
		||||
        try:
 | 
			
		||||
        with mock.patch.object(
 | 
			
		||||
            cache._class, "disconnect_all", autospec=True
 | 
			
		||||
        ) as mock_disconnect:
 | 
			
		||||
            signals.request_finished.send(self.__class__)
 | 
			
		||||
            self.assertIs(mock_disconnect.called, self.should_disconnect_on_close)
 | 
			
		||||
        finally:
 | 
			
		||||
            signals.request_finished.connect(close_old_connections)
 | 
			
		||||
 | 
			
		||||
    def test_set_many_returns_failing_keys(self):
 | 
			
		||||
        def fail_set_multi(mapping, *args, **kwargs):
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
from django.core.exceptions import ImproperlyConfigured
 | 
			
		||||
from django.core.handlers.wsgi import WSGIHandler, WSGIRequest, get_script_name
 | 
			
		||||
from django.core.signals import request_finished, request_started
 | 
			
		||||
from django.db import close_old_connections, connection
 | 
			
		||||
from django.db import connection
 | 
			
		||||
from django.test import (
 | 
			
		||||
    AsyncRequestFactory,
 | 
			
		||||
    RequestFactory,
 | 
			
		||||
@@ -14,10 +14,6 @@ from django.test import (
 | 
			
		||||
class HandlerTests(SimpleTestCase):
 | 
			
		||||
    request_factory = RequestFactory()
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        request_started.disconnect(close_old_connections)
 | 
			
		||||
        self.addCleanup(request_started.connect, close_old_connections)
 | 
			
		||||
 | 
			
		||||
    def test_middleware_initialized(self):
 | 
			
		||||
        handler = WSGIHandler()
 | 
			
		||||
        self.assertIsNotNone(handler._middleware_chain)
 | 
			
		||||
 
 | 
			
		||||
@@ -7,8 +7,6 @@ import uuid
 | 
			
		||||
 | 
			
		||||
from django.core.exceptions import DisallowedRedirect
 | 
			
		||||
from django.core.serializers.json import DjangoJSONEncoder
 | 
			
		||||
from django.core.signals import request_finished
 | 
			
		||||
from django.db import close_old_connections
 | 
			
		||||
from django.http import (
 | 
			
		||||
    BadHeaderError,
 | 
			
		||||
    HttpResponse,
 | 
			
		||||
@@ -758,12 +756,6 @@ class StreamingHttpResponseTests(SimpleTestCase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FileCloseTests(SimpleTestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        # Disable the request_finished signal during this test
 | 
			
		||||
        # to avoid interfering with the database connection.
 | 
			
		||||
        request_finished.disconnect(close_old_connections)
 | 
			
		||||
        self.addCleanup(request_finished.connect, close_old_connections)
 | 
			
		||||
 | 
			
		||||
    def test_response(self):
 | 
			
		||||
        filename = os.path.join(os.path.dirname(__file__), "abc.txt")
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,6 @@
 | 
			
		||||
from django.core.exceptions import ImproperlyConfigured
 | 
			
		||||
from django.core.servers.basehttp import get_internal_wsgi_application
 | 
			
		||||
from django.core.signals import request_started
 | 
			
		||||
from django.core.wsgi import get_wsgi_application
 | 
			
		||||
from django.db import close_old_connections
 | 
			
		||||
from django.http import FileResponse
 | 
			
		||||
from django.test import SimpleTestCase, override_settings
 | 
			
		||||
from django.test.client import RequestFactory
 | 
			
		||||
@@ -12,10 +10,6 @@ from django.test.client import RequestFactory
 | 
			
		||||
class WSGITest(SimpleTestCase):
 | 
			
		||||
    request_factory = RequestFactory()
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        request_started.disconnect(close_old_connections)
 | 
			
		||||
        self.addCleanup(request_started.connect, close_old_connections)
 | 
			
		||||
 | 
			
		||||
    def test_get_wsgi_application(self):
 | 
			
		||||
        """
 | 
			
		||||
        get_wsgi_application() returns a functioning WSGI callable.
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user