From 606bcf89b1a7d9f54d3c7b6e1c79e0bfa1de7b7f Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Fri, 27 Sep 2024 11:27:17 +0100 Subject: [PATCH] Fixed DEP 10 -- Implement Tasks interface --- django/conf/global_settings.py | 5 + django/tasks/__init__.py | 53 ++++ django/tasks/backends/__init__.py | 0 django/tasks/backends/base.py | 125 +++++++++ django/tasks/backends/dummy.py | 60 +++++ django/tasks/backends/immediate.py | 84 +++++++ django/tasks/checks.py | 14 ++ django/tasks/exceptions.py | 15 ++ django/tasks/signal_handlers.py | 49 ++++ django/tasks/signals.py | 4 + django/tasks/task.py | 267 ++++++++++++++++++++ django/tasks/utils.py | 84 +++++++ tests/tasks/__init__.py | 0 tests/tasks/is_global_function_fixture.py | 27 ++ tests/tasks/tasks.py | 71 ++++++ tests/tasks/test_dummy_backend.py | 233 +++++++++++++++++ tests/tasks/test_immediate_backend.py | 294 ++++++++++++++++++++++ tests/tasks/test_tasks.py | 254 +++++++++++++++++++ tests/tasks/test_utils.py | 178 +++++++++++++ 19 files changed, 1817 insertions(+) create mode 100644 django/tasks/__init__.py create mode 100644 django/tasks/backends/__init__.py create mode 100644 django/tasks/backends/base.py create mode 100644 django/tasks/backends/dummy.py create mode 100644 django/tasks/backends/immediate.py create mode 100644 django/tasks/checks.py create mode 100644 django/tasks/exceptions.py create mode 100644 django/tasks/signal_handlers.py create mode 100644 django/tasks/signals.py create mode 100644 django/tasks/task.py create mode 100644 django/tasks/utils.py create mode 100644 tests/tasks/__init__.py create mode 100644 tests/tasks/is_global_function_fixture.py create mode 100644 tests/tasks/tasks.py create mode 100644 tests/tasks/test_dummy_backend.py create mode 100644 tests/tasks/test_immediate_backend.py create mode 100644 tests/tasks/test_tasks.py create mode 100644 tests/tasks/test_utils.py diff --git a/django/conf/global_settings.py b/django/conf/global_settings.py index f4535acb09..1d20f84f5b 100644 --- a/django/conf/global_settings.py +++ b/django/conf/global_settings.py @@ -667,3 +667,8 @@ SECURE_REDIRECT_EXEMPT = [] SECURE_REFERRER_POLICY = "same-origin" SECURE_SSL_HOST = None SECURE_SSL_REDIRECT = False + +######### +# TASKS # +######### +TASKS = {"default": {"BACKEND": "django.tasks.backends.immediate.ImmediateBackend"}} diff --git a/django/tasks/__init__.py b/django/tasks/__init__.py new file mode 100644 index 0000000000..a57f284a98 --- /dev/null +++ b/django/tasks/__init__.py @@ -0,0 +1,53 @@ +from django.utils.connection import BaseConnectionHandler, ConnectionProxy +from django.utils.module_loading import import_string + +from . import checks, signal_handlers # noqa +from .backends.base import BaseTaskBackend +from .exceptions import InvalidTaskBackendError +from .task import ( + DEFAULT_QUEUE_NAME, + DEFAULT_TASK_BACKEND_ALIAS, + ResultStatus, + Task, + TaskResult, + task, +) + +__all__ = [ + "tasks", + "DEFAULT_TASK_BACKEND_ALIAS", + "DEFAULT_QUEUE_NAME", + "task", + "ResultStatus", + "Task", + "TaskResult", +] + + +class TasksHandler(BaseConnectionHandler): + settings_name = "TASKS" + exception_class = InvalidTaskBackendError + + def create_connection(self, alias): + params = self.settings[alias].copy() + + # Added back to allow a backend to self-identify + params["ALIAS"] = alias + + backend = params["BACKEND"] + + try: + backend_cls = import_string(backend) + except ImportError as e: + raise InvalidTaskBackendError( + f"Could not find backend '{backend}': {e}" + ) from e + + return backend_cls(params) + + +tasks = TasksHandler() + +default_task_backend: BaseTaskBackend = ConnectionProxy( + tasks, DEFAULT_TASK_BACKEND_ALIAS +) diff --git a/django/tasks/backends/__init__.py b/django/tasks/backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/django/tasks/backends/base.py b/django/tasks/backends/base.py new file mode 100644 index 0000000000..4555103a8f --- /dev/null +++ b/django/tasks/backends/base.py @@ -0,0 +1,125 @@ +from abc import ABCMeta, abstractmethod +from inspect import iscoroutinefunction + +from asgiref.sync import sync_to_async + +from django.core.checks import messages +from django.db import connections +from django.tasks.exceptions import InvalidTaskError +from django.tasks.task import MAX_PRIORITY, MIN_PRIORITY, Task +from django.tasks.utils import is_global_function +from django.test.testcases import _DatabaseFailure +from django.utils import timezone + + +class BaseTaskBackend(metaclass=ABCMeta): + task_class = Task + + supports_defer = False + """Can tasks be enqueued with the run_after attribute""" + + supports_async_task = False + """Can coroutines be enqueued""" + + supports_get_result = False + """Can results be retrieved after the fact (from **any** thread / process)""" + + def __init__(self, options): + from django.tasks import DEFAULT_QUEUE_NAME + + self.alias = options["ALIAS"] + self.queues = set(options.get("QUEUES", [DEFAULT_QUEUE_NAME])) + self.enqueue_on_commit = bool(options.get("ENQUEUE_ON_COMMIT", True)) + + def _get_enqueue_on_commit_for_task(self, task): + """ + Determine the correct `enqueue_on_commit` setting to use for a given task. + + If the task defines it, use that, otherwise, fall back to the backend. + """ + # If this project doesn't use a database, there's nothing to commit to + if not connections.settings: + return False + + # If connections are disabled during tests, there's nothing to commit to + for conn in connections.all(): + if isinstance(conn.connect, _DatabaseFailure): + return False + + if isinstance(task.enqueue_on_commit, bool): + return task.enqueue_on_commit + + return self.enqueue_on_commit + + def validate_task(self, task): + """ + Determine whether the provided task is one which can be executed by the backend. + """ + if not is_global_function(task.func): + raise InvalidTaskError( + "Task function must be a globally importable function" + ) + + if not self.supports_async_task and iscoroutinefunction(task.func): + raise InvalidTaskError("Backend does not support async tasks") + + if ( + task.priority < MIN_PRIORITY + or task.priority > MAX_PRIORITY + or int(task.priority) != task.priority + ): + raise InvalidTaskError( + f"priority must be a whole number between {MIN_PRIORITY} and " + f"{MAX_PRIORITY}" + ) + + if not self.supports_defer and task.run_after is not None: + raise InvalidTaskError("Backend does not support run_after") + + if task.run_after is not None and not timezone.is_aware(task.run_after): + raise InvalidTaskError("run_after must be an aware datetime") + + if self.queues and task.queue_name not in self.queues: + raise InvalidTaskError( + f"Queue '{task.queue_name}' is not valid for backend" + ) + + @abstractmethod + def enqueue(self, task, args, kwargs): + """ + Queue up a task to be executed + """ + ... + + async def aenqueue(self, task, args, kwargs): + """ + Queue up a task function (or coroutine) to be executed + """ + return await sync_to_async(self.enqueue, thread_sensitive=True)( + task=task, args=args, kwargs=kwargs + ) + + def get_result(self, result_id): + """ + Retrieve a result by its id (if one exists). + If one doesn't, raises ResultDoesNotExist. + """ + raise NotImplementedError( + "This backend does not support retrieving or refreshing results." + ) + + async def aget_result(self, result_id): + """ + Queue up a task function (or coroutine) to be executed + """ + return await sync_to_async(self.get_result, thread_sensitive=True)( + result_id=result_id + ) + + def check(self, **kwargs): + if self.enqueue_on_commit and not connections.settings: + yield messages.CheckMessage( + messages.ERROR, + "`ENQUEUE_ON_COMMIT` cannot be used when no databases are configured", + hint="Set `ENQUEUE_ON_COMMIT` to False", + ) diff --git a/django/tasks/backends/dummy.py b/django/tasks/backends/dummy.py new file mode 100644 index 0000000000..d6f063d239 --- /dev/null +++ b/django/tasks/backends/dummy.py @@ -0,0 +1,60 @@ +from copy import deepcopy +from functools import partial +from uuid import uuid4 + +from django.db import transaction +from django.tasks.exceptions import ResultDoesNotExist +from django.tasks.signals import task_enqueued +from django.tasks.task import ResultStatus, TaskResult +from django.tasks.utils import json_normalize +from django.utils import timezone + +from .base import BaseTaskBackend + + +class DummyBackend(BaseTaskBackend): + supports_defer = True + supports_async_task = True + + def __init__(self, options) -> None: + super().__init__(options) + + self.results = [] + + def _store_result(self, result) -> None: + object.__setattr__(result, "enqueued_at", timezone.now()) + self.results.append(result) + task_enqueued.send(type(self), task_result=result) + + def enqueue(self, task, args, kwargs) -> TaskResult: + self.validate_task(task) + + result = TaskResult( + task=task, + id=str(uuid4()), + status=ResultStatus.NEW, + enqueued_at=None, + started_at=None, + finished_at=None, + args=json_normalize(args), + kwargs=json_normalize(kwargs), + backend=self.alias, + ) + + if self._get_enqueue_on_commit_for_task(task) is not False: + transaction.on_commit(partial(self._store_result, result)) + else: + self._store_result(result) + + # Copy the task to prevent mutation issues + return deepcopy(result) + + # Don't set `supports_get_result` as the results are scoped to the current thread + def get_result(self, result_id): + try: + return next(result for result in self.results if result.id == result_id) + except StopIteration: + raise ResultDoesNotExist(result_id) from None + + def clear(self): + self.results.clear() diff --git a/django/tasks/backends/immediate.py b/django/tasks/backends/immediate.py new file mode 100644 index 0000000000..ca5f32478b --- /dev/null +++ b/django/tasks/backends/immediate.py @@ -0,0 +1,84 @@ +import logging +from functools import partial +from inspect import iscoroutinefunction +from uuid import uuid4 + +from asgiref.sync import async_to_sync + +from django.db import transaction +from django.tasks.signals import task_enqueued, task_finished +from django.tasks.task import ResultStatus, TaskResult +from django.tasks.utils import exception_to_dict, json_normalize +from django.utils import timezone + +from .base import BaseTaskBackend + +logger = logging.getLogger(__name__) + + +class ImmediateBackend(BaseTaskBackend): + supports_async_task = True + + def _execute_task(self, task_result): + """ + Execute the task for the given `TaskResult`, mutating it with the outcome + """ + object.__setattr__(task_result, "enqueued_at", timezone.now()) + task_enqueued.send(type(self), task_result=task_result) + + task = task_result.task + + calling_task_func = ( + async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func + ) + + object.__setattr__(task_result, "started_at", timezone.now()) + try: + object.__setattr__( + task_result, + "_return_value", + json_normalize( + calling_task_func(*task_result.args, **task_result.kwargs) + ), + ) + except BaseException as e: + # If the user tried to terminate, let them + if isinstance(e, KeyboardInterrupt): + raise + + object.__setattr__(task_result, "finished_at", timezone.now()) + try: + object.__setattr__(task_result, "_exception_data", exception_to_dict(e)) + except Exception: + logger.exception("Task id=%s unable to save exception", task_result.id) + + object.__setattr__(task_result, "status", ResultStatus.FAILED) + + task_finished.send(type(self), task_result=task_result) + else: + object.__setattr__(task_result, "finished_at", timezone.now()) + object.__setattr__(task_result, "status", ResultStatus.COMPLETE) + + task_finished.send(type(self), task_result=task_result) + + def enqueue(self, task, args, kwargs): + self.validate_task(task) + + task_result = TaskResult( + task=task, + id=str(uuid4()), + status=ResultStatus.NEW, + enqueued_at=None, + started_at=None, + finished_at=None, + args=json_normalize(args), + kwargs=json_normalize(kwargs), + backend=self.alias, + ) + + if self._get_enqueue_on_commit_for_task(task) is not False: + transaction.on_commit(partial(self._execute_task, task_result)) + else: + self._execute_task(task_result) + + return task_result diff --git a/django/tasks/checks.py b/django/tasks/checks.py new file mode 100644 index 0000000000..fd87f8602d --- /dev/null +++ b/django/tasks/checks.py @@ -0,0 +1,14 @@ +from django.core import checks + + +@checks.register +def check_tasks(app_configs=None, **kwargs): + """Checks all registered task backends.""" + + from django.tasks import tasks + + for backend in tasks.all(): + try: + yield from backend.check() + except NotImplementedError: + pass diff --git a/django/tasks/exceptions.py b/django/tasks/exceptions.py new file mode 100644 index 0000000000..ec1ab3823f --- /dev/null +++ b/django/tasks/exceptions.py @@ -0,0 +1,15 @@ +from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist + + +class InvalidTaskError(Exception): + """ + The provided task function is invalid. + """ + + +class InvalidTaskBackendError(ImproperlyConfigured): + pass + + +class ResultDoesNotExist(ObjectDoesNotExist): + pass diff --git a/django/tasks/signal_handlers.py b/django/tasks/signal_handlers.py new file mode 100644 index 0000000000..c5d3fbeb09 --- /dev/null +++ b/django/tasks/signal_handlers.py @@ -0,0 +1,49 @@ +import logging + +from asgiref.local import Local + +from django.core.signals import setting_changed +from django.dispatch import receiver + +from .signals import task_enqueued, task_finished +from .task import ResultStatus + +logger = logging.getLogger("django.tasks") + + +@receiver(setting_changed) +def clear_tasks_handlers(*, setting: str, **kwargs: dict) -> None: + """ + Reset the connection handler whenever the settings change + """ + if setting == "TASKS": + from django.tasks import tasks + + tasks._settings = tasks.settings = tasks.configure_settings(None) + tasks._connections = Local() + + +@receiver(task_enqueued) +def log_task_enqueued(sender, task_result, **kwargs): + logger.debug( + "Task id=%s path=%s enqueued backend=%s", + task_result.id, + task_result.task.module_path, + task_result.backend, + ) + + +@receiver(task_finished) +def log_task_finished(sender, task_result, **kwargs): + if task_result.status == ResultStatus.FAILED: + # Use `.exception` to integrate with error monitoring tools (eg Sentry) + log_method = logger.exception + else: + log_method = logger.info + + log_method( + "Task id=%s path=%s state=%s", + task_result.id, + task_result.task.module_path, + task_result.status, + ) diff --git a/django/tasks/signals.py b/django/tasks/signals.py new file mode 100644 index 0000000000..be1cf30b78 --- /dev/null +++ b/django/tasks/signals.py @@ -0,0 +1,4 @@ +from django.dispatch import Signal + +task_enqueued = Signal() +task_finished = Signal() diff --git a/django/tasks/task.py b/django/tasks/task.py new file mode 100644 index 0000000000..006045858b --- /dev/null +++ b/django/tasks/task.py @@ -0,0 +1,267 @@ +from dataclasses import dataclass, field, replace +from datetime import datetime, timedelta +from inspect import iscoroutinefunction +from typing import Any, Callable, Dict, Optional + +from asgiref.sync import async_to_sync, sync_to_async + +from django.db.models.enums import TextChoices +from django.utils import timezone +from django.utils.translation import gettext_lazy as _ + +from .exceptions import ResultDoesNotExist +from .utils import exception_from_dict, get_module_path + +DEFAULT_TASK_BACKEND_ALIAS = "default" +DEFAULT_QUEUE_NAME = "default" +MIN_PRIORITY = -100 +MAX_PRIORITY = 100 +DEFAULT_PRIORITY = 0 + +TASK_REFRESH_ATTRS = { + "_exception_data", + "_return_value", + "finished_at", + "started_at", + "status", + "enqueued_at", +} + + +class ResultStatus(TextChoices): + NEW = ("NEW", _("New")) + RUNNING = ("RUNNING", _("Running")) + FAILED = ("FAILED", _("Failed")) + COMPLETE = ("COMPLETE", _("Complete")) + + +@dataclass(frozen=True) +class Task: + priority: int + """The priority of the task""" + + func: Callable + """The task function""" + + backend: str + """The name of the backend the task will run on""" + + queue_name: str = DEFAULT_QUEUE_NAME + """The name of the queue the task will run on""" + + run_after: Optional[datetime] = None + """The earliest this task will run""" + + enqueue_on_commit: Optional[bool] = None + """ + Whether the task will be enqueued when the current transaction commits, + immediately, or whatever the backend decides + """ + + def __post_init__(self): + self.get_backend().validate_task(self) + + @property + def name(self): + """ + An identifier for the task + """ + return self.func.__name__ + + def using( + self, + *, + priority=None, + queue_name=None, + run_after=None, + backend=None, + ): + """ + Create a new task with modified defaults + """ + + changes = {} + + if priority is not None: + changes["priority"] = priority + if queue_name is not None: + changes["queue_name"] = queue_name + if run_after is not None: + if isinstance(run_after, timedelta): + changes["run_after"] = timezone.now() + run_after + else: + changes["run_after"] = run_after + if backend is not None: + changes["backend"] = backend + + return replace(self, **changes) + + def enqueue(self, *args, **kwargs): + """ + Queue up the task to be executed + """ + return self.get_backend().enqueue(self, args, kwargs) + + async def aenqueue(self, *args, **kwargs): + """ + Queue up a task function (or coroutine) to be executed + """ + return await self.get_backend().aenqueue(self, args, kwargs) + + def get_result(self, result_id): + """ + Retrieve the result for a task of this type by its id (if one exists). + If one doesn't, or is the wrong type, raises ResultDoesNotExist. + """ + result = self.get_backend().get_result(result_id) + + if result.task.func != self.func: + raise ResultDoesNotExist + + return result + + async def aget_result(self, result_id): + """ + Retrieve the result for a task of this type by its id (if one exists). + If one doesn't, or is the wrong type, raises ResultDoesNotExist. + """ + result = await self.get_backend().aget_result(result_id) + + if result.task.func != self.func: + raise ResultDoesNotExist + + return result + + def call(self, *args, **kwargs): + if iscoroutinefunction(self.func): + return async_to_sync(self.func)(*args, **kwargs) + return self.func(*args, **kwargs) + + async def acall(self, *args, **kwargs): + if iscoroutinefunction(self.func): + return await self.func(*args, **kwargs) + return await sync_to_async(self.func)(*args, **kwargs) + + def get_backend(self): + from . import tasks + + return tasks[self.backend] + + @property + def module_path(self): + return get_module_path(self.func) + + +# Implementation +def task( + function=None, + *, + priority=DEFAULT_PRIORITY, + queue_name=DEFAULT_QUEUE_NAME, + backend=DEFAULT_TASK_BACKEND_ALIAS, + enqueue_on_commit=None, +): + """ + A decorator used to create a task. + """ + from . import tasks + + def wrapper(f): + return tasks[backend].task_class( + priority=priority, + func=f, + queue_name=queue_name, + backend=backend, + enqueue_on_commit=enqueue_on_commit, + ) + + if function: + return wrapper(function) + + return wrapper + + +@dataclass(frozen=True) +class TaskResult: + task: Task + """The task for which this is a result""" + + id: str + """A unique identifier for the task result""" + + status: ResultStatus + """The status of the running task""" + + enqueued_at: Optional[datetime] + """The time this task was enqueued""" + + started_at: Optional[datetime] + """The time this task was started""" + + finished_at: Optional[datetime] + """The time this task was finished""" + + args: list + """The arguments to pass to the task function""" + + kwargs: Dict[str, Any] + """The keyword arguments to pass to the task function""" + + backend: str + """The name of the backend the task will run on""" + + _return_value: Optional[Any] = field(init=False, default=None) + _exception_data: Optional[Dict[str, Any]] = field(init=False, default=None) + + @property + def exception(self): + return ( + exception_from_dict(self._exception_data) + if self.status == ResultStatus.FAILED and self._exception_data is not None + else None + ) + + @property + def traceback(self): + """ + Return the string representation of the traceback of the task if it failed + """ + return ( + self._exception_data["exc_traceback"] + if self.status == ResultStatus.FAILED and self._exception_data is not None + else None + ) + + @property + def return_value(self): + """ + Get the return value of the task. + + If the task didn't complete successfully, an exception is raised. + This is to distinguish against the task returning None. + """ + if self.status == ResultStatus.FAILED: + raise ValueError("Task failed") + + elif self.status != ResultStatus.COMPLETE: + raise ValueError("Task has not finished yet") + + return self._return_value + + def refresh(self): + """ + Reload the cached task data from the task store + """ + refreshed_task = self.task.get_backend().get_result(self.id) + + for attr in TASK_REFRESH_ATTRS: + object.__setattr__(self, attr, getattr(refreshed_task, attr)) + + async def arefresh(self): + """ + Reload the cached task data from the task store + """ + refreshed_task = await self.task.get_backend().aget_result(self.id) + + for attr in TASK_REFRESH_ATTRS: + object.__setattr__(self, attr, getattr(refreshed_task, attr)) diff --git a/django/tasks/utils.py b/django/tasks/utils.py new file mode 100644 index 0000000000..c527ea2972 --- /dev/null +++ b/django/tasks/utils.py @@ -0,0 +1,84 @@ +import inspect +import json +import time +from collections import deque +from functools import wraps +from traceback import format_exception + +from django.utils.module_loading import import_string + + +def is_global_function(func): + if not inspect.isfunction(func) or inspect.isbuiltin(func): + return False + + if "" in func.__qualname__: + return False + + return True + + +def is_json_serializable(obj): + """ + Determine, as efficiently as possible, whether an object is JSON-serializable. + """ + try: + # HACK: JSON-encode an object, without loading it all into memory + deque(json.JSONEncoder().iterencode(obj), maxlen=0) + return True + except (TypeError, OverflowError): + return False + + +def json_normalize(obj): + """ + Round-trip encode object as JSON to normalize types. + """ + return json.loads(json.dumps(obj)) + + +def retry(*, retries=3, backoff_delay=0.1): + """ + Retry the given code `retries` times, raising the final error. + + `backoff_delay` can be used to add a delay between attempts. + """ + + def wrapper(f): + @wraps(f) + def inner_wrapper(*args, **kwargs): + for attempt in range(1, retries + 1): + try: + return f(*args, **kwargs) + except KeyboardInterrupt: + # Let the user ctrl-C out of the program without a retry + raise + except BaseException: + if attempt == retries: + raise + time.sleep(backoff_delay) + + return inner_wrapper + + return wrapper + + +def get_module_path(val): + return f"{val.__module__}.{val.__qualname__}" + + +def exception_to_dict(exc): + return { + "exc_type": get_module_path(type(exc)), + "exc_args": json_normalize(exc.args), + "exc_traceback": "".join(format_exception(type(exc), exc, exc.__traceback__)), + } + + +def exception_from_dict(exc_data): + exc_class = import_string(exc_data["exc_type"]) + + if not inspect.isclass(exc_class) or not issubclass(exc_class, BaseException): + raise TypeError(f"{type(exc_class)} is not an exception") + + return exc_class(*exc_data["exc_args"]) diff --git a/tests/tasks/__init__.py b/tests/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tasks/is_global_function_fixture.py b/tests/tasks/is_global_function_fixture.py new file mode 100644 index 0000000000..44bd423bee --- /dev/null +++ b/tests/tasks/is_global_function_fixture.py @@ -0,0 +1,27 @@ +""" +This file is used to test function is considered global even if it's not defined yet +because it's covered by a decorator. +""" + +from django.tasks.utils import is_global_function + + +@is_global_function +def really_global_function() -> None: + pass + + +inner_func_is_global_function = None + + +def main() -> None: + global inner_func_is_global_function + + @is_global_function + def inner_func() -> None: + pass + + inner_func_is_global_function = inner_func + + +main() diff --git a/tests/tasks/tasks.py b/tests/tasks/tasks.py new file mode 100644 index 0000000000..49dfa76301 --- /dev/null +++ b/tests/tasks/tasks.py @@ -0,0 +1,71 @@ +import time + +from django.tasks import task + + +@task() +def noop_task(*args: tuple, **kwargs: dict) -> None: + return None + + +@task +def noop_task_from_bare_decorator(*args: tuple, **kwargs: dict) -> None: + return None + + +@task() +async def noop_task_async(*args: tuple, **kwargs: dict) -> None: + return None + + +@task() +def calculate_meaning_of_life() -> int: + return 42 + + +@task() +def failing_task_value_error() -> None: + raise ValueError("This task failed due to ValueError") + + +@task() +def failing_task_system_exit() -> None: + raise SystemExit("This task failed due to SystemExit") + + +@task() +def failing_task_keyboard_interrupt() -> None: + raise KeyboardInterrupt("This task failed due to KeyboardInterrupt") + + +@task() +def complex_exception() -> None: + raise ValueError(ValueError("This task failed")) + + +@task() +def exit_task() -> None: + exit(1) + + +@task(enqueue_on_commit=True) +def enqueue_on_commit_task() -> None: + pass + + +@task(enqueue_on_commit=False) +def never_enqueue_on_commit_task() -> None: + pass + + +@task() +def hang() -> None: + """ + Do nothing for 5 minutes + """ + time.sleep(300) + + +@task() +def sleep_for(seconds: float) -> None: + time.sleep(seconds) diff --git a/tests/tasks/test_dummy_backend.py b/tests/tasks/test_dummy_backend.py new file mode 100644 index 0000000000..75a7f9020f --- /dev/null +++ b/tests/tasks/test_dummy_backend.py @@ -0,0 +1,233 @@ +from typing import cast + +from django.db import transaction +from django.tasks import ResultStatus, Task, default_task_backend, tasks +from django.tasks.backends.dummy import DummyBackend +from django.tasks.exceptions import ResultDoesNotExist +from django.test import SimpleTestCase, TransactionTestCase, override_settings + +from . import tasks as test_tasks + + +@override_settings( + TASKS={"default": {"BACKEND": "django.tasks.backends.dummy.DummyBackend"}} +) +class DummyBackendTestCase(SimpleTestCase): + def setUp(self): + default_task_backend.clear() # type:ignore[attr-defined] + + def test_using_correct_backend(self): + self.assertEqual(default_task_backend, tasks["default"]) + self.assertIsInstance(tasks["default"], DummyBackend) + + def test_enqueue_task(self): + for task in [test_tasks.noop_task, test_tasks.noop_task_async]: + with self.subTest(task): + result = cast(Task, task).enqueue(1, two=3) + + self.assertEqual(result.status, ResultStatus.NEW) + self.assertIsNone(result.started_at) + self.assertIsNone(result.finished_at) + with self.assertRaisesMessage(ValueError, "Task has not finished yet"): + result.return_value # noqa:B018 + self.assertEqual(result.task, task) + self.assertEqual(result.args, [1]) + self.assertEqual(result.kwargs, {"two": 3}) + + self.assertIn( + result, default_task_backend.results + ) # type:ignore[attr-defined] + + async def test_enqueue_task_async(self): + for task in [test_tasks.noop_task, test_tasks.noop_task_async]: + with self.subTest(task): + result = await cast(Task, task).aenqueue() + + self.assertEqual(result.status, ResultStatus.NEW) + self.assertIsNone(result.started_at) + self.assertIsNone(result.finished_at) + with self.assertRaisesMessage(ValueError, "Task has not finished yet"): + result.return_value # noqa:B018 + self.assertEqual(result.task, task) + self.assertEqual(result.args, []) + self.assertEqual(result.kwargs, {}) + + self.assertIn( + result, default_task_backend.results + ) # type:ignore[attr-defined] + + def test_get_result(self): + result = default_task_backend.enqueue(test_tasks.noop_task, (), {}) + + new_result = default_task_backend.get_result(result.id) + + self.assertEqual(result, new_result) + + async def test_get_result_async(self): + result = await default_task_backend.aenqueue(test_tasks.noop_task, (), {}) + + new_result = await default_task_backend.aget_result(result.id) + + self.assertEqual(result, new_result) + + def test_refresh_result(self): + result = default_task_backend.enqueue( + test_tasks.calculate_meaning_of_life, (), {} + ) + + enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined] + object.__setattr__(enqueued_result, "status", ResultStatus.COMPLETE) + + self.assertEqual(result.status, ResultStatus.NEW) + result.refresh() + self.assertEqual(result.status, ResultStatus.COMPLETE) + + async def test_refresh_result_async(self): + result = await default_task_backend.aenqueue( + test_tasks.calculate_meaning_of_life, (), {} + ) + + enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined] + object.__setattr__(enqueued_result, "status", ResultStatus.COMPLETE) + + self.assertEqual(result.status, ResultStatus.NEW) + await result.arefresh() + self.assertEqual(result.status, ResultStatus.COMPLETE) + + async def test_get_missing_result(self): + with self.assertRaises(ResultDoesNotExist): + default_task_backend.get_result("123") + + with self.assertRaises(ResultDoesNotExist): + await default_task_backend.aget_result("123") + + def test_enqueue_on_commit(self): + self.assertFalse( + default_task_backend._get_enqueue_on_commit_for_task( + test_tasks.enqueue_on_commit_task + ) + ) + + def test_enqueue_logs(self): + with self.assertLogs("django.tasks", level="DEBUG") as captured_logs: + result = test_tasks.noop_task.enqueue() + + self.assertEqual(len(captured_logs.output), 1) + self.assertIn("enqueued", captured_logs.output[0]) + self.assertIn(result.id, captured_logs.output[0]) + + +class DummyBackendTransactionTestCase(TransactionTestCase): + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.dummy.DummyBackend", + "ENQUEUE_ON_COMMIT": True, + } + } + ) + def test_wait_until_transaction_commit(self): + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task(test_tasks.noop_task) + ) + + with transaction.atomic(): + test_tasks.noop_task.enqueue() + + self.assertEqual( + len(default_task_backend.results), 0 + ) # type:ignore[attr-defined] + + self.assertEqual( + len(default_task_backend.results), 1 + ) # type:ignore[attr-defined] + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.dummy.DummyBackend", + "ENQUEUE_ON_COMMIT": False, + } + } + ) + def test_doesnt_wait_until_transaction_commit(self): + self.assertFalse(default_task_backend.enqueue_on_commit) + self.assertFalse( + default_task_backend._get_enqueue_on_commit_for_task(test_tasks.noop_task) + ) + + with transaction.atomic(): + result = test_tasks.noop_task.enqueue() + + self.assertIsNotNone(result.enqueued_at) + + self.assertEqual( + len(default_task_backend.results), 1 + ) # type:ignore[attr-defined] + + self.assertEqual( + len(default_task_backend.results), 1 + ) # type:ignore[attr-defined] + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.dummy.DummyBackend", + } + } + ) + def test_wait_until_transaction_by_default(self): + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task(test_tasks.noop_task) + ) + + with transaction.atomic(): + result = test_tasks.noop_task.enqueue() + + self.assertIsNone(result.enqueued_at) + + self.assertEqual( + len(default_task_backend.results), 0 + ) # type:ignore[attr-defined] + + self.assertEqual( + len(default_task_backend.results), 1 + ) # type:ignore[attr-defined] + self.assertIsNone(result.enqueued_at) + result.refresh() + self.assertIsNotNone(result.enqueued_at) + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.dummy.DummyBackend", + "ENQUEUE_ON_COMMIT": False, + } + } + ) + def test_task_specific_enqueue_on_commit(self): + self.assertFalse(default_task_backend.enqueue_on_commit) + self.assertTrue(test_tasks.enqueue_on_commit_task.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task( + test_tasks.enqueue_on_commit_task + ) + ) + + with transaction.atomic(): + result = test_tasks.enqueue_on_commit_task.enqueue() + + self.assertIsNone(result.enqueued_at) + + self.assertEqual( + len(default_task_backend.results), 0 + ) # type:ignore[attr-defined] + + self.assertEqual( + len(default_task_backend.results), 1 + ) # type:ignore[attr-defined] + self.assertIsNone(result.enqueued_at) + result.refresh() + self.assertIsNotNone(result.enqueued_at) diff --git a/tests/tasks/test_immediate_backend.py b/tests/tasks/test_immediate_backend.py new file mode 100644 index 0000000000..822414742e --- /dev/null +++ b/tests/tasks/test_immediate_backend.py @@ -0,0 +1,294 @@ +from typing import cast + +from django.db import transaction +from django.tasks import ResultStatus, Task, default_task_backend, tasks +from django.tasks.backends.immediate import ImmediateBackend +from django.tasks.exceptions import InvalidTaskError +from django.test import SimpleTestCase, TransactionTestCase, override_settings +from django.utils import timezone + +from . import tasks as test_tasks + + +@override_settings( + TASKS={"default": {"BACKEND": "django.tasks.backends.immediate.ImmediateBackend"}} +) +class ImmediateBackendTestCase(SimpleTestCase): + def test_using_correct_backend(self): + self.assertEqual(default_task_backend, tasks["default"]) + self.assertIsInstance(tasks["default"], ImmediateBackend) + + def test_enqueue_task(self): + for task in [test_tasks.noop_task, test_tasks.noop_task_async]: + with self.subTest(task): + result = cast(Task, task).enqueue(1, two=3) + + self.assertEqual(result.status, ResultStatus.COMPLETE) + self.assertIsNotNone(result.started_at) + self.assertIsNotNone(result.finished_at) + self.assertGreaterEqual( + result.started_at, result.enqueued_at + ) # type:ignore[arg-type, misc] + self.assertGreaterEqual( + result.finished_at, result.started_at + ) # type:ignore[arg-type, misc] + self.assertIsNone(result.return_value) + self.assertEqual(result.task, task) + self.assertEqual(result.args, [1]) + self.assertEqual(result.kwargs, {"two": 3}) + + async def test_enqueue_task_async(self): + for task in [test_tasks.noop_task, test_tasks.noop_task_async]: + with self.subTest(task): + result = await cast(Task, task).aenqueue() + + self.assertEqual(result.status, ResultStatus.COMPLETE) + self.assertIsNotNone(result.started_at) + self.assertIsNotNone(result.finished_at) + self.assertGreaterEqual( + result.started_at, result.enqueued_at + ) # type:ignore[arg-type, misc] + self.assertGreaterEqual( + result.finished_at, result.started_at + ) # type:ignore[arg-type, misc] + self.assertIsNone(result.return_value) + self.assertEqual(result.task, task) + self.assertEqual(result.args, []) + self.assertEqual(result.kwargs, {}) + + def test_catches_exception(self): + test_data = [ + ( + test_tasks.failing_task_value_error, # task function + ValueError, # expected exception + "This task failed due to ValueError", # expected message + ), + ( + test_tasks.failing_task_system_exit, + SystemExit, + "This task failed due to SystemExit", + ), + ] + for task, exception, message in test_data: + with ( + self.subTest(task), + self.assertLogs("django.tasks", level="ERROR") as captured_logs, + ): + result = task.enqueue() + + # assert logging + self.assertEqual(len(captured_logs.output), 1) + self.assertIn(message, captured_logs.output[0]) + + # assert result + self.assertEqual(result.status, ResultStatus.FAILED) + self.assertIsNotNone(result.started_at) + self.assertIsNotNone(result.finished_at) + self.assertGreaterEqual( + result.started_at, result.enqueued_at + ) # type:ignore[arg-type, misc] + self.assertGreaterEqual( + result.finished_at, result.started_at + ) # type:ignore[arg-type, misc] + self.assertIsInstance(result.exception, exception) + self.assertTrue( + result.traceback + and result.traceback.endswith(f"{exception.__name__}: {message}\n") + ) + self.assertEqual(result.task, task) + self.assertEqual(result.args, []) + self.assertEqual(result.kwargs, {}) + + def test_throws_keyboard_interrupt(self): + with self.assertRaises(KeyboardInterrupt): + with self.assertLogs("django.tasks", level="ERROR") as captured_logs: + default_task_backend.enqueue( + test_tasks.failing_task_keyboard_interrupt, [], {} + ) + + # assert logging + self.assertEqual(len(captured_logs.output), 0) + + def test_complex_exception(self): + with self.assertLogs("django.tasks", level="ERROR"): + result = test_tasks.complex_exception.enqueue() + + self.assertEqual(result.status, ResultStatus.FAILED) + self.assertIsNotNone(result.started_at) + self.assertIsNotNone(result.finished_at) + self.assertGreaterEqual( + result.started_at, result.enqueued_at + ) # type:ignore[arg-type,misc] + self.assertGreaterEqual( + result.finished_at, result.started_at + ) # type:ignore[arg-type,misc] + + self.assertIsNone(result._return_value) + self.assertIsNone(result.traceback) + + self.assertEqual(result.task, test_tasks.complex_exception) + self.assertEqual(result.args, []) + self.assertEqual(result.kwargs, {}) + + def test_result(self): + result = default_task_backend.enqueue( + test_tasks.calculate_meaning_of_life, [], {} + ) + + self.assertEqual(result.status, ResultStatus.COMPLETE) + self.assertEqual(result.return_value, 42) + + async def test_result_async(self): + result = await default_task_backend.aenqueue( + test_tasks.calculate_meaning_of_life, [], {} + ) + + self.assertEqual(result.status, ResultStatus.COMPLETE) + self.assertEqual(result.return_value, 42) + + async def test_cannot_get_result(self): + with self.assertRaisesMessage( + NotImplementedError, + "This backend does not support retrieving or refreshing results.", + ): + default_task_backend.get_result("123") + + with self.assertRaisesMessage( + NotImplementedError, + "This backend does not support retrieving or refreshing results.", + ): + await default_task_backend.aget_result(123) # type:ignore[arg-type] + + async def test_cannot_refresh_result(self): + result = await default_task_backend.aenqueue( + test_tasks.calculate_meaning_of_life, (), {} + ) + + with self.assertRaisesMessage( + NotImplementedError, + "This backend does not support retrieving or refreshing results.", + ): + await result.arefresh() + + with self.assertRaisesMessage( + NotImplementedError, + "This backend does not support retrieving or refreshing results.", + ): + result.refresh() + + def test_cannot_pass_run_after(self): + with self.assertRaisesMessage( + InvalidTaskError, + "Backend does not support run_after", + ): + default_task_backend.validate_task( + test_tasks.failing_task_value_error.using(run_after=timezone.now()) + ) + + def test_enqueue_on_commit(self): + self.assertFalse( + default_task_backend._get_enqueue_on_commit_for_task( + test_tasks.enqueue_on_commit_task + ) + ) + + def test_enqueue_logs(self): + with self.assertLogs("django.tasks", level="DEBUG") as captured_logs: + result = test_tasks.noop_task.enqueue() + + self.assertIn("enqueued", captured_logs.output[0]) + self.assertIn(result.id, captured_logs.output[0]) + + +class ImmediateBackendTransactionTestCase(TransactionTestCase): + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.immediate.ImmediateBackend", + "ENQUEUE_ON_COMMIT": True, + } + } + ) + def test_wait_until_transaction_commit(self): + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task(test_tasks.noop_task) + ) + + with transaction.atomic(): + result = test_tasks.noop_task.enqueue() + + self.assertIsNone(result.enqueued_at) + self.assertEqual(result.status, ResultStatus.NEW) + + self.assertEqual(result.status, ResultStatus.COMPLETE) + self.assertIsNotNone(result.enqueued_at) + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.immediate.ImmediateBackend", + "ENQUEUE_ON_COMMIT": False, + } + } + ) + def test_doesnt_wait_until_transaction_commit(self): + self.assertFalse(default_task_backend.enqueue_on_commit) + self.assertFalse( + default_task_backend._get_enqueue_on_commit_for_task(test_tasks.noop_task) + ) + + with transaction.atomic(): + result = test_tasks.noop_task.enqueue() + + self.assertIsNotNone(result.enqueued_at) + + self.assertEqual(result.status, ResultStatus.COMPLETE) + + self.assertEqual(result.status, ResultStatus.COMPLETE) + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.immediate.ImmediateBackend", + } + } + ) + def test_wait_until_transaction_by_default(self): + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task(test_tasks.noop_task) + ) + + with transaction.atomic(): + result = test_tasks.noop_task.enqueue() + + self.assertIsNone(result.enqueued_at) + self.assertEqual(result.status, ResultStatus.NEW) + + self.assertEqual(result.status, ResultStatus.COMPLETE) + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.immediate.ImmediateBackend", + "ENQUEUE_ON_COMMIT": False, + } + } + ) + def test_task_specific_enqueue_on_commit(self): + self.assertFalse(default_task_backend.enqueue_on_commit) + self.assertTrue(test_tasks.enqueue_on_commit_task.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task( + test_tasks.enqueue_on_commit_task + ) + ) + + with transaction.atomic(): + result = test_tasks.enqueue_on_commit_task.enqueue() + + self.assertIsNone(result.enqueued_at) + self.assertEqual(result.status, ResultStatus.NEW) + + self.assertEqual(result.status, ResultStatus.COMPLETE) diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py new file mode 100644 index 0000000000..ef458a7173 --- /dev/null +++ b/tests/tasks/test_tasks.py @@ -0,0 +1,254 @@ +import dataclasses +from datetime import datetime, timedelta + +from django.tasks import ( + DEFAULT_QUEUE_NAME, + ResultStatus, + Task, + default_task_backend, + task, + tasks, +) +from django.tasks.backends.dummy import DummyBackend +from django.tasks.backends.immediate import ImmediateBackend +from django.tasks.exceptions import ( + InvalidTaskBackendError, + InvalidTaskError, + ResultDoesNotExist, +) +from django.tasks.task import MAX_PRIORITY, MIN_PRIORITY +from django.test import SimpleTestCase, override_settings +from django.utils import timezone +from django.utils.module_loading import import_string + +from . import tasks as test_tasks + + +@override_settings( + TASKS={ + "default": { + "BACKEND": "django.tasks.backends.dummy.DummyBackend", + "QUEUES": ["default", "queue_1"], + }, + "immediate": {"BACKEND": "django.tasks.backends.immediate.ImmediateBackend"}, + "missing": {"BACKEND": "does.not.exist"}, + }, + USE_TZ=True, +) +class TaskTestCase(SimpleTestCase): + def setUp(self): + default_task_backend.clear() # type:ignore[attr-defined] + + def test_using_correct_backend(self): + self.assertEqual(default_task_backend, tasks["default"]) + self.assertIsInstance(tasks["default"], DummyBackend) + + def test_task_decorator(self): + self.assertIsInstance(test_tasks.noop_task, Task) + self.assertIsInstance(test_tasks.noop_task_async, Task) + self.assertIsInstance(test_tasks.noop_task_from_bare_decorator, Task) + + def test_enqueue_task(self): + result = test_tasks.noop_task.enqueue() + + self.assertEqual(result.status, ResultStatus.NEW) + self.assertEqual(result.task, test_tasks.noop_task) + self.assertEqual(result.args, []) + self.assertEqual(result.kwargs, {}) + + self.assertEqual( + default_task_backend.results, [result] + ) # type:ignore[attr-defined] + + async def test_enqueue_task_async(self): + result = await test_tasks.noop_task.aenqueue() + + self.assertEqual(result.status, ResultStatus.NEW) + self.assertEqual(result.task, test_tasks.noop_task) + self.assertEqual(result.args, []) + self.assertEqual(result.kwargs, {}) + + self.assertEqual( + default_task_backend.results, [result] + ) # type:ignore[attr-defined] + + def test_using_priority(self): + self.assertEqual(test_tasks.noop_task.priority, 0) + self.assertEqual(test_tasks.noop_task.using(priority=1).priority, 1) + self.assertEqual(test_tasks.noop_task.priority, 0) + + def test_using_queue_name(self): + self.assertEqual(test_tasks.noop_task.queue_name, DEFAULT_QUEUE_NAME) + self.assertEqual( + test_tasks.noop_task.using(queue_name="queue_1").queue_name, "queue_1" + ) + self.assertEqual(test_tasks.noop_task.queue_name, DEFAULT_QUEUE_NAME) + + def test_using_run_after(self): + now = timezone.now() + + self.assertIsNone(test_tasks.noop_task.run_after) + self.assertEqual(test_tasks.noop_task.using(run_after=now).run_after, now) + self.assertIsInstance( + test_tasks.noop_task.using(run_after=timedelta(hours=1)).run_after, + datetime, + ) + self.assertIsNone(test_tasks.noop_task.run_after) + + def test_using_unknown_backend(self): + self.assertEqual(test_tasks.noop_task.backend, "default") + + with self.assertRaisesMessage( + InvalidTaskBackendError, "The connection 'unknown' doesn't exist." + ): + test_tasks.noop_task.using(backend="unknown") + + def test_using_missing_backend(self): + self.assertEqual(test_tasks.noop_task.backend, "default") + + with self.assertRaisesMessage( + InvalidTaskBackendError, + "Could not find backend 'does.not.exist': No module named 'does'", + ): + test_tasks.noop_task.using(backend="missing") + + def test_using_creates_new_instance(self): + new_task = test_tasks.noop_task.using() + + self.assertEqual(new_task, test_tasks.noop_task) + self.assertIsNot(new_task, test_tasks.noop_task) + + def test_chained_using(self): + now = timezone.now() + + run_after_task = test_tasks.noop_task.using(run_after=now) + self.assertEqual(run_after_task.run_after, now) + + priority_task = run_after_task.using(priority=10) + self.assertEqual(priority_task.priority, 10) + self.assertEqual(priority_task.run_after, now) + + self.assertEqual(run_after_task.priority, 0) + + async def test_refresh_result(self): + result = await test_tasks.noop_task.aenqueue() + + original_result = dataclasses.asdict(result) + + result.refresh() + + self.assertEqual(dataclasses.asdict(result), original_result) + + await result.arefresh() + + self.assertEqual(dataclasses.asdict(result), original_result) + + def test_naive_datetime(self): + with self.assertRaisesMessage( + InvalidTaskError, "run_after must be an aware datetime" + ): + test_tasks.noop_task.using(run_after=datetime.now()) + + def test_invalid_priority(self): + with self.assertRaisesMessage( + InvalidTaskError, + f"priority must be a whole number between {MIN_PRIORITY} and " + f"{MAX_PRIORITY}", + ): + test_tasks.noop_task.using(priority=-101) + + with self.assertRaisesMessage( + InvalidTaskError, + f"priority must be a whole number between {MIN_PRIORITY} and " + f"{MAX_PRIORITY}", + ): + test_tasks.noop_task.using(priority=101) + + with self.assertRaisesMessage( + InvalidTaskError, + f"priority must be a whole number between {MIN_PRIORITY} and " + f"{MAX_PRIORITY}", + ): + test_tasks.noop_task.using(priority=3.1) # type:ignore[arg-type] + + test_tasks.noop_task.using(priority=100) + test_tasks.noop_task.using(priority=-100) + test_tasks.noop_task.using(priority=0) + + def test_call_task(self): + self.assertEqual(test_tasks.calculate_meaning_of_life.call(), 42) + + async def test_call_task_async(self): + self.assertEqual(await test_tasks.calculate_meaning_of_life.acall(), 42) + + async def test_call_async_task(self): + self.assertIsNone(await test_tasks.noop_task_async.acall()) + + def test_call_async_task_sync(self): + self.assertIsNone(test_tasks.noop_task_async.call()) + + def test_get_result(self): + result = default_task_backend.enqueue(test_tasks.noop_task, (), {}) + + new_result = test_tasks.noop_task.get_result(result.id) + + self.assertEqual(result, new_result) + + async def test_get_result_async(self): + result = await default_task_backend.aenqueue(test_tasks.noop_task, (), {}) + + new_result = await test_tasks.noop_task.aget_result(result.id) + + self.assertEqual(result, new_result) + + async def test_get_missing_result(self): + with self.assertRaises(ResultDoesNotExist): + test_tasks.noop_task.get_result("123") + + with self.assertRaises(ResultDoesNotExist): + await test_tasks.noop_task.aget_result("123") + + def test_get_incorrect_result(self): + result = default_task_backend.enqueue(test_tasks.noop_task_async, (), {}) + with self.assertRaises(ResultDoesNotExist): + test_tasks.noop_task.get_result(result.id) + + async def test_get_incorrect_result_async(self): + result = await default_task_backend.aenqueue(test_tasks.noop_task_async, (), {}) + with self.assertRaises(ResultDoesNotExist): + await test_tasks.noop_task.aget_result(result.id) + + def test_invalid_function(self): + for invalid_function in [any, self.test_invalid_function]: + with self.subTest(invalid_function): + with self.assertRaisesMessage( + InvalidTaskError, + "Task function must be a globally importable function", + ): + task()(invalid_function) # type:ignore[arg-type] + + def test_get_backend(self): + self.assertEqual(test_tasks.noop_task.backend, "default") + self.assertIsInstance(test_tasks.noop_task.get_backend(), DummyBackend) + + immediate_task = test_tasks.noop_task.using(backend="immediate") + self.assertEqual(immediate_task.backend, "immediate") + self.assertIsInstance(immediate_task.get_backend(), ImmediateBackend) + + def test_name(self): + self.assertEqual(test_tasks.noop_task.name, "noop_task") + self.assertEqual(test_tasks.noop_task_async.name, "noop_task_async") + + def test_module_path(self): + self.assertEqual(test_tasks.noop_task.module_path, "tasks.tasks.noop_task") + self.assertEqual( + test_tasks.noop_task_async.module_path, "tasks.tasks.noop_task_async" + ) + + self.assertIs( + import_string(test_tasks.noop_task.module_path), test_tasks.noop_task + ) + self.assertIs( + import_string(test_tasks.noop_task_async.module_path), + test_tasks.noop_task_async, + ) diff --git a/tests/tasks/test_utils.py b/tests/tasks/test_utils.py new file mode 100644 index 0000000000..c22e21ec20 --- /dev/null +++ b/tests/tasks/test_utils.py @@ -0,0 +1,178 @@ +import datetime +import hashlib +import optparse +import subprocess +from typing import List +from unittest.mock import Mock + +from django.core.exceptions import ImproperlyConfigured +from django.tasks import utils +from django.tasks.exceptions import InvalidTaskError +from django.test import SimpleTestCase + +from . import tasks as test_tasks + + +class IsGlobalFunctionTestCase(SimpleTestCase): + def test_builtin(self): + self.assertFalse(utils.is_global_function(any)) + self.assertFalse(utils.is_global_function(isinstance)) + + def test_from_module(self): + self.assertTrue(utils.is_global_function(subprocess.run)) + self.assertTrue(utils.is_global_function(subprocess.check_output)) + self.assertTrue(utils.is_global_function(test_tasks.noop_task.func)) + + def test_private_function(self): + def private_function(): + pass + + self.assertFalse(utils.is_global_function(private_function)) + + def test_coroutine(self): + self.assertTrue(utils.is_global_function(test_tasks.noop_task_async.func)) + + def test_method(self): + self.assertFalse(utils.is_global_function(self.test_method)) + self.assertFalse(utils.is_global_function(self.setUp)) + + def test_lambda(self): + self.assertFalse(utils.is_global_function(lambda: True)) + + def test_uninitialised_method(self): + # This import has to be here, so the module is loaded during the test + from . import is_global_function_fixture + + self.assertTrue(is_global_function_fixture.really_global_function) + self.assertIsNotNone(is_global_function_fixture.inner_func_is_global_function) + self.assertFalse(is_global_function_fixture.inner_func_is_global_function) + + +class IsJSONSerializableTestCase(SimpleTestCase): + def test_serializable(self): + for example in [123, 12.3, "123", {"123": 456}, [], None]: + with self.subTest(example): + self.assertTrue(utils.is_json_serializable(example)) + + def test_not_serializable(self): + for example in [ + self, + any, + datetime.datetime.now(), + ]: + with self.subTest(example): + self.assertFalse(utils.is_json_serializable(example)) + + +class JSONNormalizeTestCase(SimpleTestCase): + def test_round_trip(self): + self.assertEqual(utils.json_normalize({}), {}) + self.assertEqual(utils.json_normalize([]), []) + self.assertEqual(utils.json_normalize(()), []) + self.assertEqual(utils.json_normalize({"foo": ()}), {"foo": []}) + + def test_encode_error(self): + for example in [self, any, datetime.datetime.now()]: + with self.subTest(example): + self.assertFalse(utils.is_json_serializable(example)) + self.assertRaises(TypeError, utils.json_normalize, example) + + +class RetryTestCase(SimpleTestCase): + def test_retry(self): + sentinel = Mock(side_effect=ValueError("")) + + with self.assertRaises(ValueError): + utils.retry()(sentinel)() + + self.assertEqual(sentinel.call_count, 3) + + def test_keeps_return_value(self): + self.assertTrue(utils.retry()(lambda: True)()) + self.assertFalse(utils.retry()(lambda: False)()) + + +class ExceptionSerializationTestCase(SimpleTestCase): + def test_serialize_exceptions(self): + for exc in [ + ValueError(10), + SyntaxError("Wrong"), + ImproperlyConfigured("It's wrong"), + InvalidTaskError(""), + SystemExit(), + ]: + with self.subTest(exc): + data = utils.exception_to_dict(exc) + self.assertEqual(utils.json_normalize(data), data) + self.assertEqual( + set(data.keys()), {"exc_type", "exc_args", "exc_traceback"} + ) + exception = utils.exception_from_dict(data) + self.assertIsInstance(exception, type(exc)) + self.assertEqual(exception.args, exc.args) + + # Check that the exception traceback contains a minimal traceback + msg = str(exc.args[0]) if exc.args else "" + traceback = data["exc_traceback"] + self.assertIn(exc.__class__.__name__, traceback) + self.assertIn(msg, traceback) + + def test_serialize_full_traceback(self): + try: + # Using optparse to generate an error because: + # - it's pure python + # - it's easy to trip down + # - it's unlikely to change ever + optparse.OptionParser(option_list=[1]) # type: ignore + except Exception as e: + traceback = utils.exception_to_dict(e)["exc_traceback"] + # The test is willingly fuzzy to ward against changes in the + # traceback formatting + self.assertIn("traceback", traceback.lower()) + self.assertIn("line", traceback.lower()) + self.assertIn(optparse.__file__, traceback) + self.assertTrue( + traceback.endswith("TypeError: not an Option instance: 1\n") + ) + + def test_serialize_traceback_from_c_module(self): + try: + # Same as test_serialize_full_traceback, but uses hashlib + # because it's in C, not in Python + hashlib.md5(1) # type: ignore + except Exception as e: + traceback = utils.exception_to_dict(e)["exc_traceback"] + self.assertIn("traceback", traceback.lower()) + self.assertTrue( + traceback.endswith( + "TypeError: object supporting the buffer API required\n" + ) + ) + self.assertIn("hashlib.md5(1)", traceback) + + def test_cannot_deserialize_non_exception(self): + serialized_exceptions: List[utils.SerializedExceptionDict] = [ + { + "exc_type": "subprocess.check_output", + "exc_args": ["exit", "1"], + "exc_traceback": "", + }, + {"exc_type": "True", "exc_args": [], "exc_traceback": ""}, + {"exc_type": "math.pi", "exc_args": [], "exc_traceback": ""}, + {"exc_type": __name__, "exc_args": [], "exc_traceback": ""}, + { + "exc_type": utils.get_module_path(type(self)), + "exc_args": [], + "exc_traceback": "", + }, + { + "exc_type": utils.get_module_path(Mock), + "exc_args": [], + "exc_traceback": "", + }, + ] + + for data in serialized_exceptions: + with self.subTest(data): + with self.assertRaises((TypeError, ImportError)): + utils.exception_from_dict(data)