diff --git a/django/utils/decorators.py b/django/utils/decorators.py index d8814b0d4c..fb12c7fbcd 100644 --- a/django/utils/decorators.py +++ b/django/utils/decorators.py @@ -2,7 +2,7 @@ from functools import partial, update_wrapper, wraps -from asgiref.sync import iscoroutinefunction +from asgiref.sync import iscoroutinefunction, markcoroutinefunction class classonlymethod(classmethod): @@ -52,6 +52,10 @@ def _multi_decorate(decorators, method): _update_method_wrapper(_wrapper, dec) # Preserve any existing attributes of 'method', including the name. update_wrapper(_wrapper, method) + + if iscoroutinefunction(method): + markcoroutinefunction(_wrapper) + return _wrapper diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index f910416716..add5d9506a 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -128,7 +128,8 @@ Database backends Decorators ~~~~~~~~~~ -* ... +* :func:`~django.utils.decorators.method_decorator` now supports wrapping + asynchronous view methods. Email ~~~~~ diff --git a/tests/decorators/tests.py b/tests/decorators/tests.py index 58f822f2a5..1f8d623e02 100644 --- a/tests/decorators/tests.py +++ b/tests/decorators/tests.py @@ -1,6 +1,9 @@ +import asyncio from functools import update_wrapper, wraps from unittest import TestCase +from asgiref.sync import iscoroutinefunction + from django.contrib.admin.views.decorators import staff_member_required from django.contrib.auth.decorators import ( login_required, @@ -434,3 +437,262 @@ class MethodDecoratorTests(SimpleTestCase): Test().method() self.assertEqual(func_name, "method") self.assertIsNotNone(func_module) + + +def async_simple_dec(func): + @wraps(func) + async def wrapper(*args, **kwargs): + result = await func(*args, **kwargs) + return f"returned: {result}" + + return wrapper + + +async_simple_dec_m = method_decorator(async_simple_dec) + + +class AsyncMethodDecoratorTests(SimpleTestCase): + """ + Tests for async method_decorator + """ + + async def test_preserve_signature(self): + class Test: + @async_simple_dec_m + async def say(self, msg): + return f"Saying {msg}" + + self.assertEqual(await Test().say("hello"), "returned: Saying hello") + + def test_preserve_attributes(self): + async def func(*args, **kwargs): + await asyncio.sleep(0.01) + return args, kwargs + + def myattr_dec(func): + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + wrapper.myattr = True + return wrapper + + def myattr2_dec(func): + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + wrapper.myattr2 = True + return wrapper + + # Sanity check myattr_dec and myattr2_dec + func = myattr_dec(func) + + self.assertIs(getattr(func, "myattr", False), True) + + func = myattr2_dec(func) + self.assertIs(getattr(func, "myattr2", False), True) + + func = myattr_dec(myattr2_dec(func)) + self.assertIs(getattr(func, "myattr", False), True) + self.assertIs(getattr(func, "myattr2", False), False) + + myattr_dec_m = method_decorator(myattr_dec) + myattr2_dec_m = method_decorator(myattr2_dec) + + # Decorate using method_decorator() on the async method. + class TestPlain: + @myattr_dec_m + @myattr2_dec_m + async def method(self): + "A method" + + # Decorate using method_decorator() on both the class and the method. + # The decorators applied to the methods are applied before the ones + # applied to the class. + @method_decorator(myattr_dec_m, "method") + class TestMethodAndClass: + @method_decorator(myattr2_dec_m) + async def method(self): + "A method" + + # Decorate using an iterable of function decorators. + @method_decorator((myattr_dec, myattr2_dec), "method") + class TestFunctionIterable: + async def method(self): + "A method" + + # Decorate using an iterable of method decorators. + @method_decorator((myattr_dec_m, myattr2_dec_m), "method") + class TestMethodIterable: + async def method(self): + "A method" + + tests = ( + TestPlain, + TestMethodAndClass, + TestFunctionIterable, + TestMethodIterable, + ) + for Test in tests: + with self.subTest(Test=Test): + self.assertIs(getattr(Test().method, "myattr", False), True) + self.assertIs(getattr(Test().method, "myattr2", False), True) + self.assertIs(getattr(Test.method, "myattr", False), True) + self.assertIs(getattr(Test.method, "myattr2", False), True) + self.assertEqual(Test.method.__doc__, "A method") + self.assertEqual(Test.method.__name__, "method") + + async def test_new_attribute(self): + """A decorator that sets a new attribute on the method.""" + + def decorate(func): + func.x = 1 + return func + + class MyClass: + @method_decorator(decorate) + async def method(self): + return True + + obj = MyClass() + self.assertEqual(obj.method.x, 1) + self.assertIs(await obj.method(), True) + + def test_bad_iterable(self): + decorators = {async_simple_dec} + msg = "'set' object is not subscriptable" + with self.assertRaisesMessage(TypeError, msg): + + @method_decorator(decorators, "method") + class TestIterable: + async def method(self): + await asyncio.sleep(0.01) + + async def test_argumented(self): + + class ClsDecAsync: + def __init__(self, myattr): + self.myattr = myattr + + def __call__(self, f): + async def wrapper(): + result = await f() + return f"{result} appending {self.myattr}" + + return update_wrapper(wrapper, f) + + class Test: + @method_decorator(ClsDecAsync(False)) + async def method(self): + return True + + self.assertEqual(await Test().method(), "True appending False") + + async def test_descriptors(self): + class bound_wrapper: + def __init__(self, wrapped): + self.wrapped = wrapped + self.__name__ = wrapped.__name__ + + async def __call__(self, *args, **kwargs): + return await self.wrapped(*args, **kwargs) + + def __get__(self, instance, cls=None): + return self + + class descriptor_wrapper: + def __init__(self, wrapped): + self.wrapped = wrapped + self.__name__ = wrapped.__name__ + + def __get__(self, instance, cls=None): + return bound_wrapper(self.wrapped.__get__(instance, cls)) + + class Test: + @async_simple_dec_m + @descriptor_wrapper + async def method(self, arg): + return arg + + self.assertEqual(await Test().method(1), "returned: 1") + + async def test_class_decoration(self): + """ + @method_decorator can be used to decorate a class and its methods. + """ + + @method_decorator(async_simple_dec, name="method") + class Test: + async def method(self): + return False + + async def not_method(self): + return "a string" + + self.assertEqual(await Test().method(), "returned: False") + self.assertEqual(await Test().not_method(), "a string") + + async def test_tuple_of_decorators(self): + """ + @method_decorator can accept a tuple of decorators. + """ + + def add_question_mark(func): + async def _wrapper(*args, **kwargs): + await asyncio.sleep(0.01) + return await func(*args, **kwargs) + "?" + + return _wrapper + + def add_exclamation_mark(func): + async def _wrapper(*args, **kwargs): + await asyncio.sleep(0.01) + return await func(*args, **kwargs) + "!" + + return _wrapper + + decorators = (add_exclamation_mark, add_question_mark) + + @method_decorator(decorators, name="method") + class TestFirst: + async def method(self): + return "hello world" + + class TestSecond: + @method_decorator(decorators) + async def method(self): + return "world hello" + + self.assertEqual(await TestFirst().method(), "hello world?!") + self.assertEqual(await TestSecond().method(), "world hello?!") + + async def test_wrapper_assignments(self): + """@method_decorator preserves wrapper assignments.""" + func_data = {} + + def decorator(func): + @wraps(func) + async def inner(*args, **kwargs): + func_data["func_name"] = getattr(func, "__name__", None) + func_data["func_module"] = getattr(func, "__module__", None) + return await func(*args, **kwargs) + + return inner + + class Test: + @method_decorator(decorator) + async def method(self): + return "tests" + + await Test().method() + expected = {"func_name": "method", "func_module": "decorators.tests"} + self.assertEqual(func_data, expected) + + async def test_markcoroutinefunction_applied(self): + class Test: + @async_simple_dec_m + async def method(self): + return "tests" + + method = Test().method + self.assertIs(iscoroutinefunction(method), True) + self.assertEqual(await method(), "returned: tests")