mirror of
				https://github.com/django/django.git
				synced 2025-10-26 15:16:09 +00:00 
			
		
		
		
	Fixed #30173 -- Simplified db.backends.postgresql.client.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							ddb2936852
						
					
				
				
					commit
					cf826c9a91
				
			| @@ -2,17 +2,9 @@ import os | |||||||
| import signal | import signal | ||||||
| import subprocess | import subprocess | ||||||
|  |  | ||||||
| from django.core.files.temp import NamedTemporaryFile |  | ||||||
| from django.db.backends.base.client import BaseDatabaseClient | from django.db.backends.base.client import BaseDatabaseClient | ||||||
|  |  | ||||||
|  |  | ||||||
| def _escape_pgpass(txt): |  | ||||||
|     """ |  | ||||||
|     Escape a fragment of a PostgreSQL .pgpass file. |  | ||||||
|     """ |  | ||||||
|     return txt.replace('\\', '\\\\').replace(':', '\\:') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseClient(BaseDatabaseClient): | class DatabaseClient(BaseDatabaseClient): | ||||||
|     executable_name = 'psql' |     executable_name = 'psql' | ||||||
|  |  | ||||||
| @@ -34,38 +26,17 @@ class DatabaseClient(BaseDatabaseClient): | |||||||
|             args += ['-p', str(port)] |             args += ['-p', str(port)] | ||||||
|         args += [dbname] |         args += [dbname] | ||||||
|  |  | ||||||
|         temp_pgpass = None |  | ||||||
|         sigint_handler = signal.getsignal(signal.SIGINT) |         sigint_handler = signal.getsignal(signal.SIGINT) | ||||||
|  |         subprocess_env = os.environ.copy() | ||||||
|  |         if passwd: | ||||||
|  |             subprocess_env['PGPASSWORD'] = str(passwd) | ||||||
|         try: |         try: | ||||||
|             if passwd: |  | ||||||
|                 # Create temporary .pgpass file. |  | ||||||
|                 temp_pgpass = NamedTemporaryFile(mode='w+') |  | ||||||
|                 try: |  | ||||||
|                     print( |  | ||||||
|                         _escape_pgpass(host) or '*', |  | ||||||
|                         str(port) or '*', |  | ||||||
|                         _escape_pgpass(dbname) or '*', |  | ||||||
|                         _escape_pgpass(user) or '*', |  | ||||||
|                         _escape_pgpass(passwd), |  | ||||||
|                         file=temp_pgpass, |  | ||||||
|                         sep=':', |  | ||||||
|                         flush=True, |  | ||||||
|                     ) |  | ||||||
|                     os.environ['PGPASSFILE'] = temp_pgpass.name |  | ||||||
|                 except UnicodeEncodeError: |  | ||||||
|                     # If the current locale can't encode the data, let the |  | ||||||
|                     # user input the password manually. |  | ||||||
|                     pass |  | ||||||
|             # Allow SIGINT to pass to psql to abort queries. |             # Allow SIGINT to pass to psql to abort queries. | ||||||
|             signal.signal(signal.SIGINT, signal.SIG_IGN) |             signal.signal(signal.SIGINT, signal.SIG_IGN) | ||||||
|             subprocess.check_call(args) |             subprocess.run(args, check=True, env=subprocess_env) | ||||||
|         finally: |         finally: | ||||||
|             # Restore the original SIGINT handler. |             # Restore the original SIGINT handler. | ||||||
|             signal.signal(signal.SIGINT, sigint_handler) |             signal.signal(signal.SIGINT, sigint_handler) | ||||||
|             if temp_pgpass: |  | ||||||
|                 temp_pgpass.close() |  | ||||||
|                 if 'PGPASSFILE' in os.environ:  # unit tests need cleanup |  | ||||||
|                     del os.environ['PGPASSFILE'] |  | ||||||
|  |  | ||||||
|     def runshell(self): |     def runshell(self): | ||||||
|         DatabaseClient.runshell_db(self.connection.get_connection_params()) |         DatabaseClient.runshell_db(self.connection.get_connection_params()) | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import os | import os | ||||||
| import signal | import signal | ||||||
|  | import subprocess | ||||||
| from unittest import mock | from unittest import mock | ||||||
|  |  | ||||||
| from django.db.backends.postgresql.client import DatabaseClient | from django.db.backends.postgresql.client import DatabaseClient | ||||||
| @@ -11,23 +12,17 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): | |||||||
|     def _run_it(self, dbinfo): |     def _run_it(self, dbinfo): | ||||||
|         """ |         """ | ||||||
|         That function invokes the runshell command, while mocking |         That function invokes the runshell command, while mocking | ||||||
|         subprocess.call. It returns a 2-tuple with: |         subprocess.run(). It returns a 2-tuple with: | ||||||
|         - The command line list |         - The command line list | ||||||
|         - The content of the file pointed by environment PGPASSFILE, or None. |         - The the value of the PGPASSWORD environment variable, or None. | ||||||
|         """ |         """ | ||||||
|         def _mock_subprocess_call(*args): |         def _mock_subprocess_run(*args, env=os.environ, **kwargs): | ||||||
|             self.subprocess_args = list(*args) |             self.subprocess_args = list(*args) | ||||||
|             if 'PGPASSFILE' in os.environ: |             self.pgpassword = env.get('PGPASSWORD') | ||||||
|                 with open(os.environ['PGPASSFILE']) as f: |             return subprocess.CompletedProcess(self.subprocess_args, 0) | ||||||
|                     self.pgpass = f.read().strip()  # ignore line endings |         with mock.patch('subprocess.run', new=_mock_subprocess_run): | ||||||
|             else: |  | ||||||
|                 self.pgpass = None |  | ||||||
|             return 0 |  | ||||||
|         self.subprocess_args = None |  | ||||||
|         self.pgpass = None |  | ||||||
|         with mock.patch('subprocess.call', new=_mock_subprocess_call): |  | ||||||
|             DatabaseClient.runshell_db(dbinfo) |             DatabaseClient.runshell_db(dbinfo) | ||||||
|         return self.subprocess_args, self.pgpass |         return self.subprocess_args, self.pgpassword | ||||||
|  |  | ||||||
|     def test_basic(self): |     def test_basic(self): | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
| @@ -39,7 +34,7 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): | |||||||
|                 'port': '444', |                 'port': '444', | ||||||
|             }), ( |             }), ( | ||||||
|                 ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], |                 ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], | ||||||
|                 'somehost:444:dbname:someuser:somepassword', |                 'somepassword', | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @@ -66,28 +61,13 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): | |||||||
|                 'port': '444', |                 'port': '444', | ||||||
|             }), ( |             }), ( | ||||||
|                 ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], |                 ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], | ||||||
|                 '\\:\\:1:444:dbname:some\\:user:some\\:password', |                 'some:password', | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_escape_characters(self): |  | ||||||
|         self.assertEqual( |  | ||||||
|             self._run_it({ |  | ||||||
|                 'database': 'dbname', |  | ||||||
|                 'user': 'some\\user', |  | ||||||
|                 'password': 'some\\password', |  | ||||||
|                 'host': 'somehost', |  | ||||||
|                 'port': '444', |  | ||||||
|             }), ( |  | ||||||
|                 ['psql', '-U', 'some\\user', '-h', 'somehost', '-p', '444', 'dbname'], |  | ||||||
|                 'somehost:444:dbname:some\\\\user:some\\\\password', |  | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_accent(self): |     def test_accent(self): | ||||||
|         username = 'rôle' |         username = 'rôle' | ||||||
|         password = 'sésame' |         password = 'sésame' | ||||||
|         pgpass_string = 'somehost:444:dbname:%s:%s' % (username, password) |  | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             self._run_it({ |             self._run_it({ | ||||||
|                 'database': 'dbname', |                 'database': 'dbname', | ||||||
| @@ -97,20 +77,20 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): | |||||||
|                 'port': '444', |                 'port': '444', | ||||||
|             }), ( |             }), ( | ||||||
|                 ['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'], |                 ['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'], | ||||||
|                 pgpass_string, |                 password, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_sigint_handler(self): |     def test_sigint_handler(self): | ||||||
|         """SIGINT is ignored in Python and passed to psql to abort quries.""" |         """SIGINT is ignored in Python and passed to psql to abort quries.""" | ||||||
|         def _mock_subprocess_call(*args): |         def _mock_subprocess_run(*args, **kwargs): | ||||||
|             handler = signal.getsignal(signal.SIGINT) |             handler = signal.getsignal(signal.SIGINT) | ||||||
|             self.assertEqual(handler, signal.SIG_IGN) |             self.assertEqual(handler, signal.SIG_IGN) | ||||||
|  |  | ||||||
|         sigint_handler = signal.getsignal(signal.SIGINT) |         sigint_handler = signal.getsignal(signal.SIGINT) | ||||||
|         # The default handler isn't SIG_IGN. |         # The default handler isn't SIG_IGN. | ||||||
|         self.assertNotEqual(sigint_handler, signal.SIG_IGN) |         self.assertNotEqual(sigint_handler, signal.SIG_IGN) | ||||||
|         with mock.patch('subprocess.check_call', new=_mock_subprocess_call): |         with mock.patch('subprocess.run', new=_mock_subprocess_run): | ||||||
|             DatabaseClient.runshell_db({}) |             DatabaseClient.runshell_db({}) | ||||||
|         # dbshell restores the original handler. |         # dbshell restores the original handler. | ||||||
|         self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT)) |         self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user