From 4115288b4f7bbd694946a1ddef0f0ba85c03f9a1 Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Wed, 2 Mar 2016 17:12:56 -0800 Subject: [PATCH] Fixed #26315 -- Allowed call_command() to accept a Command object as the first argument. --- django/core/management/__init__.py | 33 +++++++++++++++++-------- docs/ref/django-admin.txt | 9 +++++-- docs/releases/1.10.txt | 3 +++ tests/admin_scripts/tests.py | 38 ++++++++++++++--------------- tests/auth_tests/test_management.py | 8 +++--- 5 files changed, 56 insertions(+), 35 deletions(-) diff --git a/django/core/management/__init__.py b/django/core/management/__init__.py index f9e11b48a6..2495d38642 100644 --- a/django/core/management/__init__.py +++ b/django/core/management/__init__.py @@ -82,22 +82,35 @@ def call_command(name, *args, **options): This is the primary API you should use for calling specific commands. + `name` may be a string or a command object. Using a string is preferred + unless the command object is required for further processing or testing. + Some examples: call_command('migrate') call_command('shell', plain=True) call_command('sqlmigrate', 'myapp') - """ - # Load the command object. - try: - app_name = get_commands()[name] - except KeyError: - raise CommandError("Unknown command: %r" % name) - if isinstance(app_name, BaseCommand): - # If the command is already loaded, use it directly. - command = app_name + from django.core.management.commands import flush + cmd = flush.Command() + call_command(cmd, verbosity=0, interactive=False) + # Do something with cmd ... + """ + if isinstance(name, BaseCommand): + # Command object passed in. + command = name + name = command.__class__.__module__.split('.')[-1] else: - command = load_command_class(app_name, name) + # Load the command object by name. + try: + app_name = get_commands()[name] + except KeyError: + raise CommandError("Unknown command: %r" % name) + + if isinstance(app_name, BaseCommand): + # If the command is already loaded, use it directly. + command = app_name + else: + command = load_command_class(app_name, name) # Simulate argument parsing to get the option defaults (see #10080 for details). parser = command.create_parser('', name) diff --git a/docs/ref/django-admin.txt b/docs/ref/django-admin.txt index 6a7752f40c..7ba33f5d17 100644 --- a/docs/ref/django-admin.txt +++ b/docs/ref/django-admin.txt @@ -1760,7 +1760,8 @@ Running management commands from your code To call a management command from code use ``call_command``. ``name`` - the name of the command to call. + the name of the command to call or a command object. Passing the name is + preferred unless the object is required for testing. ``*args`` a list of arguments accepted by the command. @@ -1771,8 +1772,11 @@ To call a management command from code use ``call_command``. Examples:: from django.core import management + from django.core.management.commands import loaddata + management.call_command('flush', verbosity=0, interactive=False) management.call_command('loaddata', 'test_data', verbosity=0) + management.call_command(loaddata.Command(), 'test_data', verbosity=0) Note that command options that take no arguments are passed as keywords with ``True`` or ``False``, as you can see with the ``interactive`` option above. @@ -1799,7 +1803,8 @@ value of the ``handle()`` method of the command. .. versionchanged:: 1.10 ``call_command()`` now returns the value received from the - ``command.handle()`` method. + ``command.handle()`` method. It now also accepts a command object as the + first argument. Output redirection ================== diff --git a/docs/releases/1.10.txt b/docs/releases/1.10.txt index 8622c34e89..0d1ea5b5d2 100644 --- a/docs/releases/1.10.txt +++ b/docs/releases/1.10.txt @@ -278,6 +278,9 @@ Management Commands :djadmin:`runserver` does, if the set of migrations on disk don't match the migrations in the database. +* To assist with testing, :func:`~django.core.management.call_command` now + accepts a command object as the first argument. + Migrations ~~~~~~~~~~ diff --git a/tests/admin_scripts/tests.py b/tests/admin_scripts/tests.py index 8a65b136c2..b7b1ecbf97 100644 --- a/tests/admin_scripts/tests.py +++ b/tests/admin_scripts/tests.py @@ -1309,52 +1309,52 @@ class ManageRunserver(AdminScriptTestCase): self.cmd = Command(stdout=self.output) self.cmd.run = monkey_run - def assertServerSettings(self, addr, port, ipv6=None, raw_ipv6=False): + def assertServerSettings(self, addr, port, ipv6=False, raw_ipv6=False): self.assertEqual(self.cmd.addr, addr) self.assertEqual(self.cmd.port, port) self.assertEqual(self.cmd.use_ipv6, ipv6) self.assertEqual(self.cmd._raw_ipv6, raw_ipv6) def test_runserver_addrport(self): - self.cmd.handle() + call_command(self.cmd) self.assertServerSettings('127.0.0.1', '8000') - self.cmd.handle(addrport="1.2.3.4:8000") + call_command(self.cmd, addrport="1.2.3.4:8000") self.assertServerSettings('1.2.3.4', '8000') - self.cmd.handle(addrport="7000") + call_command(self.cmd, addrport="7000") self.assertServerSettings('127.0.0.1', '7000') @unittest.skipUnless(socket.has_ipv6, "platform doesn't support IPv6") def test_runner_addrport_ipv6(self): - self.cmd.handle(addrport="", use_ipv6=True) + call_command(self.cmd, addrport="", use_ipv6=True) self.assertServerSettings('::1', '8000', ipv6=True, raw_ipv6=True) - self.cmd.handle(addrport="7000", use_ipv6=True) + call_command(self.cmd, addrport="7000", use_ipv6=True) self.assertServerSettings('::1', '7000', ipv6=True, raw_ipv6=True) - self.cmd.handle(addrport="[2001:0db8:1234:5678::9]:7000") + call_command(self.cmd, addrport="[2001:0db8:1234:5678::9]:7000") self.assertServerSettings('2001:0db8:1234:5678::9', '7000', ipv6=True, raw_ipv6=True) def test_runner_hostname(self): - self.cmd.handle(addrport="localhost:8000") + call_command(self.cmd, addrport="localhost:8000") self.assertServerSettings('localhost', '8000') - self.cmd.handle(addrport="test.domain.local:7000") + call_command(self.cmd, addrport="test.domain.local:7000") self.assertServerSettings('test.domain.local', '7000') @unittest.skipUnless(socket.has_ipv6, "platform doesn't support IPv6") def test_runner_hostname_ipv6(self): - self.cmd.handle(addrport="test.domain.local:7000", use_ipv6=True) + call_command(self.cmd, addrport="test.domain.local:7000", use_ipv6=True) self.assertServerSettings('test.domain.local', '7000', ipv6=True) def test_runner_ambiguous(self): # Only 4 characters, all of which could be in an ipv6 address - self.cmd.handle(addrport="beef:7654") + call_command(self.cmd, addrport="beef:7654") self.assertServerSettings('beef', '7654') # Uses only characters that could be in an ipv6 address - self.cmd.handle(addrport="deadbeef:7654") + call_command(self.cmd, addrport="deadbeef:7654") self.assertServerSettings('deadbeef', '7654') def test_no_database(self): @@ -1530,7 +1530,7 @@ class CommandTypes(AdminScriptTestCase): out = StringIO() err = StringIO() command = Command(stdout=out, stderr=err) - command.execute() + call_command(command) if color.supports_color(): self.assertIn('Hello, world!\n', out.getvalue()) self.assertIn('Hello, world!\n', err.getvalue()) @@ -1552,14 +1552,14 @@ class CommandTypes(AdminScriptTestCase): out = StringIO() err = StringIO() command = Command(stdout=out, stderr=err, no_color=True) - command.execute() + call_command(command) self.assertEqual(out.getvalue(), 'Hello, world!\n') self.assertEqual(err.getvalue(), 'Hello, world!\n') out = StringIO() err = StringIO() command = Command(stdout=out, stderr=err) - command.execute(no_color=True) + call_command(command, no_color=True) self.assertEqual(out.getvalue(), 'Hello, world!\n') self.assertEqual(err.getvalue(), 'Hello, world!\n') @@ -1572,11 +1572,11 @@ class CommandTypes(AdminScriptTestCase): out = StringIO() command = Command(stdout=out) - command.execute() + call_command(command) self.assertEqual(out.getvalue(), "Hello, World!\n") out.truncate(0) new_out = StringIO() - command.execute(stdout=new_out) + call_command(command, stdout=new_out) self.assertEqual(out.getvalue(), "") self.assertEqual(new_out.getvalue(), "Hello, World!\n") @@ -1589,11 +1589,11 @@ class CommandTypes(AdminScriptTestCase): err = StringIO() command = Command(stderr=err) - command.execute() + call_command(command) self.assertEqual(err.getvalue(), "Hello, World!\n") err.truncate(0) new_err = StringIO() - command.execute(stderr=new_err) + call_command(command, stderr=new_err) self.assertEqual(err.getvalue(), "") self.assertEqual(new_err.getvalue(), "Hello, World!\n") diff --git a/tests/auth_tests/test_management.py b/tests/auth_tests/test_management.py index 84a414a435..fe3963e221 100644 --- a/tests/auth_tests/test_management.py +++ b/tests/auth_tests/test_management.py @@ -342,8 +342,8 @@ class CreatesuperuserManagementCommandTestCase(TestCase): """ sentinel = object() command = createsuperuser.Command() - command.check = lambda: [] - command.execute( + call_command( + command, stdin=sentinel, stdout=six.StringIO(), stderr=six.StringIO(), @@ -355,8 +355,8 @@ class CreatesuperuserManagementCommandTestCase(TestCase): self.assertIs(command.stdin, sentinel) command = createsuperuser.Command() - command.check = lambda: [] - command.execute( + call_command( + command, stdout=six.StringIO(), stderr=six.StringIO(), interactive=False,