1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #30448, Fixed #35618 -- Do not close connection in atomic block.

Previously, the `django.db.close_old_connections` handler for
`request_started` and `request_finished` would incorrectly close any
connection within a `transaction.atomic` block (such as the one
automatically used by `django.db.TestCase`), leading to
`InterfaceError: connection already closed`.

The test client and many of our tests have been working around this
bug by manually suppressing the `close_old_connections` handler, but
the workarounds are incomplete and the bug still affects other
projects.  Fix the bug and remove the workarounds.

Signed-off-by: Anders Kaseorg <andersk@mit.edu>
This commit is contained in:
Anders Kaseorg 2024-07-19 13:04:59 -07:00
parent 9cb8baa0c4
commit c91e92fbfc
8 changed files with 14 additions and 49 deletions

View File

@ -593,7 +593,7 @@ class BaseDatabaseWrapper:
Close the current connection if unrecoverable errors have occurred
or if it outlived its maximum age.
"""
if self.connection is not None:
if self.connection is not None and not self.in_atomic_block:
self.health_check_done = False
# If the application didn't restore the original autocommit setting,
# don't take chances, drop the connection.

View File

@ -17,8 +17,7 @@ from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import LimitedStream, WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
from django.core.signals import got_request_exception, request_finished, request_started
from django.db import close_old_connections
from django.core.signals import got_request_exception, request_started
from django.http import HttpHeaders, HttpRequest, QueryDict, SimpleCookie
from django.test import signals
from django.test.utils import ContextList
@ -121,9 +120,7 @@ def closing_iterator_wrapper(iterable, close):
try:
yield from iterable
finally:
request_finished.disconnect(close_old_connections)
close() # will fire request_finished
request_finished.connect(close_old_connections)
async def aclosing_iterator_wrapper(iterable, close):
@ -131,9 +128,7 @@ async def aclosing_iterator_wrapper(iterable, close):
async for chunk in iterable:
yield chunk
finally:
request_finished.disconnect(close_old_connections)
close() # will fire request_finished
request_finished.connect(close_old_connections)
def conditional_content_removal(request, response):
@ -172,9 +167,7 @@ class ClientHandler(BaseHandler):
if self._middleware_chain is None:
self.load_middleware()
request_started.disconnect(close_old_connections)
request_started.send(sender=self.__class__, environ=environ)
request_started.connect(close_old_connections)
request = WSGIRequest(environ)
# sneaky little hack so that we can easily get round
# CsrfViewMiddleware. This makes life easier, and is probably
@ -203,9 +196,7 @@ class ClientHandler(BaseHandler):
response.streaming_content, response.close
)
else:
request_finished.disconnect(close_old_connections)
response.close() # will fire request_finished
request_finished.connect(close_old_connections)
return response
@ -228,9 +219,7 @@ class AsyncClientHandler(BaseHandler):
else:
body_file = FakePayload("")
request_started.disconnect(close_old_connections)
await request_started.asend(sender=self.__class__, scope=scope)
request_started.connect(close_old_connections)
# Wrap FakePayload body_file to allow large read() in test environment.
request = ASGIRequest(scope, LimitedStream(body_file, len(body_file)))
# Sneaky little hack so that we can easily get round
@ -255,10 +244,8 @@ class AsyncClientHandler(BaseHandler):
response.streaming_content, response.close
)
else:
request_finished.disconnect(close_old_connections)
# Will fire request_finished.
await sync_to_async(response.close, thread_sensitive=False)()
request_finished.connect(close_old_connections)
return response

View File

@ -12,7 +12,6 @@ from django.core.asgi import get_asgi_application
from django.core.exceptions import RequestDataTooBig
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
from django.core.signals import request_finished, request_started
from django.db import close_old_connections
from django.http import HttpResponse, StreamingHttpResponse
from django.test import (
AsyncRequestFactory,
@ -45,10 +44,6 @@ class SignalHandler:
class ASGITest(SimpleTestCase):
async_request_factory = AsyncRequestFactory()
def setUp(self):
request_started.disconnect(close_old_connections)
self.addCleanup(request_started.connect, close_old_connections)
async def test_get_asgi_application(self):
"""
get_asgi_application() returns a functioning ASGI callable.

View File

@ -381,6 +381,11 @@ class ConnectionHealthChecksTests(SimpleTestCase):
connection.set_autocommit(True)
self.assertIs(new_connection, connection.connection)
def test_no_close_in_atomic(self):
with transaction.atomic():
connection.close_if_unusable_or_obsolete()
self.run_query()
class MultiDatabaseTests(TestCase):
databases = {"default", "other"}

View File

@ -27,7 +27,7 @@ from django.core.cache import (
from django.core.cache.backends.base import InvalidCacheBackendError
from django.core.cache.backends.redis import RedisCacheClient
from django.core.cache.utils import make_template_fragment_key
from django.db import close_old_connections, connection, connections
from django.db import connection, connections
from django.db.backends.utils import CursorWrapper
from django.http import (
HttpRequest,
@ -1560,15 +1560,11 @@ class BaseMemcachedTests(BaseCacheTests):
def test_close(self):
# For clients that don't manage their connections properly, the
# connection is closed when the request is complete.
signals.request_finished.disconnect(close_old_connections)
try:
with mock.patch.object(
cache._class, "disconnect_all", autospec=True
) as mock_disconnect:
signals.request_finished.send(self.__class__)
self.assertIs(mock_disconnect.called, self.should_disconnect_on_close)
finally:
signals.request_finished.connect(close_old_connections)
def test_set_many_returns_failing_keys(self):
def fail_set_multi(mapping, *args, **kwargs):

View File

@ -1,7 +1,7 @@
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler, WSGIRequest, get_script_name
from django.core.signals import request_finished, request_started
from django.db import close_old_connections, connection
from django.db import connection
from django.test import (
AsyncRequestFactory,
RequestFactory,
@ -14,10 +14,6 @@ from django.test import (
class HandlerTests(SimpleTestCase):
request_factory = RequestFactory()
def setUp(self):
request_started.disconnect(close_old_connections)
self.addCleanup(request_started.connect, close_old_connections)
def test_middleware_initialized(self):
handler = WSGIHandler()
self.assertIsNotNone(handler._middleware_chain)

View File

@ -7,8 +7,6 @@ import uuid
from django.core.exceptions import DisallowedRedirect
from django.core.serializers.json import DjangoJSONEncoder
from django.core.signals import request_finished
from django.db import close_old_connections
from django.http import (
BadHeaderError,
HttpResponse,
@ -758,12 +756,6 @@ class StreamingHttpResponseTests(SimpleTestCase):
class FileCloseTests(SimpleTestCase):
def setUp(self):
# Disable the request_finished signal during this test
# to avoid interfering with the database connection.
request_finished.disconnect(close_old_connections)
self.addCleanup(request_finished.connect, close_old_connections)
def test_response(self):
filename = os.path.join(os.path.dirname(__file__), "abc.txt")

View File

@ -1,8 +1,6 @@
from django.core.exceptions import ImproperlyConfigured
from django.core.servers.basehttp import get_internal_wsgi_application
from django.core.signals import request_started
from django.core.wsgi import get_wsgi_application
from django.db import close_old_connections
from django.http import FileResponse
from django.test import SimpleTestCase, override_settings
from django.test.client import RequestFactory
@ -12,10 +10,6 @@ from django.test.client import RequestFactory
class WSGITest(SimpleTestCase):
request_factory = RequestFactory()
def setUp(self):
request_started.disconnect(close_old_connections)
self.addCleanup(request_started.connect, close_old_connections)
def test_get_wsgi_application(self):
"""
get_wsgi_application() returns a functioning WSGI callable.