mirror of
				https://github.com/django/django.git
				synced 2025-10-26 15:16:09 +00:00 
			
		
		
		
	Fixed #23658 -- Provided the password to PostgreSQL dbshell command
The password from settings.py is written in a temporary .pgpass file file whose name is given to psql using the PGPASSFILE environment variable.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							eecd42ea7d
						
					
				
				
					commit
					b64c0d4d61
				
			| @@ -1,19 +1,66 @@ | |||||||
|  | import os | ||||||
| import subprocess | import subprocess | ||||||
|  |  | ||||||
|  | from django.core.files.temp import NamedTemporaryFile | ||||||
| from django.db.backends.base.client import BaseDatabaseClient | from django.db.backends.base.client import BaseDatabaseClient | ||||||
|  | from django.utils.six import print_ | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _escape_pgpass(txt): | ||||||
|  |     """ | ||||||
|  |     Escape a fragment of a PostgreSQL .pgpass file. | ||||||
|  |     """ | ||||||
|  |     return txt.replace('\\', '\\\\').replace(':', '\\:') | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseClient(BaseDatabaseClient): | class DatabaseClient(BaseDatabaseClient): | ||||||
|     executable_name = 'psql' |     executable_name = 'psql' | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def runshell_db(cls, settings_dict): | ||||||
|  |         args = [cls.executable_name] | ||||||
|  |  | ||||||
|  |         host = settings_dict.get('HOST', '') | ||||||
|  |         port = settings_dict.get('PORT', '') | ||||||
|  |         name = settings_dict.get('NAME', '') | ||||||
|  |         user = settings_dict.get('USER', '') | ||||||
|  |         passwd = settings_dict.get('PASSWORD', '') | ||||||
|  |  | ||||||
|  |         if user: | ||||||
|  |             args += ['-U', user] | ||||||
|  |         if host: | ||||||
|  |             args += ['-h', host] | ||||||
|  |         if port: | ||||||
|  |             args += ['-p', str(port)] | ||||||
|  |         args += [name] | ||||||
|  |  | ||||||
|  |         temp_pgpass = None | ||||||
|  |         try: | ||||||
|  |             if passwd: | ||||||
|  |                 # Create temporary .pgpass file. | ||||||
|  |                 temp_pgpass = NamedTemporaryFile(mode='w+') | ||||||
|  |                 try: | ||||||
|  |                     print_( | ||||||
|  |                         _escape_pgpass(host) or '*', | ||||||
|  |                         str(port) or '*', | ||||||
|  |                         _escape_pgpass(name) or '*', | ||||||
|  |                         _escape_pgpass(user) or '*', | ||||||
|  |                         _escape_pgpass(passwd), | ||||||
|  |                         file=temp_pgpass, | ||||||
|  |                         sep=':', | ||||||
|  |                         flush=True, | ||||||
|  |                     ) | ||||||
|  |                     os.environ['PGPASSFILE'] = temp_pgpass.name | ||||||
|  |                 except UnicodeEncodeError: | ||||||
|  |                     # If the current locale can't encode the data, we let | ||||||
|  |                     # the user input the password manually. | ||||||
|  |                     pass | ||||||
|  |             subprocess.call(args) | ||||||
|  |         finally: | ||||||
|  |             if temp_pgpass: | ||||||
|  |                 temp_pgpass.close() | ||||||
|  |                 if 'PGPASSFILE' in os.environ:  # unit tests need cleanup | ||||||
|  |                     del os.environ['PGPASSFILE'] | ||||||
|  |  | ||||||
|     def runshell(self): |     def runshell(self): | ||||||
|         settings_dict = self.connection.settings_dict |         DatabaseClient.runshell_db(self.connection.settings_dict) | ||||||
|         args = [self.executable_name] |  | ||||||
|         if settings_dict['USER']: |  | ||||||
|             args += ["-U", settings_dict['USER']] |  | ||||||
|         if settings_dict['HOST']: |  | ||||||
|             args.extend(["-h", settings_dict['HOST']]) |  | ||||||
|         if settings_dict['PORT']: |  | ||||||
|             args.extend(["-p", str(settings_dict['PORT'])]) |  | ||||||
|         args += [settings_dict['NAME']] |  | ||||||
|         subprocess.call(args) |  | ||||||
|   | |||||||
| @@ -350,6 +350,10 @@ Management Commands | |||||||
| * The :djadmin:`startapp` command creates an ``apps.py`` file and adds | * The :djadmin:`startapp` command creates an ``apps.py`` file and adds | ||||||
|   ``default_app_config`` in ``__init__.py``. |   ``default_app_config`` in ``__init__.py``. | ||||||
|  |  | ||||||
|  | * When using the PostgreSQL backend, the :djadmin:`dbshell` command can connect | ||||||
|  |   to the database using the password from your settings file (instead of | ||||||
|  |   requiring it to be manually entered). | ||||||
|  |  | ||||||
| Models | Models | ||||||
| ^^^^^^ | ^^^^^^ | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										117
									
								
								tests/dbshell/test_postgresql_psycopg2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								tests/dbshell/test_postgresql_psycopg2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | |||||||
|  | # -*- coding: utf8 -*- | ||||||
|  | from __future__ import unicode_literals | ||||||
|  |  | ||||||
|  | import locale | ||||||
|  | import os | ||||||
|  |  | ||||||
|  | from django.db.backends.postgresql_psycopg2.client import DatabaseClient | ||||||
|  | from django.test import SimpleTestCase, mock | ||||||
|  | from django.utils import six | ||||||
|  | from django.utils.encoding import force_bytes, force_str | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PostgreSqlDbshellCommandTestCase(SimpleTestCase): | ||||||
|  |  | ||||||
|  |     def _run_it(self, dbinfo): | ||||||
|  |         """ | ||||||
|  |         That function invokes the runshell command, while mocking | ||||||
|  |         subprocess.call. It returns a 2-tuple with: | ||||||
|  |         - The command line list | ||||||
|  |         - The binary content of file pointed by environment PGPASSFILE, or | ||||||
|  |           None. | ||||||
|  |         """ | ||||||
|  |         def _mock_subprocess_call(*args): | ||||||
|  |             self.subprocess_args = list(*args) | ||||||
|  |             if 'PGPASSFILE' in os.environ: | ||||||
|  |                 self.pgpass = open(os.environ['PGPASSFILE'], 'rb').read() | ||||||
|  |             else: | ||||||
|  |                 self.pgpass = None | ||||||
|  |             return 0 | ||||||
|  |         self.subprocess_args = None | ||||||
|  |         self.pgpass = None | ||||||
|  |         with mock.patch('subprocess.call', new=_mock_subprocess_call): | ||||||
|  |             DatabaseClient.runshell_db(dbinfo) | ||||||
|  |         return self.subprocess_args, self.pgpass | ||||||
|  |  | ||||||
|  |     def test_basic(self): | ||||||
|  |         self.assertEqual( | ||||||
|  |             self._run_it({ | ||||||
|  |                 'NAME': 'dbname', | ||||||
|  |                 'USER': 'someuser', | ||||||
|  |                 'PASSWORD': 'somepassword', | ||||||
|  |                 'HOST': 'somehost', | ||||||
|  |                 'PORT': 444, | ||||||
|  |             }), ( | ||||||
|  |                 ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], | ||||||
|  |                 b'somehost:444:dbname:someuser:somepassword\n', | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_nopass(self): | ||||||
|  |         self.assertEqual( | ||||||
|  |             self._run_it({ | ||||||
|  |                 'NAME': 'dbname', | ||||||
|  |                 'USER': 'someuser', | ||||||
|  |                 'HOST': 'somehost', | ||||||
|  |                 'PORT': 444, | ||||||
|  |             }), ( | ||||||
|  |                 ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], | ||||||
|  |                 None, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_column(self): | ||||||
|  |         self.assertEqual( | ||||||
|  |             self._run_it({ | ||||||
|  |                 'NAME': 'dbname', | ||||||
|  |                 'USER': 'some:user', | ||||||
|  |                 'PASSWORD': 'some:password', | ||||||
|  |                 'HOST': '::1', | ||||||
|  |                 'PORT': 444, | ||||||
|  |             }), ( | ||||||
|  |                 ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], | ||||||
|  |                 b'\\:\\:1:444:dbname:some\\:user:some\\:password\n', | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_escape_characters(self): | ||||||
|  |         self.assertEqual( | ||||||
|  |             self._run_it({ | ||||||
|  |                 'NAME': 'dbname', | ||||||
|  |                 'USER': 'some\\user', | ||||||
|  |                 'PASSWORD': 'some\\password', | ||||||
|  |                 'HOST': 'somehost', | ||||||
|  |                 'PORT': 444, | ||||||
|  |             }), ( | ||||||
|  |                 ['psql', '-U', 'some\\user', '-h', 'somehost', '-p', '444', 'dbname'], | ||||||
|  |                 b'somehost:444:dbname:some\\\\user:some\\\\password\n', | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_accent(self): | ||||||
|  |         # The pgpass temporary file needs to be encoded using the system locale. | ||||||
|  |         encoding = locale.getpreferredencoding() | ||||||
|  |         username = 'rôle' | ||||||
|  |         password = 'sésame' | ||||||
|  |         try: | ||||||
|  |             username_str = force_str(username, encoding) | ||||||
|  |             password_str = force_str(password, encoding) | ||||||
|  |             pgpass_bytes = force_bytes( | ||||||
|  |                 'somehost:444:dbname:%s:%s\n' % (username, password), | ||||||
|  |                 encoding=encoding, | ||||||
|  |             ) | ||||||
|  |         except UnicodeEncodeError: | ||||||
|  |             if six.PY2: | ||||||
|  |                 self.skipTest("Your locale can't run this test.") | ||||||
|  |         self.assertEqual( | ||||||
|  |             self._run_it({ | ||||||
|  |                 'NAME': 'dbname', | ||||||
|  |                 'USER': username_str, | ||||||
|  |                 'PASSWORD': password_str, | ||||||
|  |                 'HOST': 'somehost', | ||||||
|  |                 'PORT': 444, | ||||||
|  |             }), ( | ||||||
|  |                 ['psql', '-U', username_str, '-h', 'somehost', '-p', '444', 'dbname'], | ||||||
|  |                 pgpass_bytes, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
		Reference in New Issue
	
	Block a user