Commit 53353230 authored by Loic Bistuer's avatar Loic Bistuer
Browse files

Fixed #23663 -- Initialize output streams for BaseCommand in __init__().

This helps with testability of management commands.

Thanks to trac username daveoncode for the report and to
Tim Graham and Claude Paroz for the reviews.
parent 494cd857
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ class Command(BaseCommand):
    requires_system_checks = False

    def __init__(self, *args, **kwargs):
        super(BaseCommand, self).__init__(*args, **kwargs)
        super(Command, self).__init__(*args, **kwargs)
        self.copied_files = []
        self.symlinked_files = []
        self.unmodified_files = []
+26 −13
Original line number Diff line number Diff line
@@ -79,11 +79,20 @@ class OutputWrapper(object):
    """
    Wrapper around stdout/stderr
    """
    @property
    def style_func(self):
        return self._style_func

    @style_func.setter
    def style_func(self, style_func):
        if style_func and hasattr(self._out, 'isatty') and self._out.isatty():
            self._style_func = style_func
        else:
            self._style_func = lambda x: x

    def __init__(self, out, style_func=None, ending='\n'):
        self._out = out
        self.style_func = None
        if hasattr(out, 'isatty') and out.isatty():
            self.style_func = style_func
        self.ending = ending

    def __getattr__(self, name):
@@ -93,8 +102,7 @@ class OutputWrapper(object):
        ending = self.ending if ending is None else ending
        if ending and not msg.endswith(ending):
            msg += ending
        style_func = [f for f in (style_func, self.style_func, lambda x:x)
                      if f is not None][0]
        style_func = style_func or self.style_func
        self._out.write(force_str(style_func(msg)))


@@ -221,8 +229,14 @@ class BaseCommand(object):
    #
    # requires_system_checks = True

    def __init__(self):
    def __init__(self, stdout=None, stderr=None, no_color=False):
        self.stdout = OutputWrapper(stdout or sys.stdout)
        self.stderr = OutputWrapper(stderr or sys.stderr)
        if no_color:
            self.style = no_style()
        else:
            self.style = color_style()
            self.stderr.style_func = self.style.ERROR

        # `requires_model_validation` is deprecated in favor of
        # `requires_system_checks`. If both options are present, an error is
@@ -371,9 +385,7 @@ class BaseCommand(object):
            if options.traceback or not isinstance(e, CommandError):
                raise

            # self.stderr is not guaranteed to be set here
            stderr = getattr(self, 'stderr', OutputWrapper(sys.stderr, self.style.ERROR))
            stderr.write('%s: %s' % (e.__class__.__name__, e))
            self.stderr.write('%s: %s' % (e.__class__.__name__, e))
            sys.exit(1)

    def execute(self, *args, **options):
@@ -382,12 +394,13 @@ class BaseCommand(object):
        controlled by attributes ``self.requires_system_checks`` and
        ``self.requires_model_validation``, except if force-skipped).
        """
        self.stdout = OutputWrapper(options.get('stdout', sys.stdout))
        if options.get('no_color'):
            self.style = no_style()
            self.stderr = OutputWrapper(options.get('stderr', sys.stderr))
        else:
            self.stderr = OutputWrapper(options.get('stderr', sys.stderr), self.style.ERROR)
            self.stderr.style_func = None
        if options.get('stdout'):
            self.stdout = OutputWrapper(options['stdout'])
        if options.get('stderr'):
            self.stderr = OutputWrapper(options.get('stderr'), self.stderr.style_func)

        if self.can_import_settings:
            from django.conf import settings  # NOQA
+0 −9
Original line number Diff line number Diff line
from django.core.management.base import BaseCommand


class Command(BaseCommand):
    help = "Test color output"
    requires_system_checks = False

    def handle(self, **options):
        return self.style.SQL_KEYWORD('BEGIN')
+75 −4
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ import django
from django import conf, get_version
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.management import BaseCommand, CommandError, call_command
from django.core.management import BaseCommand, CommandError, call_command, color
from django.db import connection
from django.utils.encoding import force_text
from django.utils._os import npath, upath
@@ -1392,12 +1392,83 @@ class CommandTypes(AdminScriptTestCase):
        self.assertOutput(out, "Prints the CREATE TABLE, custom SQL and CREATE INDEX SQL statements for the\ngiven model module name(s).")
        self.assertEqual(out.count('optional arguments'), 1)

    def test_no_color(self):
    def test_command_color(self):
        class Command(BaseCommand):
            requires_system_checks = False

            def handle(self, *args, **options):
                self.stdout.write('Hello, world!', self.style.ERROR)
                self.stderr.write('Hello, world!', self.style.ERROR)

        out = StringIO()
        err = StringIO()
        command = Command(stdout=out, stderr=err)
        command.execute()
        if color.supports_color():
            self.assertIn('Hello, world!\n', out.getvalue())
            self.assertIn('Hello, world!\n', err.getvalue())
            self.assertNotEqual(out.getvalue(), 'Hello, world!\n')
            self.assertNotEqual(err.getvalue(), 'Hello, world!\n')
        else:
            self.assertEqual(out.getvalue(), 'Hello, world!\n')
            self.assertEqual(err.getvalue(), 'Hello, world!\n')

    def test_command_no_color(self):
        "--no-color prevent colorization of the output"
        class Command(BaseCommand):
            requires_system_checks = False

            def handle(self, *args, **options):
                self.stdout.write('Hello, world!', self.style.ERROR)
                self.stderr.write('Hello, world!', self.style.ERROR)

        out = StringIO()
        err = StringIO()
        command = Command(stdout=out, stderr=err, no_color=True)
        command.execute()
        self.assertEqual(out.getvalue(), 'Hello, world!\n')
        self.assertEqual(err.getvalue(), 'Hello, world!\n')

        call_command('color_command', no_color=True, stdout=out)
        self.assertEqual(out.getvalue(), 'BEGIN\n')
        out = StringIO()
        err = StringIO()
        command = Command(stdout=out, stderr=err)
        command.execute(no_color=True)
        self.assertEqual(out.getvalue(), 'Hello, world!\n')
        self.assertEqual(err.getvalue(), 'Hello, world!\n')

    def test_custom_stdout(self):
        class Command(BaseCommand):
            requires_system_checks = False

            def handle(self, *args, **options):
                self.stdout.write("Hello, World!")

        out = StringIO()
        command = Command(stdout=out)
        command.execute()
        self.assertEqual(out.getvalue(), "Hello, World!\n")
        out.truncate(0)
        new_out = StringIO()
        command.execute(stdout=new_out)
        self.assertEqual(out.getvalue(), "")
        self.assertEqual(new_out.getvalue(), "Hello, World!\n")

    def test_custom_stderr(self):
        class Command(BaseCommand):
            requires_system_checks = False

            def handle(self, *args, **options):
                self.stderr.write("Hello, World!")

        err = StringIO()
        command = Command(stderr=err)
        command.execute()
        self.assertEqual(err.getvalue(), "Hello, World!\n")
        err.truncate(0)
        new_err = StringIO()
        command.execute(stderr=new_err)
        self.assertEqual(err.getvalue(), "")
        self.assertEqual(new_err.getvalue(), "Hello, World!\n")

    def test_base_command(self):
        "User BaseCommands can execute when a label is provided"