Commit b64c0d4d authored by Jean-Michel Vourgère's avatar Jean-Michel Vourgère Committed by Tim Graham
Browse files

Fixed #23658 -- Provided the password to PostgreSQL dbshell command

The password from settings.py is written in a temporary .pgpass file
file whose name is given to psql using the PGPASSFILE environment
variable.
parent eecd42ea
Loading
Loading
Loading
Loading
+57 −10
Original line number Diff line number Diff line
import os
import subprocess

from django.core.files.temp import NamedTemporaryFile
from django.db.backends.base.client import BaseDatabaseClient
from django.utils.six import print_


def _escape_pgpass(txt):
    """
    Escape a fragment of a PostgreSQL .pgpass file.
    """
    return txt.replace('\\', '\\\\').replace(':', '\\:')


class DatabaseClient(BaseDatabaseClient):
    executable_name = 'psql'

    def runshell(self):
        settings_dict = self.connection.settings_dict
        args = [self.executable_name]
        if settings_dict['USER']:
            args += ["-U", settings_dict['USER']]
        if settings_dict['HOST']:
            args.extend(["-h", settings_dict['HOST']])
        if settings_dict['PORT']:
            args.extend(["-p", str(settings_dict['PORT'])])
        args += [settings_dict['NAME']]
    @classmethod
    def runshell_db(cls, settings_dict):
        args = [cls.executable_name]

        host = settings_dict.get('HOST', '')
        port = settings_dict.get('PORT', '')
        name = settings_dict.get('NAME', '')
        user = settings_dict.get('USER', '')
        passwd = settings_dict.get('PASSWORD', '')

        if user:
            args += ['-U', user]
        if host:
            args += ['-h', host]
        if port:
            args += ['-p', str(port)]
        args += [name]

        temp_pgpass = None
        try:
            if passwd:
                # Create temporary .pgpass file.
                temp_pgpass = NamedTemporaryFile(mode='w+')
                try:
                    print_(
                        _escape_pgpass(host) or '*',
                        str(port) or '*',
                        _escape_pgpass(name) or '*',
                        _escape_pgpass(user) or '*',
                        _escape_pgpass(passwd),
                        file=temp_pgpass,
                        sep=':',
                        flush=True,
                    )
                    os.environ['PGPASSFILE'] = temp_pgpass.name
                except UnicodeEncodeError:
                    # If the current locale can't encode the data, we let
                    # the user input the password manually.
                    pass
            subprocess.call(args)
        finally:
            if temp_pgpass:
                temp_pgpass.close()
                if 'PGPASSFILE' in os.environ:  # unit tests need cleanup
                    del os.environ['PGPASSFILE']

    def runshell(self):
        DatabaseClient.runshell_db(self.connection.settings_dict)
+4 −0
Original line number Diff line number Diff line
@@ -350,6 +350,10 @@ Management Commands
* The :djadmin:`startapp` command creates an ``apps.py`` file and adds
  ``default_app_config`` in ``__init__.py``.

* When using the PostgreSQL backend, the :djadmin:`dbshell` command can connect
  to the database using the password from your settings file (instead of
  requiring it to be manually entered).

Models
^^^^^^

+117 −0
Original line number Diff line number Diff line
# -*- coding: utf8 -*-
from __future__ import unicode_literals

import locale
import os

from django.db.backends.postgresql_psycopg2.client import DatabaseClient
from django.test import SimpleTestCase, mock
from django.utils import six
from django.utils.encoding import force_bytes, force_str


class PostgreSqlDbshellCommandTestCase(SimpleTestCase):

    def _run_it(self, dbinfo):
        """
        That function invokes the runshell command, while mocking
        subprocess.call. It returns a 2-tuple with:
        - The command line list
        - The binary content of file pointed by environment PGPASSFILE, or
          None.
        """
        def _mock_subprocess_call(*args):
            self.subprocess_args = list(*args)
            if 'PGPASSFILE' in os.environ:
                self.pgpass = open(os.environ['PGPASSFILE'], 'rb').read()
            else:
                self.pgpass = None
            return 0
        self.subprocess_args = None
        self.pgpass = None
        with mock.patch('subprocess.call', new=_mock_subprocess_call):
            DatabaseClient.runshell_db(dbinfo)
        return self.subprocess_args, self.pgpass

    def test_basic(self):
        self.assertEqual(
            self._run_it({
                'NAME': 'dbname',
                'USER': 'someuser',
                'PASSWORD': 'somepassword',
                'HOST': 'somehost',
                'PORT': 444,
            }), (
                ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
                b'somehost:444:dbname:someuser:somepassword\n',
            )
        )

    def test_nopass(self):
        self.assertEqual(
            self._run_it({
                'NAME': 'dbname',
                'USER': 'someuser',
                'HOST': 'somehost',
                'PORT': 444,
            }), (
                ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
                None,
            )
        )

    def test_column(self):
        self.assertEqual(
            self._run_it({
                'NAME': 'dbname',
                'USER': 'some:user',
                'PASSWORD': 'some:password',
                'HOST': '::1',
                'PORT': 444,
            }), (
                ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'],
                b'\\:\\:1:444:dbname:some\\:user:some\\:password\n',
            )
        )

    def test_escape_characters(self):
        self.assertEqual(
            self._run_it({
                'NAME': 'dbname',
                'USER': 'some\\user',
                'PASSWORD': 'some\\password',
                'HOST': 'somehost',
                'PORT': 444,
            }), (
                ['psql', '-U', 'some\\user', '-h', 'somehost', '-p', '444', 'dbname'],
                b'somehost:444:dbname:some\\\\user:some\\\\password\n',
            )
        )

    def test_accent(self):
        # The pgpass temporary file needs to be encoded using the system locale.
        encoding = locale.getpreferredencoding()
        username = 'rôle'
        password = 'sésame'
        try:
            username_str = force_str(username, encoding)
            password_str = force_str(password, encoding)
            pgpass_bytes = force_bytes(
                'somehost:444:dbname:%s:%s\n' % (username, password),
                encoding=encoding,
            )
        except UnicodeEncodeError:
            if six.PY2:
                self.skipTest("Your locale can't run this test.")
        self.assertEqual(
            self._run_it({
                'NAME': 'dbname',
                'USER': username_str,
                'PASSWORD': password_str,
                'HOST': 'somehost',
                'PORT': 444,
            }), (
                ['psql', '-U', username_str, '-h', 'somehost', '-p', '444', 'dbname'],
                pgpass_bytes,
            )
        )