From cf826c9a91015c8da2ad4910b12e2ed83e2fb20f Mon Sep 17 00:00:00 2001 From: Daniel Bowring Date: Mon, 11 Feb 2019 12:17:24 +1100 Subject: [PATCH] Fixed #30173 -- Simplified db.backends.postgresql.client. --- django/db/backends/postgresql/client.py | 37 +++----------------- tests/dbshell/test_postgresql.py | 46 +++++++------------------ 2 files changed, 17 insertions(+), 66 deletions(-) diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py index 7fca6eff30..cf4df76882 100644 --- a/django/db/backends/postgresql/client.py +++ b/django/db/backends/postgresql/client.py @@ -2,17 +2,9 @@ import os import signal import subprocess -from django.core.files.temp import NamedTemporaryFile from django.db.backends.base.client import BaseDatabaseClient -def _escape_pgpass(txt): - """ - Escape a fragment of a PostgreSQL .pgpass file. - """ - return txt.replace('\\', '\\\\').replace(':', '\\:') - - class DatabaseClient(BaseDatabaseClient): executable_name = 'psql' @@ -34,38 +26,17 @@ class DatabaseClient(BaseDatabaseClient): args += ['-p', str(port)] args += [dbname] - temp_pgpass = None sigint_handler = signal.getsignal(signal.SIGINT) + subprocess_env = os.environ.copy() + if passwd: + subprocess_env['PGPASSWORD'] = str(passwd) try: - if passwd: - # Create temporary .pgpass file. - temp_pgpass = NamedTemporaryFile(mode='w+') - try: - print( - _escape_pgpass(host) or '*', - str(port) or '*', - _escape_pgpass(dbname) or '*', - _escape_pgpass(user) or '*', - _escape_pgpass(passwd), - file=temp_pgpass, - sep=':', - flush=True, - ) - os.environ['PGPASSFILE'] = temp_pgpass.name - except UnicodeEncodeError: - # If the current locale can't encode the data, let the - # user input the password manually. - pass # Allow SIGINT to pass to psql to abort queries. signal.signal(signal.SIGINT, signal.SIG_IGN) - subprocess.check_call(args) + subprocess.run(args, check=True, env=subprocess_env) finally: # Restore the original SIGINT handler. signal.signal(signal.SIGINT, sigint_handler) - if temp_pgpass: - temp_pgpass.close() - if 'PGPASSFILE' in os.environ: # unit tests need cleanup - del os.environ['PGPASSFILE'] def runshell(self): DatabaseClient.runshell_db(self.connection.get_connection_params()) diff --git a/tests/dbshell/test_postgresql.py b/tests/dbshell/test_postgresql.py index 0d4f28554d..a33e7f6482 100644 --- a/tests/dbshell/test_postgresql.py +++ b/tests/dbshell/test_postgresql.py @@ -1,5 +1,6 @@ import os import signal +import subprocess from unittest import mock from django.db.backends.postgresql.client import DatabaseClient @@ -11,23 +12,17 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): def _run_it(self, dbinfo): """ That function invokes the runshell command, while mocking - subprocess.call. It returns a 2-tuple with: + subprocess.run(). It returns a 2-tuple with: - The command line list - - The content of the file pointed by environment PGPASSFILE, or None. + - The the value of the PGPASSWORD environment variable, or None. """ - def _mock_subprocess_call(*args): + def _mock_subprocess_run(*args, env=os.environ, **kwargs): self.subprocess_args = list(*args) - if 'PGPASSFILE' in os.environ: - with open(os.environ['PGPASSFILE']) as f: - self.pgpass = f.read().strip() # ignore line endings - else: - self.pgpass = None - return 0 - self.subprocess_args = None - self.pgpass = None - with mock.patch('subprocess.call', new=_mock_subprocess_call): + self.pgpassword = env.get('PGPASSWORD') + return subprocess.CompletedProcess(self.subprocess_args, 0) + with mock.patch('subprocess.run', new=_mock_subprocess_run): DatabaseClient.runshell_db(dbinfo) - return self.subprocess_args, self.pgpass + return self.subprocess_args, self.pgpassword def test_basic(self): self.assertEqual( @@ -39,7 +34,7 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): 'port': '444', }), ( ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], - 'somehost:444:dbname:someuser:somepassword', + 'somepassword', ) ) @@ -66,28 +61,13 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): 'port': '444', }), ( ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], - '\\:\\:1:444:dbname:some\\:user:some\\:password', - ) - ) - - def test_escape_characters(self): - self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'some\\user', - 'password': 'some\\password', - 'host': 'somehost', - 'port': '444', - }), ( - ['psql', '-U', 'some\\user', '-h', 'somehost', '-p', '444', 'dbname'], - 'somehost:444:dbname:some\\\\user:some\\\\password', + 'some:password', ) ) def test_accent(self): username = 'rôle' password = 'sésame' - pgpass_string = 'somehost:444:dbname:%s:%s' % (username, password) self.assertEqual( self._run_it({ 'database': 'dbname', @@ -97,20 +77,20 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): 'port': '444', }), ( ['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'], - pgpass_string, + password, ) ) def test_sigint_handler(self): """SIGINT is ignored in Python and passed to psql to abort quries.""" - def _mock_subprocess_call(*args): + def _mock_subprocess_run(*args, **kwargs): handler = signal.getsignal(signal.SIGINT) self.assertEqual(handler, signal.SIG_IGN) sigint_handler = signal.getsignal(signal.SIGINT) # The default handler isn't SIG_IGN. self.assertNotEqual(sigint_handler, signal.SIG_IGN) - with mock.patch('subprocess.check_call', new=_mock_subprocess_call): + with mock.patch('subprocess.run', new=_mock_subprocess_run): DatabaseClient.runshell_db({}) # dbshell restores the original handler. self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT))