1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #30448, Fixed #35618 -- Do not close connection in atomic block.

Previously, the `django.db.close_old_connections` handler for
`request_started` and `request_finished` would incorrectly close any
connection within a `transaction.atomic` block (such as the one
automatically used by `django.db.TestCase`), leading to
`InterfaceError: connection already closed`.

The test client and many of our tests have been working around this
bug by manually suppressing the `close_old_connections` handler, but
the workarounds are incomplete and the bug still affects other
projects.  Fix the bug and remove the workarounds.

Signed-off-by: Anders Kaseorg <andersk@mit.edu>
This commit is contained in:
Anders Kaseorg 2024-07-19 13:04:59 -07:00
parent 9cb8baa0c4
commit c91e92fbfc
8 changed files with 14 additions and 49 deletions

View File

@ -593,7 +593,7 @@ class BaseDatabaseWrapper:
Close the current connection if unrecoverable errors have occurred Close the current connection if unrecoverable errors have occurred
or if it outlived its maximum age. or if it outlived its maximum age.
""" """
if self.connection is not None: if self.connection is not None and not self.in_atomic_block:
self.health_check_done = False self.health_check_done = False
# If the application didn't restore the original autocommit setting, # If the application didn't restore the original autocommit setting,
# don't take chances, drop the connection. # don't take chances, drop the connection.

View File

@ -17,8 +17,7 @@ from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import LimitedStream, WSGIRequest from django.core.handlers.wsgi import LimitedStream, WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder from django.core.serializers.json import DjangoJSONEncoder
from django.core.signals import got_request_exception, request_finished, request_started from django.core.signals import got_request_exception, request_started
from django.db import close_old_connections
from django.http import HttpHeaders, HttpRequest, QueryDict, SimpleCookie from django.http import HttpHeaders, HttpRequest, QueryDict, SimpleCookie
from django.test import signals from django.test import signals
from django.test.utils import ContextList from django.test.utils import ContextList
@ -121,9 +120,7 @@ def closing_iterator_wrapper(iterable, close):
try: try:
yield from iterable yield from iterable
finally: finally:
request_finished.disconnect(close_old_connections)
close() # will fire request_finished close() # will fire request_finished
request_finished.connect(close_old_connections)
async def aclosing_iterator_wrapper(iterable, close): async def aclosing_iterator_wrapper(iterable, close):
@ -131,9 +128,7 @@ async def aclosing_iterator_wrapper(iterable, close):
async for chunk in iterable: async for chunk in iterable:
yield chunk yield chunk
finally: finally:
request_finished.disconnect(close_old_connections)
close() # will fire request_finished close() # will fire request_finished
request_finished.connect(close_old_connections)
def conditional_content_removal(request, response): def conditional_content_removal(request, response):
@ -172,9 +167,7 @@ class ClientHandler(BaseHandler):
if self._middleware_chain is None: if self._middleware_chain is None:
self.load_middleware() self.load_middleware()
request_started.disconnect(close_old_connections)
request_started.send(sender=self.__class__, environ=environ) request_started.send(sender=self.__class__, environ=environ)
request_started.connect(close_old_connections)
request = WSGIRequest(environ) request = WSGIRequest(environ)
# sneaky little hack so that we can easily get round # sneaky little hack so that we can easily get round
# CsrfViewMiddleware. This makes life easier, and is probably # CsrfViewMiddleware. This makes life easier, and is probably
@ -203,9 +196,7 @@ class ClientHandler(BaseHandler):
response.streaming_content, response.close response.streaming_content, response.close
) )
else: else:
request_finished.disconnect(close_old_connections)
response.close() # will fire request_finished response.close() # will fire request_finished
request_finished.connect(close_old_connections)
return response return response
@ -228,9 +219,7 @@ class AsyncClientHandler(BaseHandler):
else: else:
body_file = FakePayload("") body_file = FakePayload("")
request_started.disconnect(close_old_connections)
await request_started.asend(sender=self.__class__, scope=scope) await request_started.asend(sender=self.__class__, scope=scope)
request_started.connect(close_old_connections)
# Wrap FakePayload body_file to allow large read() in test environment. # Wrap FakePayload body_file to allow large read() in test environment.
request = ASGIRequest(scope, LimitedStream(body_file, len(body_file))) request = ASGIRequest(scope, LimitedStream(body_file, len(body_file)))
# Sneaky little hack so that we can easily get round # Sneaky little hack so that we can easily get round
@ -255,10 +244,8 @@ class AsyncClientHandler(BaseHandler):
response.streaming_content, response.close response.streaming_content, response.close
) )
else: else:
request_finished.disconnect(close_old_connections)
# Will fire request_finished. # Will fire request_finished.
await sync_to_async(response.close, thread_sensitive=False)() await sync_to_async(response.close, thread_sensitive=False)()
request_finished.connect(close_old_connections)
return response return response

View File

@ -12,7 +12,6 @@ from django.core.asgi import get_asgi_application
from django.core.exceptions import RequestDataTooBig from django.core.exceptions import RequestDataTooBig
from django.core.handlers.asgi import ASGIHandler, ASGIRequest from django.core.handlers.asgi import ASGIHandler, ASGIRequest
from django.core.signals import request_finished, request_started from django.core.signals import request_finished, request_started
from django.db import close_old_connections
from django.http import HttpResponse, StreamingHttpResponse from django.http import HttpResponse, StreamingHttpResponse
from django.test import ( from django.test import (
AsyncRequestFactory, AsyncRequestFactory,
@ -45,10 +44,6 @@ class SignalHandler:
class ASGITest(SimpleTestCase): class ASGITest(SimpleTestCase):
async_request_factory = AsyncRequestFactory() async_request_factory = AsyncRequestFactory()
def setUp(self):
request_started.disconnect(close_old_connections)
self.addCleanup(request_started.connect, close_old_connections)
async def test_get_asgi_application(self): async def test_get_asgi_application(self):
""" """
get_asgi_application() returns a functioning ASGI callable. get_asgi_application() returns a functioning ASGI callable.

View File

@ -381,6 +381,11 @@ class ConnectionHealthChecksTests(SimpleTestCase):
connection.set_autocommit(True) connection.set_autocommit(True)
self.assertIs(new_connection, connection.connection) self.assertIs(new_connection, connection.connection)
def test_no_close_in_atomic(self):
with transaction.atomic():
connection.close_if_unusable_or_obsolete()
self.run_query()
class MultiDatabaseTests(TestCase): class MultiDatabaseTests(TestCase):
databases = {"default", "other"} databases = {"default", "other"}

16
tests/cache/tests.py vendored
View File

@ -27,7 +27,7 @@ from django.core.cache import (
from django.core.cache.backends.base import InvalidCacheBackendError from django.core.cache.backends.base import InvalidCacheBackendError
from django.core.cache.backends.redis import RedisCacheClient from django.core.cache.backends.redis import RedisCacheClient
from django.core.cache.utils import make_template_fragment_key from django.core.cache.utils import make_template_fragment_key
from django.db import close_old_connections, connection, connections from django.db import connection, connections
from django.db.backends.utils import CursorWrapper from django.db.backends.utils import CursorWrapper
from django.http import ( from django.http import (
HttpRequest, HttpRequest,
@ -1560,15 +1560,11 @@ class BaseMemcachedTests(BaseCacheTests):
def test_close(self): def test_close(self):
# For clients that don't manage their connections properly, the # For clients that don't manage their connections properly, the
# connection is closed when the request is complete. # connection is closed when the request is complete.
signals.request_finished.disconnect(close_old_connections) with mock.patch.object(
try: cache._class, "disconnect_all", autospec=True
with mock.patch.object( ) as mock_disconnect:
cache._class, "disconnect_all", autospec=True signals.request_finished.send(self.__class__)
) as mock_disconnect: self.assertIs(mock_disconnect.called, self.should_disconnect_on_close)
signals.request_finished.send(self.__class__)
self.assertIs(mock_disconnect.called, self.should_disconnect_on_close)
finally:
signals.request_finished.connect(close_old_connections)
def test_set_many_returns_failing_keys(self): def test_set_many_returns_failing_keys(self):
def fail_set_multi(mapping, *args, **kwargs): def fail_set_multi(mapping, *args, **kwargs):

View File

@ -1,7 +1,7 @@
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler, WSGIRequest, get_script_name from django.core.handlers.wsgi import WSGIHandler, WSGIRequest, get_script_name
from django.core.signals import request_finished, request_started from django.core.signals import request_finished, request_started
from django.db import close_old_connections, connection from django.db import connection
from django.test import ( from django.test import (
AsyncRequestFactory, AsyncRequestFactory,
RequestFactory, RequestFactory,
@ -14,10 +14,6 @@ from django.test import (
class HandlerTests(SimpleTestCase): class HandlerTests(SimpleTestCase):
request_factory = RequestFactory() request_factory = RequestFactory()
def setUp(self):
request_started.disconnect(close_old_connections)
self.addCleanup(request_started.connect, close_old_connections)
def test_middleware_initialized(self): def test_middleware_initialized(self):
handler = WSGIHandler() handler = WSGIHandler()
self.assertIsNotNone(handler._middleware_chain) self.assertIsNotNone(handler._middleware_chain)

View File

@ -7,8 +7,6 @@ import uuid
from django.core.exceptions import DisallowedRedirect from django.core.exceptions import DisallowedRedirect
from django.core.serializers.json import DjangoJSONEncoder from django.core.serializers.json import DjangoJSONEncoder
from django.core.signals import request_finished
from django.db import close_old_connections
from django.http import ( from django.http import (
BadHeaderError, BadHeaderError,
HttpResponse, HttpResponse,
@ -758,12 +756,6 @@ class StreamingHttpResponseTests(SimpleTestCase):
class FileCloseTests(SimpleTestCase): class FileCloseTests(SimpleTestCase):
def setUp(self):
# Disable the request_finished signal during this test
# to avoid interfering with the database connection.
request_finished.disconnect(close_old_connections)
self.addCleanup(request_finished.connect, close_old_connections)
def test_response(self): def test_response(self):
filename = os.path.join(os.path.dirname(__file__), "abc.txt") filename = os.path.join(os.path.dirname(__file__), "abc.txt")

View File

@ -1,8 +1,6 @@
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.servers.basehttp import get_internal_wsgi_application from django.core.servers.basehttp import get_internal_wsgi_application
from django.core.signals import request_started
from django.core.wsgi import get_wsgi_application from django.core.wsgi import get_wsgi_application
from django.db import close_old_connections
from django.http import FileResponse from django.http import FileResponse
from django.test import SimpleTestCase, override_settings from django.test import SimpleTestCase, override_settings
from django.test.client import RequestFactory from django.test.client import RequestFactory
@ -12,10 +10,6 @@ from django.test.client import RequestFactory
class WSGITest(SimpleTestCase): class WSGITest(SimpleTestCase):
request_factory = RequestFactory() request_factory = RequestFactory()
def setUp(self):
request_started.disconnect(close_old_connections)
self.addCleanup(request_started.connect, close_old_connections)
def test_get_wsgi_application(self): def test_get_wsgi_application(self):
""" """
get_wsgi_application() returns a functioning WSGI callable. get_wsgi_application() returns a functioning WSGI callable.