mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fixed #32416 -- Made ThreadedWSGIServer close connections after each thread.
ThreadedWSGIServer is used by LiveServerTestCase.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							71a936f9d8
						
					
				
				
					commit
					823a9e6bac
				
			| @@ -16,6 +16,7 @@ from wsgiref import simple_server | |||||||
| from django.core.exceptions import ImproperlyConfigured | from django.core.exceptions import ImproperlyConfigured | ||||||
| from django.core.handlers.wsgi import LimitedStream | from django.core.handlers.wsgi import LimitedStream | ||||||
| from django.core.wsgi import get_wsgi_application | from django.core.wsgi import get_wsgi_application | ||||||
|  | from django.db import connections | ||||||
| from django.utils.module_loading import import_string | from django.utils.module_loading import import_string | ||||||
|  |  | ||||||
| __all__ = ('WSGIServer', 'WSGIRequestHandler') | __all__ = ('WSGIServer', 'WSGIRequestHandler') | ||||||
| @@ -81,6 +82,28 @@ class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer): | |||||||
|     """A threaded version of the WSGIServer""" |     """A threaded version of the WSGIServer""" | ||||||
|     daemon_threads = True |     daemon_threads = True | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, connections_override=None, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.connections_override = connections_override | ||||||
|  |  | ||||||
|  |     # socketserver.ThreadingMixIn.process_request() passes this method as | ||||||
|  |     # the target to a new Thread object. | ||||||
|  |     def process_request_thread(self, request, client_address): | ||||||
|  |         if self.connections_override: | ||||||
|  |             # Override this thread's database connections with the ones | ||||||
|  |             # provided by the parent thread. | ||||||
|  |             for alias, conn in self.connections_override.items(): | ||||||
|  |                 connections[alias] = conn | ||||||
|  |         super().process_request_thread(request, client_address) | ||||||
|  |  | ||||||
|  |     def _close_connections(self): | ||||||
|  |         # Used for mocking in tests. | ||||||
|  |         connections.close_all() | ||||||
|  |  | ||||||
|  |     def close_request(self, request): | ||||||
|  |         self._close_connections() | ||||||
|  |         super().close_request(request) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServerHandler(simple_server.ServerHandler): | class ServerHandler(simple_server.ServerHandler): | ||||||
|     http_version = '1.1' |     http_version = '1.1' | ||||||
|   | |||||||
| @@ -83,6 +83,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|                 "the sqlite backend's close() method is a no-op when using an " |                 "the sqlite backend's close() method is a no-op when using an " | ||||||
|                 "in-memory database": { |                 "in-memory database": { | ||||||
|                     'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections', |                     'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections', | ||||||
|  |                     'servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections', | ||||||
|                 }, |                 }, | ||||||
|             }) |             }) | ||||||
|         return skips |         return skips | ||||||
|   | |||||||
| @@ -1513,11 +1513,12 @@ class LiveServerThread(threading.Thread): | |||||||
|         finally: |         finally: | ||||||
|             connections.close_all() |             connections.close_all() | ||||||
|  |  | ||||||
|     def _create_server(self): |     def _create_server(self, connections_override=None): | ||||||
|         return self.server_class( |         return self.server_class( | ||||||
|             (self.host, self.port), |             (self.host, self.port), | ||||||
|             QuietWSGIRequestHandler, |             QuietWSGIRequestHandler, | ||||||
|             allow_reuse_address=False, |             allow_reuse_address=False, | ||||||
|  |             connections_override=connections_override, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def terminate(self): |     def terminate(self): | ||||||
| @@ -1600,7 +1601,7 @@ class LiveServerTestCase(TransactionTestCase): | |||||||
|     def _tearDownClassInternal(cls): |     def _tearDownClassInternal(cls): | ||||||
|         # Terminate the live server's thread. |         # Terminate the live server's thread. | ||||||
|         cls.server_thread.terminate() |         cls.server_thread.terminate() | ||||||
|         # Restore sqlite in-memory database connections' non-shareability. |         # Restore shared connections' non-shareability. | ||||||
|         for conn in cls.server_thread.connections_override.values(): |         for conn in cls.server_thread.connections_override.values(): | ||||||
|             conn.dec_thread_sharing() |             conn.dec_thread_sharing() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,13 +4,15 @@ Tests for django.core.servers. | |||||||
| import errno | import errno | ||||||
| import os | import os | ||||||
| import socket | import socket | ||||||
|  | import threading | ||||||
| from http.client import HTTPConnection | from http.client import HTTPConnection | ||||||
| from urllib.error import HTTPError | from urllib.error import HTTPError | ||||||
| from urllib.parse import urlencode | from urllib.parse import urlencode | ||||||
| from urllib.request import urlopen | from urllib.request import urlopen | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.core.servers.basehttp import WSGIServer | from django.core.servers.basehttp import ThreadedWSGIServer, WSGIServer | ||||||
|  | from django.db import DEFAULT_DB_ALIAS, connections | ||||||
| from django.test import LiveServerTestCase, override_settings | from django.test import LiveServerTestCase, override_settings | ||||||
| from django.test.testcases import LiveServerThread, QuietWSGIRequestHandler | from django.test.testcases import LiveServerThread, QuietWSGIRequestHandler | ||||||
|  |  | ||||||
| @@ -40,6 +42,71 @@ class LiveServerBase(LiveServerTestCase): | |||||||
|         return urlopen(self.live_server_url + url) |         return urlopen(self.live_server_url + url) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CloseConnectionTestServer(ThreadedWSGIServer): | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         # This event is set right after the first time a request closes its | ||||||
|  |         # database connections. | ||||||
|  |         self._connections_closed = threading.Event() | ||||||
|  |  | ||||||
|  |     def _close_connections(self): | ||||||
|  |         super()._close_connections() | ||||||
|  |         self._connections_closed.set() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CloseConnectionTestLiveServerThread(LiveServerThread): | ||||||
|  |  | ||||||
|  |     server_class = CloseConnectionTestServer | ||||||
|  |  | ||||||
|  |     def _create_server(self, connections_override=None): | ||||||
|  |         return super()._create_server(connections_override=self.connections_override) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LiveServerTestCloseConnectionTest(LiveServerBase): | ||||||
|  |  | ||||||
|  |     server_thread_class = CloseConnectionTestLiveServerThread | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def _make_connections_override(cls): | ||||||
|  |         conn = connections[DEFAULT_DB_ALIAS] | ||||||
|  |         cls.conn = conn | ||||||
|  |         cls.old_conn_max_age = conn.settings_dict['CONN_MAX_AGE'] | ||||||
|  |         # Set the connection's CONN_MAX_AGE to None to simulate the | ||||||
|  |         # CONN_MAX_AGE setting being set to None on the server. This prevents | ||||||
|  |         # Django from closing the connection and allows testing that | ||||||
|  |         # ThreadedWSGIServer closes connections. | ||||||
|  |         conn.settings_dict['CONN_MAX_AGE'] = None | ||||||
|  |         # Pass a database connection through to the server to check it is being | ||||||
|  |         # closed by ThreadedWSGIServer. | ||||||
|  |         return {DEFAULT_DB_ALIAS: conn} | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def tearDownConnectionTest(cls): | ||||||
|  |         cls.conn.settings_dict['CONN_MAX_AGE'] = cls.old_conn_max_age | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def tearDownClass(cls): | ||||||
|  |         cls.tearDownConnectionTest() | ||||||
|  |         super().tearDownClass() | ||||||
|  |  | ||||||
|  |     def test_closes_connections(self): | ||||||
|  |         # The server's request thread sets this event after closing | ||||||
|  |         # its database connections. | ||||||
|  |         closed_event = self.server_thread.httpd._connections_closed | ||||||
|  |         conn = self.conn | ||||||
|  |         # Open a connection to the database. | ||||||
|  |         conn.connect() | ||||||
|  |         self.assertIsNotNone(conn.connection) | ||||||
|  |         with self.urlopen('/model_view/') as f: | ||||||
|  |             # The server can access the database. | ||||||
|  |             self.assertEqual(f.read().splitlines(), [b'jane', b'robert']) | ||||||
|  |         # Wait for the server's request thread to close the connection. | ||||||
|  |         # A timeout of 0.1 seconds should be more than enough. If the wait | ||||||
|  |         # times out, the assertion after should fail. | ||||||
|  |         closed_event.wait(timeout=0.1) | ||||||
|  |         self.assertIsNone(conn.connection) | ||||||
|  |  | ||||||
|  |  | ||||||
| class FailingLiveServerThread(LiveServerThread): | class FailingLiveServerThread(LiveServerThread): | ||||||
|     def _create_server(self): |     def _create_server(self): | ||||||
|         raise RuntimeError('Error creating server.') |         raise RuntimeError('Error creating server.') | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user