From 50f89ae850f6b4e35819fe725a08c7e579bfd099 Mon Sep 17 00:00:00 2001 From: Jon Janzen Date: Sun, 31 Mar 2024 12:29:10 -0700 Subject: [PATCH] Fixed #35303 -- Implemented async auth backends and utils. --- django/contrib/auth/__init__.py | 172 +++++++++-- django/contrib/auth/backends.py | 111 +++++++ django/contrib/auth/base_user.py | 3 + django/contrib/auth/decorators.py | 2 +- django/contrib/auth/middleware.py | 85 +++++- django/contrib/auth/models.py | 134 ++++++++- docs/ref/contrib/auth.txt | 144 ++++++++- docs/releases/5.2.txt | 30 ++ docs/topics/auth/customizing.txt | 29 ++ tests/async/test_async_auth.py | 31 ++ tests/auth_tests/models/custom_user.py | 13 + tests/auth_tests/test_auth_backends.py | 324 +++++++++++++++++++++ tests/auth_tests/test_basic.py | 39 ++- tests/auth_tests/test_decorators.py | 6 +- tests/auth_tests/test_models.py | 41 ++- tests/auth_tests/test_remote_user.py | 180 +++++++++++- tests/deprecation/test_middleware_mixin.py | 2 - 17 files changed, 1285 insertions(+), 61 deletions(-) diff --git a/django/contrib/auth/__init__.py b/django/contrib/auth/__init__.py index 3db1445d9e..689567ca6c 100644 --- a/django/contrib/auth/__init__.py +++ b/django/contrib/auth/__init__.py @@ -1,8 +1,6 @@ import inspect import re -from asgiref.sync import sync_to_async - from django.apps import apps as django_apps from django.conf import settings from django.core.exceptions import ImproperlyConfigured, PermissionDenied @@ -40,6 +38,39 @@ def get_backends(): return _get_backends(return_tuples=False) +def _get_compatible_backends(request, **credentials): + for backend, backend_path in _get_backends(return_tuples=True): + backend_signature = inspect.signature(backend.authenticate) + try: + backend_signature.bind(request, **credentials) + except TypeError: + # This backend doesn't accept these credentials as arguments. Try + # the next one. + continue + yield backend, backend_path + + +def _get_backend_from_user(user, backend=None): + try: + backend = backend or user.backend + except AttributeError: + backends = _get_backends(return_tuples=True) + if len(backends) == 1: + _, backend = backends[0] + else: + raise ValueError( + "You have multiple authentication backends configured and " + "therefore must provide the `backend` argument or set the " + "`backend` attribute on the user." + ) + else: + if not isinstance(backend, str): + raise TypeError( + "backend must be a dotted import path string (got %r)." % backend + ) + return backend + + @sensitive_variables("credentials") def _clean_credentials(credentials): """ @@ -62,19 +93,21 @@ def _get_user_session_key(request): return get_user_model()._meta.pk.to_python(request.session[SESSION_KEY]) +async def _aget_user_session_key(request): + # This value in the session is always serialized to a string, so we need + # to convert it back to Python whenever we access it. + session_key = await request.session.aget(SESSION_KEY) + if session_key is None: + raise KeyError() + return get_user_model()._meta.pk.to_python(session_key) + + @sensitive_variables("credentials") def authenticate(request=None, **credentials): """ If the given credentials are valid, return a User object. """ - for backend, backend_path in _get_backends(return_tuples=True): - backend_signature = inspect.signature(backend.authenticate) - try: - backend_signature.bind(request, **credentials) - except TypeError: - # This backend doesn't accept these credentials as arguments. Try - # the next one. - continue + for backend, backend_path in _get_compatible_backends(request, **credentials): try: user = backend.authenticate(request, **credentials) except PermissionDenied: @@ -96,7 +129,23 @@ def authenticate(request=None, **credentials): @sensitive_variables("credentials") async def aauthenticate(request=None, **credentials): """See authenticate().""" - return await sync_to_async(authenticate)(request, **credentials) + for backend, backend_path in _get_compatible_backends(request, **credentials): + try: + user = await backend.aauthenticate(request, **credentials) + except PermissionDenied: + # This backend says to stop in our tracks - this user should not be + # allowed in at all. + break + if user is None: + continue + # Annotate the user object with the path of the backend. + user.backend = backend_path + return user + + # The credentials supplied are invalid to all backends, fire signal. + await user_login_failed.asend( + sender=__name__, credentials=_clean_credentials(credentials), request=request + ) def login(request, user, backend=None): @@ -125,23 +174,7 @@ def login(request, user, backend=None): else: request.session.cycle_key() - try: - backend = backend or user.backend - except AttributeError: - backends = _get_backends(return_tuples=True) - if len(backends) == 1: - _, backend = backends[0] - else: - raise ValueError( - "You have multiple authentication backends configured and " - "therefore must provide the `backend` argument or set the " - "`backend` attribute on the user." - ) - else: - if not isinstance(backend, str): - raise TypeError( - "backend must be a dotted import path string (got %r)." % backend - ) + backend = _get_backend_from_user(user=user, backend=backend) request.session[SESSION_KEY] = user._meta.pk.value_to_string(user) request.session[BACKEND_SESSION_KEY] = backend @@ -154,7 +187,36 @@ def login(request, user, backend=None): async def alogin(request, user, backend=None): """See login().""" - return await sync_to_async(login)(request, user, backend) + session_auth_hash = "" + if user is None: + user = await request.auser() + if hasattr(user, "get_session_auth_hash"): + session_auth_hash = user.get_session_auth_hash() + + if await request.session.ahas_key(SESSION_KEY): + if await _aget_user_session_key(request) != user.pk or ( + session_auth_hash + and not constant_time_compare( + await request.session.aget(HASH_SESSION_KEY, ""), + session_auth_hash, + ) + ): + # To avoid reusing another user's session, create a new, empty + # session if the existing session corresponds to a different + # authenticated user. + await request.session.aflush() + else: + await request.session.acycle_key() + + backend = _get_backend_from_user(user=user, backend=backend) + + await request.session.aset(SESSION_KEY, user._meta.pk.value_to_string(user)) + await request.session.aset(BACKEND_SESSION_KEY, backend) + await request.session.aset(HASH_SESSION_KEY, session_auth_hash) + if hasattr(request, "user"): + request.user = user + rotate_token(request) + await user_logged_in.asend(sender=user.__class__, request=request, user=user) def logout(request): @@ -177,7 +239,19 @@ def logout(request): async def alogout(request): """See logout().""" - return await sync_to_async(logout)(request) + # Dispatch the signal before the user is logged out so the receivers have a + # chance to find out *who* logged out. + user = getattr(request, "auser", None) + if user is not None: + user = await user() + if not getattr(user, "is_authenticated", True): + user = None + await user_logged_out.asend(sender=user.__class__, request=request, user=user) + await request.session.aflush() + if hasattr(request, "user"): + from django.contrib.auth.models import AnonymousUser + + request.user = AnonymousUser() def get_user_model(): @@ -243,7 +317,43 @@ def get_user(request): async def aget_user(request): """See get_user().""" - return await sync_to_async(get_user)(request) + from .models import AnonymousUser + + user = None + try: + user_id = await _aget_user_session_key(request) + backend_path = await request.session.aget(BACKEND_SESSION_KEY) + except KeyError: + pass + else: + if backend_path in settings.AUTHENTICATION_BACKENDS: + backend = load_backend(backend_path) + user = await backend.aget_user(user_id) + # Verify the session + if hasattr(user, "get_session_auth_hash"): + session_hash = await request.session.aget(HASH_SESSION_KEY) + if not session_hash: + session_hash_verified = False + else: + session_auth_hash = user.get_session_auth_hash() + session_hash_verified = session_hash and constant_time_compare( + session_hash, user.get_session_auth_hash() + ) + if not session_hash_verified: + # If the current secret does not verify the session, try + # with the fallback secrets and stop when a matching one is + # found. + if session_hash and any( + constant_time_compare(session_hash, fallback_auth_hash) + for fallback_auth_hash in user.get_session_auth_fallback_hash() + ): + await request.session.acycle_key() + await request.session.aset(HASH_SESSION_KEY, session_auth_hash) + else: + await request.session.aflush() + user = None + + return user or AnonymousUser() def get_permission_codename(action, opts): diff --git a/django/contrib/auth/backends.py b/django/contrib/auth/backends.py index dd3c2e527b..f14fb3e96f 100644 --- a/django/contrib/auth/backends.py +++ b/django/contrib/auth/backends.py @@ -1,3 +1,5 @@ +from asgiref.sync import sync_to_async + from django.contrib.auth import get_user_model from django.contrib.auth.models import Permission from django.db.models import Exists, OuterRef, Q @@ -9,24 +11,45 @@ class BaseBackend: def authenticate(self, request, **kwargs): return None + async def aauthenticate(self, request, **kwargs): + return await sync_to_async(self.authenticate)(request, **kwargs) + def get_user(self, user_id): return None + async def aget_user(self, user_id): + return await sync_to_async(self.get_user)(user_id) + def get_user_permissions(self, user_obj, obj=None): return set() + async def aget_user_permissions(self, user_obj, obj=None): + return await sync_to_async(self.get_user_permissions)(user_obj, obj) + def get_group_permissions(self, user_obj, obj=None): return set() + async def aget_group_permissions(self, user_obj, obj=None): + return await sync_to_async(self.get_group_permissions)(user_obj, obj) + def get_all_permissions(self, user_obj, obj=None): return { *self.get_user_permissions(user_obj, obj=obj), *self.get_group_permissions(user_obj, obj=obj), } + async def aget_all_permissions(self, user_obj, obj=None): + return { + *await self.aget_user_permissions(user_obj, obj=obj), + *await self.aget_group_permissions(user_obj, obj=obj), + } + def has_perm(self, user_obj, perm, obj=None): return perm in self.get_all_permissions(user_obj, obj=obj) + async def ahas_perm(self, user_obj, perm, obj=None): + return perm in await self.aget_all_permissions(user_obj, obj) + class ModelBackend(BaseBackend): """ @@ -48,6 +71,23 @@ class ModelBackend(BaseBackend): if user.check_password(password) and self.user_can_authenticate(user): return user + async def aauthenticate(self, request, username=None, password=None, **kwargs): + if username is None: + username = kwargs.get(UserModel.USERNAME_FIELD) + if username is None or password is None: + return + try: + user = await UserModel._default_manager.aget_by_natural_key(username) + except UserModel.DoesNotExist: + # Run the default password hasher once to reduce the timing + # difference between an existing and a nonexistent user (#20760). + UserModel().set_password(password) + else: + if await user.acheck_password(password) and self.user_can_authenticate( + user + ): + return user + def user_can_authenticate(self, user): """ Reject users with is_active=False. Custom user models that don't have @@ -84,6 +124,25 @@ class ModelBackend(BaseBackend): ) return getattr(user_obj, perm_cache_name) + async def _aget_permissions(self, user_obj, obj, from_name): + """See _get_permissions().""" + if not user_obj.is_active or user_obj.is_anonymous or obj is not None: + return set() + + perm_cache_name = "_%s_perm_cache" % from_name + if not hasattr(user_obj, perm_cache_name): + if user_obj.is_superuser: + perms = Permission.objects.all() + else: + perms = getattr(self, "_get_%s_permissions" % from_name)(user_obj) + perms = perms.values_list("content_type__app_label", "codename").order_by() + setattr( + user_obj, + perm_cache_name, + {"%s.%s" % (ct, name) async for ct, name in perms}, + ) + return getattr(user_obj, perm_cache_name) + def get_user_permissions(self, user_obj, obj=None): """ Return a set of permission strings the user `user_obj` has from their @@ -91,6 +150,10 @@ class ModelBackend(BaseBackend): """ return self._get_permissions(user_obj, obj, "user") + async def aget_user_permissions(self, user_obj, obj=None): + """See get_user_permissions().""" + return await self._aget_permissions(user_obj, obj, "user") + def get_group_permissions(self, user_obj, obj=None): """ Return a set of permission strings the user `user_obj` has from the @@ -98,6 +161,10 @@ class ModelBackend(BaseBackend): """ return self._get_permissions(user_obj, obj, "group") + async def aget_group_permissions(self, user_obj, obj=None): + """See get_group_permissions().""" + return await self._aget_permissions(user_obj, obj, "group") + def get_all_permissions(self, user_obj, obj=None): if not user_obj.is_active or user_obj.is_anonymous or obj is not None: return set() @@ -108,6 +175,9 @@ class ModelBackend(BaseBackend): def has_perm(self, user_obj, perm, obj=None): return user_obj.is_active and super().has_perm(user_obj, perm, obj=obj) + async def ahas_perm(self, user_obj, perm, obj=None): + return user_obj.is_active and await super().ahas_perm(user_obj, perm, obj=obj) + def has_module_perms(self, user_obj, app_label): """ Return True if user_obj has any permissions in the given app_label. @@ -117,6 +187,13 @@ class ModelBackend(BaseBackend): for perm in self.get_all_permissions(user_obj) ) + async def ahas_module_perms(self, user_obj, app_label): + """See has_module_perms()""" + return user_obj.is_active and any( + perm[: perm.index(".")] == app_label + for perm in await self.aget_all_permissions(user_obj) + ) + def with_perm(self, perm, is_active=True, include_superusers=True, obj=None): """ Return users that have permission "perm". By default, filter out @@ -159,6 +236,13 @@ class ModelBackend(BaseBackend): return None return user if self.user_can_authenticate(user) else None + async def aget_user(self, user_id): + try: + user = await UserModel._default_manager.aget(pk=user_id) + except UserModel.DoesNotExist: + return None + return user if self.user_can_authenticate(user) else None + class AllowAllUsersModelBackend(ModelBackend): def user_can_authenticate(self, user): @@ -210,6 +294,29 @@ class RemoteUserBackend(ModelBackend): user = self.configure_user(request, user, created=created) return user if self.user_can_authenticate(user) else None + async def aauthenticate(self, request, remote_user): + """See authenticate().""" + if not remote_user: + return + created = False + user = None + username = self.clean_username(remote_user) + + # Note that this could be accomplished in one try-except clause, but + # instead we use get_or_create when creating unknown users since it has + # built-in safeguards for multiple threads. + if self.create_unknown_user: + user, created = await UserModel._default_manager.aget_or_create( + **{UserModel.USERNAME_FIELD: username} + ) + else: + try: + user = await UserModel._default_manager.aget_by_natural_key(username) + except UserModel.DoesNotExist: + pass + user = await self.aconfigure_user(request, user, created=created) + return user if self.user_can_authenticate(user) else None + def clean_username(self, username): """ Perform any cleaning on the "username" prior to using it to get or @@ -227,6 +334,10 @@ class RemoteUserBackend(ModelBackend): """ return user + async def aconfigure_user(self, request, user, created=True): + """See configure_user()""" + return await sync_to_async(self.configure_user)(request, user, created) + class AllowAllUsersRemoteUserBackend(RemoteUserBackend): def user_can_authenticate(self, user): diff --git a/django/contrib/auth/base_user.py b/django/contrib/auth/base_user.py index 0c9538d69d..5bb88ac4dd 100644 --- a/django/contrib/auth/base_user.py +++ b/django/contrib/auth/base_user.py @@ -36,6 +36,9 @@ class BaseUserManager(models.Manager): def get_by_natural_key(self, username): return self.get(**{self.model.USERNAME_FIELD: username}) + async def aget_by_natural_key(self, username): + return await self.aget(**{self.model.USERNAME_FIELD: username}) + class AbstractBaseUser(models.Model): password = models.CharField(_("password"), max_length=128) diff --git a/django/contrib/auth/decorators.py b/django/contrib/auth/decorators.py index 78e76a9ae9..77fbc79855 100644 --- a/django/contrib/auth/decorators.py +++ b/django/contrib/auth/decorators.py @@ -111,7 +111,7 @@ def permission_required(perm, login_url=None, raise_exception=False): async def check_perms(user): # First check if the user has the permission (even anon users). - if await sync_to_async(user.has_perms)(perms): + if await user.ahas_perms(perms): return True # In case the 403 handler should be called raise the exception. if raise_exception: diff --git a/django/contrib/auth/middleware.py b/django/contrib/auth/middleware.py index cb409ee778..85f58ec9a5 100644 --- a/django/contrib/auth/middleware.py +++ b/django/contrib/auth/middleware.py @@ -1,6 +1,8 @@ from functools import partial from urllib.parse import urlsplit +from asgiref.sync import iscoroutinefunction, markcoroutinefunction + from django.conf import settings from django.contrib import auth from django.contrib.auth import REDIRECT_FIELD_NAME, load_backend @@ -88,7 +90,7 @@ class LoginRequiredMiddleware(MiddlewareMixin): ) -class RemoteUserMiddleware(MiddlewareMixin): +class RemoteUserMiddleware: """ Middleware for utilizing web-server-provided authentication. @@ -102,13 +104,27 @@ class RemoteUserMiddleware(MiddlewareMixin): different header. """ + sync_capable = True + async_capable = True + + def __init__(self, get_response): + if get_response is None: + raise ValueError("get_response must be provided.") + self.get_response = get_response + self.is_async = iscoroutinefunction(get_response) + if self.is_async: + markcoroutinefunction(self) + super().__init__() + # Name of request header to grab username from. This will be the key as # used in the request.META dictionary, i.e. the normalization of headers to # all uppercase and the addition of "HTTP_" prefix apply. header = "REMOTE_USER" force_logout_if_no_header = True - def process_request(self, request): + def __call__(self, request): + if self.is_async: + return self.__acall__(request) # AuthenticationMiddleware is required so that request.user exists. if not hasattr(request, "user"): raise ImproperlyConfigured( @@ -126,13 +142,13 @@ class RemoteUserMiddleware(MiddlewareMixin): # AnonymousUser by the AuthenticationMiddleware). if self.force_logout_if_no_header and request.user.is_authenticated: self._remove_invalid_user(request) - return + return self.get_response(request) # If the user is already authenticated and that user is the user we are # getting passed in the headers, then the correct user is already # persisted in the session and we don't need to continue. if request.user.is_authenticated: if request.user.get_username() == self.clean_username(username, request): - return + return self.get_response(request) else: # An authenticated user is associated with the request, but # it does not match the authorized user in the header. @@ -146,6 +162,51 @@ class RemoteUserMiddleware(MiddlewareMixin): # by logging the user in. request.user = user auth.login(request, user) + return self.get_response(request) + + async def __acall__(self, request): + # AuthenticationMiddleware is required so that request.user exists. + if not hasattr(request, "user"): + raise ImproperlyConfigured( + "The Django remote user auth middleware requires the" + " authentication middleware to be installed. Edit your" + " MIDDLEWARE setting to insert" + " 'django.contrib.auth.middleware.AuthenticationMiddleware'" + " before the RemoteUserMiddleware class." + ) + try: + username = request.META["HTTP_" + self.header] + except KeyError: + # If specified header doesn't exist then remove any existing + # authenticated remote-user, or return (leaving request.user set to + # AnonymousUser by the AuthenticationMiddleware). + if self.force_logout_if_no_header: + user = await request.auser() + if user.is_authenticated: + await self._aremove_invalid_user(request) + return await self.get_response(request) + user = await request.auser() + # If the user is already authenticated and that user is the user we are + # getting passed in the headers, then the correct user is already + # persisted in the session and we don't need to continue. + if user.is_authenticated: + if user.get_username() == self.clean_username(username, request): + return await self.get_response(request) + else: + # An authenticated user is associated with the request, but + # it does not match the authorized user in the header. + await self._aremove_invalid_user(request) + + # We are seeing this user for the first time in this session, attempt + # to authenticate the user. + user = await auth.aauthenticate(request, remote_user=username) + if user: + # User is valid. Set request.user and persist user in the session + # by logging the user in. + request.user = user + await auth.alogin(request, user) + + return await self.get_response(request) def clean_username(self, username, request): """ @@ -176,6 +237,22 @@ class RemoteUserMiddleware(MiddlewareMixin): if isinstance(stored_backend, RemoteUserBackend): auth.logout(request) + async def _aremove_invalid_user(self, request): + """ + Remove the current authenticated user in the request which is invalid + but only if the user is authenticated via the RemoteUserBackend. + """ + try: + stored_backend = load_backend( + await request.session.aget(auth.BACKEND_SESSION_KEY, "") + ) + except ImportError: + # Backend failed to load. + await auth.alogout(request) + else: + if isinstance(stored_backend, RemoteUserBackend): + await auth.alogout(request) + class PersistentRemoteUserMiddleware(RemoteUserMiddleware): """ diff --git a/django/contrib/auth/models.py b/django/contrib/auth/models.py index e5ef1bb523..d4a8dd902b 100644 --- a/django/contrib/auth/models.py +++ b/django/contrib/auth/models.py @@ -95,6 +95,9 @@ class GroupManager(models.Manager): def get_by_natural_key(self, name): return self.get(name=name) + async def aget_by_natural_key(self, name): + return await self.aget(name=name) + class Group(models.Model): """ @@ -137,10 +140,7 @@ class Group(models.Model): class UserManager(BaseUserManager): use_in_migrations = True - def _create_user(self, username, email, password, **extra_fields): - """ - Create and save a user with the given username, email, and password. - """ + def _create_user_object(self, username, email, password, **extra_fields): if not username: raise ValueError("The given username must be set") email = self.normalize_email(email) @@ -153,14 +153,32 @@ class UserManager(BaseUserManager): username = GlobalUserModel.normalize_username(username) user = self.model(username=username, email=email, **extra_fields) user.password = make_password(password) + return user + + def _create_user(self, username, email, password, **extra_fields): + """ + Create and save a user with the given username, email, and password. + """ + user = self._create_user_object(username, email, password, **extra_fields) user.save(using=self._db) return user + async def _acreate_user(self, username, email, password, **extra_fields): + """See _create_user()""" + user = self._create_user_object(username, email, password, **extra_fields) + await user.asave(using=self._db) + return user + def create_user(self, username, email=None, password=None, **extra_fields): extra_fields.setdefault("is_staff", False) extra_fields.setdefault("is_superuser", False) return self._create_user(username, email, password, **extra_fields) + async def acreate_user(self, username, email=None, password=None, **extra_fields): + extra_fields.setdefault("is_staff", False) + extra_fields.setdefault("is_superuser", False) + return await self._acreate_user(username, email, password, **extra_fields) + def create_superuser(self, username, email=None, password=None, **extra_fields): extra_fields.setdefault("is_staff", True) extra_fields.setdefault("is_superuser", True) @@ -172,6 +190,19 @@ class UserManager(BaseUserManager): return self._create_user(username, email, password, **extra_fields) + async def acreate_superuser( + self, username, email=None, password=None, **extra_fields + ): + extra_fields.setdefault("is_staff", True) + extra_fields.setdefault("is_superuser", True) + + if extra_fields.get("is_staff") is not True: + raise ValueError("Superuser must have is_staff=True.") + if extra_fields.get("is_superuser") is not True: + raise ValueError("Superuser must have is_superuser=True.") + + return await self._acreate_user(username, email, password, **extra_fields) + def with_perm( self, perm, is_active=True, include_superusers=True, backend=None, obj=None ): @@ -210,6 +241,15 @@ def _user_get_permissions(user, obj, from_name): return permissions +async def _auser_get_permissions(user, obj, from_name): + permissions = set() + name = "aget_%s_permissions" % from_name + for backend in auth.get_backends(): + if hasattr(backend, name): + permissions.update(await getattr(backend, name)(user, obj)) + return permissions + + def _user_has_perm(user, perm, obj): """ A backend can raise `PermissionDenied` to short-circuit permission checking. @@ -225,6 +265,19 @@ def _user_has_perm(user, perm, obj): return False +async def _auser_has_perm(user, perm, obj): + """See _user_has_perm()""" + for backend in auth.get_backends(): + if not hasattr(backend, "ahas_perm"): + continue + try: + if await backend.ahas_perm(user, perm, obj): + return True + except PermissionDenied: + return False + return False + + def _user_has_module_perms(user, app_label): """ A backend can raise `PermissionDenied` to short-circuit permission checking. @@ -240,6 +293,19 @@ def _user_has_module_perms(user, app_label): return False +async def _auser_has_module_perms(user, app_label): + """See _user_has_module_perms()""" + for backend in auth.get_backends(): + if not hasattr(backend, "ahas_module_perms"): + continue + try: + if await backend.ahas_module_perms(user, app_label): + return True + except PermissionDenied: + return False + return False + + class PermissionsMixin(models.Model): """ Add the fields and methods necessary to support the Group and Permission @@ -285,6 +351,10 @@ class PermissionsMixin(models.Model): """ return _user_get_permissions(self, obj, "user") + async def aget_user_permissions(self, obj=None): + """See get_user_permissions()""" + return await _auser_get_permissions(self, obj, "user") + def get_group_permissions(self, obj=None): """ Return a list of permission strings that this user has through their @@ -293,9 +363,16 @@ class PermissionsMixin(models.Model): """ return _user_get_permissions(self, obj, "group") + async def aget_group_permissions(self, obj=None): + """See get_group_permissions()""" + return await _auser_get_permissions(self, obj, "group") + def get_all_permissions(self, obj=None): return _user_get_permissions(self, obj, "all") + async def aget_all_permissions(self, obj=None): + return await _auser_get_permissions(self, obj, "all") + def has_perm(self, perm, obj=None): """ Return True if the user has the specified permission. Query all @@ -311,6 +388,15 @@ class PermissionsMixin(models.Model): # Otherwise we need to check the backends. return _user_has_perm(self, perm, obj) + async def ahas_perm(self, perm, obj=None): + """See has_perm()""" + # Active superusers have all permissions. + if self.is_active and self.is_superuser: + return True + + # Otherwise we need to check the backends. + return await _auser_has_perm(self, perm, obj) + def has_perms(self, perm_list, obj=None): """ Return True if the user has each of the specified permissions. If @@ -320,6 +406,15 @@ class PermissionsMixin(models.Model): raise ValueError("perm_list must be an iterable of permissions.") return all(self.has_perm(perm, obj) for perm in perm_list) + async def ahas_perms(self, perm_list, obj=None): + """See has_perms()""" + if not isinstance(perm_list, Iterable) or isinstance(perm_list, str): + raise ValueError("perm_list must be an iterable of permissions.") + for perm in perm_list: + if not await self.ahas_perm(perm, obj): + return False + return True + def has_module_perms(self, app_label): """ Return True if the user has any permissions in the given app label. @@ -331,6 +426,14 @@ class PermissionsMixin(models.Model): return _user_has_module_perms(self, app_label) + async def ahas_module_perms(self, app_label): + """See has_module_perms()""" + # Active superusers have all permissions. + if self.is_active and self.is_superuser: + return True + + return await _auser_has_module_perms(self, app_label) + class AbstractUser(AbstractBaseUser, PermissionsMixin): """ @@ -471,23 +574,46 @@ class AnonymousUser: def get_user_permissions(self, obj=None): return _user_get_permissions(self, obj, "user") + async def aget_user_permissions(self, obj=None): + return await _auser_get_permissions(self, obj, "user") + def get_group_permissions(self, obj=None): return set() + async def aget_group_permissions(self, obj=None): + return self.get_group_permissions(obj) + def get_all_permissions(self, obj=None): return _user_get_permissions(self, obj, "all") + async def aget_all_permissions(self, obj=None): + return await _auser_get_permissions(self, obj, "all") + def has_perm(self, perm, obj=None): return _user_has_perm(self, perm, obj=obj) + async def ahas_perm(self, perm, obj=None): + return await _auser_has_perm(self, perm, obj=obj) + def has_perms(self, perm_list, obj=None): if not isinstance(perm_list, Iterable) or isinstance(perm_list, str): raise ValueError("perm_list must be an iterable of permissions.") return all(self.has_perm(perm, obj) for perm in perm_list) + async def ahas_perms(self, perm_list, obj=None): + if not isinstance(perm_list, Iterable) or isinstance(perm_list, str): + raise ValueError("perm_list must be an iterable of permissions.") + for perm in perm_list: + if not await self.ahas_perm(perm, obj): + return False + return True + def has_module_perms(self, module): return _user_has_module_perms(self, module) + async def ahas_module_perms(self, module): + return await _auser_has_module_perms(self, module) + @property def is_anonymous(self): return True diff --git a/docs/ref/contrib/auth.txt b/docs/ref/contrib/auth.txt index d5fc724b54..c8699a2913 100644 --- a/docs/ref/contrib/auth.txt +++ b/docs/ref/contrib/auth.txt @@ -197,13 +197,23 @@ Methods been called for this user. .. method:: get_user_permissions(obj=None) + .. method:: aget_user_permissions(obj=None) + + *Asynchronous version*: ``aget_user_permissions()`` Returns a set of permission strings that the user has directly. If ``obj`` is passed in, only returns the user permissions for this specific object. + .. versionchanged:: 5.2 + + ``aget_user_permissions()`` method was added. + .. method:: get_group_permissions(obj=None) + .. method:: aget_group_permissions(obj=None) + + *Asynchronous version*: ``aget_group_permissions()`` Returns a set of permission strings that the user has, through their groups. @@ -211,7 +221,14 @@ Methods If ``obj`` is passed in, only returns the group permissions for this specific object. + .. versionchanged:: 5.2 + + ``aget_group_permissions()`` method was added. + .. method:: get_all_permissions(obj=None) + .. method:: aget_all_permissions(obj=None) + + *Asynchronous version*: ``aget_all_permissions()`` Returns a set of permission strings that the user has, both through group and user permissions. @@ -219,7 +236,14 @@ Methods If ``obj`` is passed in, only returns the permissions for this specific object. + .. versionchanged:: 5.2 + + ``aget_all_permissions()`` method was added. + .. method:: has_perm(perm, obj=None) + .. method:: ahas_perm(perm, obj=None) + + *Asynchronous version*: ``ahas_perm()`` Returns ``True`` if the user has the specified permission, where perm is in the format ``"."``. (see @@ -230,7 +254,14 @@ Methods If ``obj`` is passed in, this method won't check for a permission for the model, but for this specific object. + .. versionchanged:: 5.2 + + ``ahas_perm()`` method was added. + .. method:: has_perms(perm_list, obj=None) + .. method:: ahas_perms(perm_list, obj=None) + + *Asynchronous version*: ``ahas_perms()`` Returns ``True`` if the user has each of the specified permissions, where each perm is in the format @@ -241,13 +272,24 @@ Methods If ``obj`` is passed in, this method won't check for permissions for the model, but for the specific object. + .. versionchanged:: 5.2 + + ``ahas_perms()`` method was added. + .. method:: has_module_perms(package_name) + .. method:: ahas_module_perms(package_name) + + *Asynchronous version*: ``ahas_module_perms()`` Returns ``True`` if the user has any permissions in the given package (the Django app label). If the user is inactive, this method will always return ``False``. For an active superuser, this method will always return ``True``. + .. versionchanged:: 5.2 + + ``ahas_module_perms()`` method was added. + .. method:: email_user(subject, message, from_email=None, **kwargs) Sends an email to the user. If ``from_email`` is ``None``, Django uses @@ -264,6 +306,9 @@ Manager methods by :class:`~django.contrib.auth.models.BaseUserManager`): .. method:: create_user(username, email=None, password=None, **extra_fields) + .. method:: acreate_user(username, email=None, password=None, **extra_fields) + + *Asynchronous version*: ``acreate_user()`` Creates, saves and returns a :class:`~django.contrib.auth.models.User`. @@ -285,11 +330,22 @@ Manager methods See :ref:`Creating users ` for example usage. + .. versionchanged:: 5.2 + + ``acreate_user()`` method was added. + .. method:: create_superuser(username, email=None, password=None, **extra_fields) + .. method:: acreate_superuser(username, email=None, password=None, **extra_fields) + + *Asynchronous version*: ``acreate_superuser()`` Same as :meth:`create_user`, but sets :attr:`~models.User.is_staff` and :attr:`~models.User.is_superuser` to ``True``. + .. versionchanged:: 5.2 + + ``acreate_superuser()`` method was added. + .. method:: with_perm(perm, is_active=True, include_superusers=True, backend=None, obj=None) Returns users that have the given permission ``perm`` either in the @@ -499,23 +555,51 @@ The following backends are available in :mod:`django.contrib.auth.backends`: methods. By default, it will reject any user and provide no permissions. .. method:: get_user_permissions(user_obj, obj=None) + .. method:: aget_user_permissions(user_obj, obj=None) + + *Asynchronous version*: ``aget_user_permissions()`` Returns an empty set. + .. versionchanged:: 5.2 + + ``aget_user_permissions()`` function was added. + .. method:: get_group_permissions(user_obj, obj=None) + .. method:: aget_group_permissions(user_obj, obj=None) + + *Asynchronous version*: ``aget_group_permissions()`` Returns an empty set. + .. versionchanged:: 5.2 + + ``aget_group_permissions()`` function was added. + .. method:: get_all_permissions(user_obj, obj=None) + .. method:: aget_all_permissions(user_obj, obj=None) + + *Asynchronous version*: ``aget_all_permissions()`` Uses :meth:`get_user_permissions` and :meth:`get_group_permissions` to get the set of permission strings the ``user_obj`` has. + .. versionchanged:: 5.2 + + ``aget_all_permissions()`` function was added. + .. method:: has_perm(user_obj, perm, obj=None) + .. method:: ahas_perm(user_obj, perm, obj=None) + + *Asynchronous version*: ``ahas_perm()`` Uses :meth:`get_all_permissions` to check if ``user_obj`` has the permission string ``perm``. + .. versionchanged:: 5.2 + + ``ahas_perm()`` function was added. + .. class:: ModelBackend This is the default authentication backend used by Django. It @@ -539,6 +623,9 @@ The following backends are available in :mod:`django.contrib.auth.backends`: unlike others methods it returns an empty queryset if ``obj is not None``. .. method:: authenticate(request, username=None, password=None, **kwargs) + .. method:: aauthenticate(request, username=None, password=None, **kwargs) + + *Asynchronous version*: ``aauthenticate()`` Tries to authenticate ``username`` with ``password`` by calling :meth:`User.check_password @@ -552,38 +639,77 @@ The following backends are available in :mod:`django.contrib.auth.backends`: if it wasn't provided to :func:`~django.contrib.auth.authenticate` (which passes it on to the backend). + .. versionchanged:: 5.2 + + ``aauthenticate()`` function was added. + .. method:: get_user_permissions(user_obj, obj=None) + .. method:: aget_user_permissions(user_obj, obj=None) + + *Asynchronous version*: ``aget_user_permissions()`` Returns the set of permission strings the ``user_obj`` has from their own user permissions. Returns an empty set if :attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or :attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``. + .. versionchanged:: 5.2 + + ``aget_user_permissions()`` function was added. + .. method:: get_group_permissions(user_obj, obj=None) + .. method:: aget_group_permissions(user_obj, obj=None) + + *Asynchronous version*: ``aget_group_permissions()`` Returns the set of permission strings the ``user_obj`` has from the permissions of the groups they belong. Returns an empty set if :attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or :attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``. + .. versionchanged:: 5.2 + + ``aget_group_permissions()`` function was added. + .. method:: get_all_permissions(user_obj, obj=None) + .. method:: aget_all_permissions(user_obj, obj=None) + + *Asynchronous version*: ``aget_all_permissions()`` Returns the set of permission strings the ``user_obj`` has, including both user permissions and group permissions. Returns an empty set if :attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or :attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``. + + .. versionchanged:: 5.2 + + ``aget_all_permissions()`` function was added. .. method:: has_perm(user_obj, perm, obj=None) + .. method:: ahas_perm(user_obj, perm, obj=None) + + *Asynchronous version*: ``ahas_perm()`` Uses :meth:`get_all_permissions` to check if ``user_obj`` has the permission string ``perm``. Returns ``False`` if the user is not :attr:`~django.contrib.auth.models.CustomUser.is_active`. + .. versionchanged:: 5.2 + + ``ahas_perm()`` function was added. + .. method:: has_module_perms(user_obj, app_label) + .. method:: ahas_module_perms(user_obj, app_label) + + *Asynchronous version*: ``ahas_module_perms()`` Returns whether the ``user_obj`` has any permissions on the app ``app_label``. + .. versionchanged:: 5.2 + + ``ahas_module_perms()`` function was added. + .. method:: user_can_authenticate() Returns whether the user is allowed to authenticate. To match the @@ -637,6 +763,9 @@ The following backends are available in :mod:`django.contrib.auth.backends`: created if not already in the database Defaults to ``True``. .. method:: authenticate(request, remote_user) + .. method:: aauthenticate(request, remote_user) + + *Asynchronous version*: ``aauthenticate()`` The username passed as ``remote_user`` is considered trusted. This method returns the user object with the given username, creating a new @@ -651,6 +780,10 @@ The following backends are available in :mod:`django.contrib.auth.backends`: if it wasn't provided to :func:`~django.contrib.auth.authenticate` (which passes it on to the backend). + .. versionchanged:: 5.2 + + ``aauthenticate()`` function was added. + .. method:: clean_username(username) Performs any cleaning on the ``username`` (e.g. stripping LDAP DN @@ -658,12 +791,17 @@ The following backends are available in :mod:`django.contrib.auth.backends`: the cleaned username. .. method:: configure_user(request, user, created=True) + .. method:: aconfigure_user(request, user, created=True) + + *Asynchronous version*: ``aconfigure_user()`` Configures the user on each authentication attempt. This method is called immediately after fetching or creating the user being authenticated, and can be used to perform custom setup actions, such as setting the user's groups based on attributes in an LDAP directory. - Returns the user object. + Returns the user object. When fetching or creating an user is called + from a synchronous context, ``configure_user`` is called, + ``aconfigure_user`` is called from async contexts. The setup can be performed either once when the user is created (``created`` is ``True``) or on existing users (``created`` is @@ -674,6 +812,10 @@ The following backends are available in :mod:`django.contrib.auth.backends`: if it wasn't provided to :func:`~django.contrib.auth.authenticate` (which passes it on to the backend). + .. versionchanged:: 5.2 + + ``aconfigure_user()`` function was added. + .. method:: user_can_authenticate() Returns whether the user is allowed to authenticate. This method diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index c445e02694..9aa232b902 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -52,6 +52,36 @@ Minor features * The default iteration count for the PBKDF2 password hasher is increased from 870,000 to 1,000,000. +* The following new asynchronous methods on are now provided, using an ``a`` + prefix: + + * :meth:`.UserManager.acreate_user` + * :meth:`.UserManager.acreate_superuser` + * :meth:`.BaseUserManager.aget_by_natural_key` + * :meth:`.User.aget_user_permissions()` + * :meth:`.User.aget_all_permissions()` + * :meth:`.User.aget_group_permissions()` + * :meth:`.User.ahas_perm()` + * :meth:`.User.ahas_perms()` + * :meth:`.User.ahas_module_perms()` + * :meth:`.User.aget_user_permissions()` + * :meth:`.User.aget_group_permissions()` + * :meth:`.User.ahas_perm()` + * :meth:`.ModelBackend.aauthenticate()` + * :meth:`.ModelBackend.aget_user_permissions()` + * :meth:`.ModelBackend.aget_group_permissions()` + * :meth:`.ModelBackend.aget_all_permissions()` + * :meth:`.ModelBackend.ahas_perm()` + * :meth:`.ModelBackend.ahas_module_perms()` + * :meth:`.RemoteUserBackend.aauthenticate()` + * :meth:`.RemoteUserBackend.aconfigure_user()` + +* Auth backends can now provide async implementations which are used when + calling async auth functions (e.g. + :func:`~.django.contrib.auth.aauthenticate`) to reduce context-switching which + improves performance. See :ref:`adding an async interface + ` for more details. + :mod:`django.contrib.contenttypes` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/topics/auth/customizing.txt b/docs/topics/auth/customizing.txt index f41b10fb4a..6fdcd136c0 100644 --- a/docs/topics/auth/customizing.txt +++ b/docs/topics/auth/customizing.txt @@ -790,10 +790,17 @@ utility methods: email address. .. method:: models.BaseUserManager.get_by_natural_key(username) + .. method:: models.BaseUserManager.aget_by_natural_key(username) + + *Asynchronous version*: ``aget_by_natural_key()`` Retrieves a user instance using the contents of the field nominated by ``USERNAME_FIELD``. + .. versionchanged:: 5.2 + + ``aget_by_natural_key()`` method was added. + Extending Django's default ``User`` ----------------------------------- @@ -1186,3 +1193,25 @@ Finally, specify the custom model as the default user model for your project using the :setting:`AUTH_USER_MODEL` setting in your ``settings.py``:: AUTH_USER_MODEL = "customauth.MyUser" + +.. _writing-authentication-backends-async-interface: + +Adding an async interface +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 5.2 + +To optimize performance when called from an async context authentication, +backends can implement async versions of each function - ``aget_user(user_id)`` +and ``aauthenticate(request, **credentials)``. When an authentication backend +extends ``BaseBackend`` and async versions of these functions are not provided, +they will be automatically synthesized with ``sync_to_async``. This has +:ref:`performance penalties `. + +While an async interface is optional, a synchronous interface is always +required. There is no automatic synthesis for a synchronous interface if an +async interface is implemented. + +Django's out-of-the-box authentication backends have native async support. If +these native backends are extended take special care to make sure the async +versions of modified functions are modified as well. diff --git a/tests/async/test_async_auth.py b/tests/async/test_async_auth.py index f6551c63ee..37884d13a6 100644 --- a/tests/async/test_async_auth.py +++ b/tests/async/test_async_auth.py @@ -33,9 +33,40 @@ class AsyncAuthTest(TestCase): self.assertIsInstance(user, User) self.assertEqual(user.username, self.test_user.username) + async def test_changed_password_invalidates_aget_user(self): + request = HttpRequest() + request.session = await self.client.asession() + await alogin(request, self.test_user) + + self.test_user.set_password("new_password") + await self.test_user.asave() + + user = await aget_user(request) + + self.assertIsNotNone(user) + self.assertTrue(user.is_anonymous) + # Session should be flushed. + self.assertIsNone(request.session.session_key) + + async def test_alogin_new_user(self): + request = HttpRequest() + request.session = await self.client.asession() + await alogin(request, self.test_user) + second_user = await User.objects.acreate_user( + "testuser2", "test2@example.com", "testpw2" + ) + await alogin(request, second_user) + user = await aget_user(request) + self.assertIsInstance(user, User) + self.assertEqual(user.username, second_user.username) + async def test_alogin_without_user(self): + async def auser(): + return self.test_user + request = HttpRequest() request.user = self.test_user + request.auser = auser request.session = await self.client.asession() await alogin(request, None) user = await aget_user(request) diff --git a/tests/auth_tests/models/custom_user.py b/tests/auth_tests/models/custom_user.py index b9938681ca..4586e452cd 100644 --- a/tests/auth_tests/models/custom_user.py +++ b/tests/auth_tests/models/custom_user.py @@ -29,6 +29,19 @@ class CustomUserManager(BaseUserManager): user.save(using=self._db) return user + async def acreate_user(self, email, date_of_birth, password=None, **fields): + """See create_user()""" + if not email: + raise ValueError("Users must have an email address") + + user = self.model( + email=self.normalize_email(email), date_of_birth=date_of_birth, **fields + ) + + user.set_password(password) + await user.asave(using=self._db) + return user + def create_superuser(self, email, password, date_of_birth, **fields): u = self.create_user( email, password=password, date_of_birth=date_of_birth, **fields diff --git a/tests/auth_tests/test_auth_backends.py b/tests/auth_tests/test_auth_backends.py index 3b4f40e6e0..b612d27ab0 100644 --- a/tests/auth_tests/test_auth_backends.py +++ b/tests/auth_tests/test_auth_backends.py @@ -2,6 +2,8 @@ import sys from datetime import date from unittest import mock +from asgiref.sync import sync_to_async + from django.contrib.auth import ( BACKEND_SESSION_KEY, SESSION_KEY, @@ -55,17 +57,33 @@ class BaseBackendTest(TestCase): def test_get_user_permissions(self): self.assertEqual(self.user.get_user_permissions(), {"user_perm"}) + async def test_aget_user_permissions(self): + self.assertEqual(await self.user.aget_user_permissions(), {"user_perm"}) + def test_get_group_permissions(self): self.assertEqual(self.user.get_group_permissions(), {"group_perm"}) + async def test_aget_group_permissions(self): + self.assertEqual(await self.user.aget_group_permissions(), {"group_perm"}) + def test_get_all_permissions(self): self.assertEqual(self.user.get_all_permissions(), {"user_perm", "group_perm"}) + async def test_aget_all_permissions(self): + self.assertEqual( + await self.user.aget_all_permissions(), {"user_perm", "group_perm"} + ) + def test_has_perm(self): self.assertIs(self.user.has_perm("user_perm"), True) self.assertIs(self.user.has_perm("group_perm"), True) self.assertIs(self.user.has_perm("other_perm", TestObj()), False) + async def test_ahas_perm(self): + self.assertIs(await self.user.ahas_perm("user_perm"), True) + self.assertIs(await self.user.ahas_perm("group_perm"), True) + self.assertIs(await self.user.ahas_perm("other_perm", TestObj()), False) + def test_has_perms_perm_list_invalid(self): msg = "perm_list must be an iterable of permissions." with self.assertRaisesMessage(ValueError, msg): @@ -73,6 +91,13 @@ class BaseBackendTest(TestCase): with self.assertRaisesMessage(ValueError, msg): self.user.has_perms(object()) + async def test_ahas_perms_perm_list_invalid(self): + msg = "perm_list must be an iterable of permissions." + with self.assertRaisesMessage(ValueError, msg): + await self.user.ahas_perms("user_perm") + with self.assertRaisesMessage(ValueError, msg): + await self.user.ahas_perms(object()) + class CountingMD5PasswordHasher(MD5PasswordHasher): """Hasher that counts how many times it computes a hash.""" @@ -125,6 +150,25 @@ class BaseModelBackendTest: user.save() self.assertIs(user.has_perm("auth.test"), False) + async def test_ahas_perm(self): + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + self.assertIs(await user.ahas_perm("auth.test"), False) + + user.is_staff = True + await user.asave() + self.assertIs(await user.ahas_perm("auth.test"), False) + + user.is_superuser = True + await user.asave() + self.assertIs(await user.ahas_perm("auth.test"), True) + self.assertIs(await user.ahas_module_perms("auth"), True) + + user.is_staff = True + user.is_superuser = True + user.is_active = False + await user.asave() + self.assertIs(await user.ahas_perm("auth.test"), False) + def test_custom_perms(self): user = self.UserModel._default_manager.get(pk=self.user.pk) content_type = ContentType.objects.get_for_model(Group) @@ -174,6 +218,55 @@ class BaseModelBackendTest: self.assertIs(user.has_perm("test"), False) self.assertIs(user.has_perms(["auth.test2", "auth.test3"]), False) + async def test_acustom_perms(self): + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await user.user_permissions.aadd(perm) + + # Reloading user to purge the _perm_cache. + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + self.assertEqual(await user.aget_all_permissions(), {"auth.test"}) + self.assertEqual(await user.aget_user_permissions(), {"auth.test"}) + self.assertEqual(await user.aget_group_permissions(), set()) + self.assertIs(await user.ahas_module_perms("Group"), False) + self.assertIs(await user.ahas_module_perms("auth"), True) + + perm = await Permission.objects.acreate( + name="test2", content_type=content_type, codename="test2" + ) + await user.user_permissions.aadd(perm) + perm = await Permission.objects.acreate( + name="test3", content_type=content_type, codename="test3" + ) + await user.user_permissions.aadd(perm) + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + expected_user_perms = {"auth.test2", "auth.test", "auth.test3"} + self.assertEqual(await user.aget_all_permissions(), expected_user_perms) + self.assertIs(await user.ahas_perm("test"), False) + self.assertIs(await user.ahas_perm("auth.test"), True) + self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), True) + + perm = await Permission.objects.acreate( + name="test_group", content_type=content_type, codename="test_group" + ) + group = await Group.objects.acreate(name="test_group") + await group.permissions.aadd(perm) + await user.groups.aadd(group) + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + self.assertEqual( + await user.aget_all_permissions(), {*expected_user_perms, "auth.test_group"} + ) + self.assertEqual(await user.aget_user_permissions(), expected_user_perms) + self.assertEqual(await user.aget_group_permissions(), {"auth.test_group"}) + self.assertIs(await user.ahas_perms(["auth.test3", "auth.test_group"]), True) + + user = AnonymousUser() + self.assertIs(await user.ahas_perm("test"), False) + self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), False) + def test_has_no_object_perm(self): """Regressiontest for #12462""" user = self.UserModel._default_manager.get(pk=self.user.pk) @@ -188,6 +281,20 @@ class BaseModelBackendTest: self.assertIs(user.has_perm("auth.test"), True) self.assertEqual(user.get_all_permissions(), {"auth.test"}) + async def test_ahas_no_object_perm(self): + """See test_has_no_object_perm()""" + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await user.user_permissions.aadd(perm) + + self.assertIs(await user.ahas_perm("auth.test", "object"), False) + self.assertEqual(await user.aget_all_permissions("object"), set()) + self.assertIs(await user.ahas_perm("auth.test"), True) + self.assertEqual(await user.aget_all_permissions(), {"auth.test"}) + def test_anonymous_has_no_permissions(self): """ #17903 -- Anonymous users shouldn't have permissions in @@ -220,6 +327,38 @@ class BaseModelBackendTest: self.assertEqual(backend.get_user_permissions(user), set()) self.assertEqual(backend.get_group_permissions(user), set()) + async def test_aanonymous_has_no_permissions(self): + """See test_anonymous_has_no_permissions()""" + backend = ModelBackend() + + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + user_perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test_user" + ) + group_perm = await Permission.objects.acreate( + name="test2", content_type=content_type, codename="test_group" + ) + await user.user_permissions.aadd(user_perm) + + group = await Group.objects.acreate(name="test_group") + await user.groups.aadd(group) + await group.permissions.aadd(group_perm) + + self.assertEqual( + await backend.aget_all_permissions(user), + {"auth.test_user", "auth.test_group"}, + ) + self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"}) + self.assertEqual( + await backend.aget_group_permissions(user), {"auth.test_group"} + ) + + with mock.patch.object(self.UserModel, "is_anonymous", True): + self.assertEqual(await backend.aget_all_permissions(user), set()) + self.assertEqual(await backend.aget_user_permissions(user), set()) + self.assertEqual(await backend.aget_group_permissions(user), set()) + def test_inactive_has_no_permissions(self): """ #17903 -- Inactive users shouldn't have permissions in @@ -254,11 +393,52 @@ class BaseModelBackendTest: self.assertEqual(backend.get_user_permissions(user), set()) self.assertEqual(backend.get_group_permissions(user), set()) + async def test_ainactive_has_no_permissions(self): + """See test_inactive_has_no_permissions()""" + backend = ModelBackend() + + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + user_perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test_user" + ) + group_perm = await Permission.objects.acreate( + name="test2", content_type=content_type, codename="test_group" + ) + await user.user_permissions.aadd(user_perm) + + group = await Group.objects.acreate(name="test_group") + await user.groups.aadd(group) + await group.permissions.aadd(group_perm) + + self.assertEqual( + await backend.aget_all_permissions(user), + {"auth.test_user", "auth.test_group"}, + ) + self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"}) + self.assertEqual( + await backend.aget_group_permissions(user), {"auth.test_group"} + ) + + user.is_active = False + await user.asave() + + self.assertEqual(await backend.aget_all_permissions(user), set()) + self.assertEqual(await backend.aget_user_permissions(user), set()) + self.assertEqual(await backend.aget_group_permissions(user), set()) + def test_get_all_superuser_permissions(self): """A superuser has all permissions. Refs #14795.""" user = self.UserModel._default_manager.get(pk=self.superuser.pk) self.assertEqual(len(user.get_all_permissions()), len(Permission.objects.all())) + async def test_aget_all_superuser_permissions(self): + """See test_get_all_superuser_permissions()""" + user = await self.UserModel._default_manager.aget(pk=self.superuser.pk) + self.assertEqual( + len(await user.aget_all_permissions()), await Permission.objects.acount() + ) + @override_settings( PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] ) @@ -277,6 +457,24 @@ class BaseModelBackendTest: authenticate(username="no_such_user", password="test") self.assertEqual(CountingMD5PasswordHasher.calls, 1) + @override_settings( + PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] + ) + async def test_aauthentication_timing(self): + """See test_authentication_timing()""" + # Re-set the password, because this tests overrides PASSWORD_HASHERS. + self.user.set_password("test") + await self.user.asave() + + CountingMD5PasswordHasher.calls = 0 + username = getattr(self.user, self.UserModel.USERNAME_FIELD) + await aauthenticate(username=username, password="test") + self.assertEqual(CountingMD5PasswordHasher.calls, 1) + + CountingMD5PasswordHasher.calls = 0 + await aauthenticate(username="no_such_user", password="test") + self.assertEqual(CountingMD5PasswordHasher.calls, 1) + @override_settings( PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] ) @@ -320,6 +518,15 @@ class ModelBackendTest(BaseModelBackendTest, TestCase): self.user.save() self.assertIsNone(authenticate(**self.user_credentials)) + async def test_aauthenticate_inactive(self): + """ + An inactive user can't authenticate. + """ + self.assertEqual(await aauthenticate(**self.user_credentials), self.user) + self.user.is_active = False + await self.user.asave() + self.assertIsNone(await aauthenticate(**self.user_credentials)) + @override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField") def test_authenticate_user_without_is_active_field(self): """ @@ -332,6 +539,18 @@ class ModelBackendTest(BaseModelBackendTest, TestCase): ) self.assertEqual(authenticate(username="test", password="test"), user) + @override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField") + async def test_aauthenticate_user_without_is_active_field(self): + """ + A custom user without an `is_active` field is allowed to authenticate. + """ + user = await CustomUserWithoutIsActiveField.objects._acreate_user( + username="test", + email="test@example.com", + password="test", + ) + self.assertEqual(await aauthenticate(username="test", password="test"), user) + @override_settings(AUTH_USER_MODEL="auth_tests.ExtensionUser") class ExtensionUserModelBackendTest(BaseModelBackendTest, TestCase): @@ -403,6 +622,15 @@ class CustomUserModelBackendAuthenticateTest(TestCase): authenticated_user = authenticate(email="test@example.com", password="test") self.assertEqual(test_user, authenticated_user) + async def test_aauthenticate(self): + test_user = await CustomUser._default_manager.acreate_user( + email="test@example.com", password="test", date_of_birth=date(2006, 4, 25) + ) + authenticated_user = await aauthenticate( + email="test@example.com", password="test" + ) + self.assertEqual(test_user, authenticated_user) + @override_settings(AUTH_USER_MODEL="auth_tests.UUIDUser") class UUIDUserTests(TestCase): @@ -416,6 +644,13 @@ class UUIDUserTests(TestCase): UUIDUser.objects.get(pk=self.client.session[SESSION_KEY]), user ) + async def test_alogin(self): + """See test_login()""" + user = await UUIDUser.objects.acreate_user(username="uuid", password="test") + self.assertTrue(await self.client.alogin(username="uuid", password="test")) + session_key = await self.client.session.aget(SESSION_KEY) + self.assertEqual(await UUIDUser.objects.aget(pk=session_key), user) + class TestObj: pass @@ -435,9 +670,15 @@ class SimpleRowlevelBackend: return True return False + async def ahas_perm(self, user, perm, obj=None): + return self.has_perm(user, perm, obj) + def has_module_perms(self, user, app_label): return (user.is_anonymous or user.is_active) and app_label == "app1" + async def ahas_module_perms(self, user, app_label): + return self.has_module_perms(user, app_label) + def get_all_permissions(self, user, obj=None): if not obj: return [] # We only support row level perms @@ -452,6 +693,9 @@ class SimpleRowlevelBackend: else: return ["simple"] + async def aget_all_permissions(self, user, obj=None): + return self.get_all_permissions(user, obj) + def get_group_permissions(self, user, obj=None): if not obj: return # We only support row level perms @@ -524,10 +768,18 @@ class AnonymousUserBackendTest(SimpleTestCase): self.assertIs(self.user1.has_perm("perm", TestObj()), False) self.assertIs(self.user1.has_perm("anon", TestObj()), True) + async def test_ahas_perm(self): + self.assertIs(await self.user1.ahas_perm("perm", TestObj()), False) + self.assertIs(await self.user1.ahas_perm("anon", TestObj()), True) + def test_has_perms(self): self.assertIs(self.user1.has_perms(["anon"], TestObj()), True) self.assertIs(self.user1.has_perms(["anon", "perm"], TestObj()), False) + async def test_ahas_perms(self): + self.assertIs(await self.user1.ahas_perms(["anon"], TestObj()), True) + self.assertIs(await self.user1.ahas_perms(["anon", "perm"], TestObj()), False) + def test_has_perms_perm_list_invalid(self): msg = "perm_list must be an iterable of permissions." with self.assertRaisesMessage(ValueError, msg): @@ -535,13 +787,27 @@ class AnonymousUserBackendTest(SimpleTestCase): with self.assertRaisesMessage(ValueError, msg): self.user1.has_perms(object()) + async def test_ahas_perms_perm_list_invalid(self): + msg = "perm_list must be an iterable of permissions." + with self.assertRaisesMessage(ValueError, msg): + await self.user1.ahas_perms("perm") + with self.assertRaisesMessage(ValueError, msg): + await self.user1.ahas_perms(object()) + def test_has_module_perms(self): self.assertIs(self.user1.has_module_perms("app1"), True) self.assertIs(self.user1.has_module_perms("app2"), False) + async def test_ahas_module_perms(self): + self.assertIs(await self.user1.ahas_module_perms("app1"), True) + self.assertIs(await self.user1.ahas_module_perms("app2"), False) + def test_get_all_permissions(self): self.assertEqual(self.user1.get_all_permissions(TestObj()), {"anon"}) + async def test_aget_all_permissions(self): + self.assertEqual(await self.user1.aget_all_permissions(TestObj()), {"anon"}) + @override_settings(AUTHENTICATION_BACKENDS=[]) class NoBackendsTest(TestCase): @@ -561,6 +827,14 @@ class NoBackendsTest(TestCase): with self.assertRaisesMessage(ImproperlyConfigured, msg): self.user.has_perm(("perm", TestObj())) + async def test_araises_exception(self): + msg = ( + "No authentication backends have been defined. " + "Does AUTHENTICATION_BACKENDS contain anything?" + ) + with self.assertRaisesMessage(ImproperlyConfigured, msg): + await self.user.ahas_perm(("perm", TestObj())) + @override_settings( AUTHENTICATION_BACKENDS=["auth_tests.test_auth_backends.SimpleRowlevelBackend"] @@ -593,12 +867,21 @@ class PermissionDeniedBackend: def authenticate(self, request, username=None, password=None): raise PermissionDenied + async def aauthenticate(self, request, username=None, password=None): + raise PermissionDenied + def has_perm(self, user_obj, perm, obj=None): raise PermissionDenied + async def ahas_perm(self, user_obj, perm, obj=None): + raise PermissionDenied + def has_module_perms(self, user_obj, app_label): raise PermissionDenied + async def ahas_module_perms(self, user_obj, app_label): + raise PermissionDenied + class PermissionDeniedBackendTest(TestCase): """ @@ -631,10 +914,25 @@ class PermissionDeniedBackendTest(TestCase): [{"password": "********************", "username": "test"}], ) + @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend}) + async def test_aauthenticate_permission_denied(self): + self.assertIsNone(await aauthenticate(username="test", password="test")) + # user_login_failed signal is sent. + self.assertEqual( + self.user_login_failed, + [{"password": "********************", "username": "test"}], + ) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) def test_authenticates(self): self.assertEqual(authenticate(username="test", password="test"), self.user1) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) + async def test_aauthenticate(self): + self.assertEqual( + await aauthenticate(username="test", password="test"), self.user1 + ) + @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend}) def test_has_perm_denied(self): content_type = ContentType.objects.get_for_model(Group) @@ -646,6 +944,17 @@ class PermissionDeniedBackendTest(TestCase): self.assertIs(self.user1.has_perm("auth.test"), False) self.assertIs(self.user1.has_module_perms("auth"), False) + @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend}) + async def test_ahas_perm_denied(self): + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await self.user1.user_permissions.aadd(perm) + + self.assertIs(await self.user1.ahas_perm("auth.test"), False) + self.assertIs(await self.user1.ahas_module_perms("auth"), False) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) def test_has_perm(self): content_type = ContentType.objects.get_for_model(Group) @@ -657,6 +966,17 @@ class PermissionDeniedBackendTest(TestCase): self.assertIs(self.user1.has_perm("auth.test"), True) self.assertIs(self.user1.has_module_perms("auth"), True) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) + async def test_ahas_perm(self): + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await self.user1.user_permissions.aadd(perm) + + self.assertIs(await self.user1.ahas_perm("auth.test"), True) + self.assertIs(await self.user1.ahas_module_perms("auth"), True) + class NewModelBackend(ModelBackend): pass @@ -715,6 +1035,10 @@ class TypeErrorBackend: def authenticate(self, request, username=None, password=None): raise TypeError + @sensitive_variables("password") + async def aauthenticate(self, request, username=None, password=None): + raise TypeError + class SkippedBackend: def authenticate(self): diff --git a/tests/auth_tests/test_basic.py b/tests/auth_tests/test_basic.py index d7a7750b54..8d54e187fc 100644 --- a/tests/auth_tests/test_basic.py +++ b/tests/auth_tests/test_basic.py @@ -1,5 +1,3 @@ -from asgiref.sync import sync_to_async - from django.conf import settings from django.contrib.auth import aget_user, get_user, get_user_model from django.contrib.auth.models import AnonymousUser, User @@ -44,6 +42,12 @@ class BasicTestCase(TestCase): u2 = User.objects.create_user("testuser2", "test2@example.com") self.assertFalse(u2.has_usable_password()) + async def test_acreate(self): + u = await User.objects.acreate_user("testuser", "test@example.com", "testpw") + self.assertTrue(u.has_usable_password()) + self.assertFalse(await u.acheck_password("bad")) + self.assertTrue(await u.acheck_password("testpw")) + def test_unicode_username(self): User.objects.create_user("jörg") User.objects.create_user("Григорий") @@ -73,6 +77,15 @@ class BasicTestCase(TestCase): self.assertTrue(super.is_active) self.assertTrue(super.is_staff) + async def test_asuperuser(self): + "Check the creation and properties of a superuser" + super = await User.objects.acreate_superuser( + "super", "super@example.com", "super" + ) + self.assertTrue(super.is_superuser) + self.assertTrue(super.is_active) + self.assertTrue(super.is_staff) + def test_superuser_no_email_or_password(self): cases = [ {}, @@ -171,13 +184,25 @@ class TestGetUser(TestCase): self.assertIsInstance(user, User) self.assertEqual(user.username, created_user.username) - async def test_aget_user(self): - created_user = await sync_to_async(User.objects.create_user)( + async def test_aget_user_fallback_secret(self): + created_user = await User.objects.acreate_user( "testuser", "test@example.com", "testpw" ) await self.client.alogin(username="testuser", password="testpw") request = HttpRequest() request.session = await self.client.asession() - user = await aget_user(request) - self.assertIsInstance(user, User) - self.assertEqual(user.username, created_user.username) + prev_session_key = request.session.session_key + with override_settings( + SECRET_KEY="newsecret", + SECRET_KEY_FALLBACKS=[settings.SECRET_KEY], + ): + user = await aget_user(request) + self.assertIsInstance(user, User) + self.assertEqual(user.username, created_user.username) + self.assertNotEqual(request.session.session_key, prev_session_key) + # Remove the fallback secret. + # The session hash should be updated using the current secret. + with override_settings(SECRET_KEY="newsecret"): + user = await aget_user(request) + self.assertIsInstance(user, User) + self.assertEqual(user.username, created_user.username) diff --git a/tests/auth_tests/test_decorators.py b/tests/auth_tests/test_decorators.py index e585b28bd5..fa2672beb4 100644 --- a/tests/auth_tests/test_decorators.py +++ b/tests/auth_tests/test_decorators.py @@ -1,7 +1,5 @@ from asyncio import iscoroutinefunction -from asgiref.sync import sync_to_async - from django.conf import settings from django.contrib.auth import models from django.contrib.auth.decorators import ( @@ -374,7 +372,7 @@ class UserPassesTestDecoratorTest(TestCase): def test_decorator_async_test_func(self): async def async_test_func(user): - return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) + return await user.ahas_perms(["auth_tests.add_customuser"]) @user_passes_test(async_test_func) def sync_view(request): @@ -410,7 +408,7 @@ class UserPassesTestDecoratorTest(TestCase): async def test_decorator_async_view_async_test_func(self): async def async_test_func(user): - return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) + return await user.ahas_perms(["auth_tests.add_customuser"]) @user_passes_test(async_test_func) async def async_view(request): diff --git a/tests/auth_tests/test_models.py b/tests/auth_tests/test_models.py index 983424843c..a3e7a3205b 100644 --- a/tests/auth_tests/test_models.py +++ b/tests/auth_tests/test_models.py @@ -1,7 +1,5 @@ from unittest import mock -from asgiref.sync import sync_to_async - from django.conf.global_settings import PASSWORD_HASHERS from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend @@ -30,10 +28,19 @@ class NaturalKeysTestCase(TestCase): self.assertEqual(User.objects.get_by_natural_key("staff"), staff_user) self.assertEqual(staff_user.natural_key(), ("staff",)) + async def test_auser_natural_key(self): + staff_user = await User.objects.acreate_user(username="staff") + self.assertEqual(await User.objects.aget_by_natural_key("staff"), staff_user) + self.assertEqual(staff_user.natural_key(), ("staff",)) + def test_group_natural_key(self): users_group = Group.objects.create(name="users") self.assertEqual(Group.objects.get_by_natural_key("users"), users_group) + async def test_agroup_natural_key(self): + users_group = await Group.objects.acreate(name="users") + self.assertEqual(await Group.objects.aget_by_natural_key("users"), users_group) + class LoadDataWithoutNaturalKeysTestCase(TestCase): fixtures = ["regular.json"] @@ -157,6 +164,17 @@ class UserManagerTestCase(TransactionTestCase): is_superuser=False, ) + async def test_acreate_super_user_raises_error_on_false_is_superuser(self): + with self.assertRaisesMessage( + ValueError, "Superuser must have is_superuser=True." + ): + await User.objects.acreate_superuser( + username="test", + email="test@test.com", + password="test", + is_superuser=False, + ) + def test_create_superuser_raises_error_on_false_is_staff(self): with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."): User.objects.create_superuser( @@ -166,6 +184,15 @@ class UserManagerTestCase(TransactionTestCase): is_staff=False, ) + async def test_acreate_superuser_raises_error_on_false_is_staff(self): + with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."): + await User.objects.acreate_superuser( + username="test", + email="test@test.com", + password="test", + is_staff=False, + ) + def test_runpython_manager_methods(self): def forwards(apps, schema_editor): UserModel = apps.get_model("auth", "User") @@ -301,9 +328,7 @@ class AbstractUserTestCase(TestCase): @override_settings(PASSWORD_HASHERS=PASSWORD_HASHERS) async def test_acheck_password_upgrade(self): - user = await sync_to_async(User.objects.create_user)( - username="user", password="foo" - ) + user = await User.objects.acreate_user(username="user", password="foo") initial_password = user.password self.assertIs(await user.acheck_password("foo"), True) hasher = get_hasher("default") @@ -557,6 +582,12 @@ class AnonymousUserTests(SimpleTestCase): self.assertEqual(self.user.get_user_permissions(), set()) self.assertEqual(self.user.get_group_permissions(), set()) + async def test_properties_async_versions(self): + self.assertEqual(await self.user.groups.acount(), 0) + self.assertEqual(await self.user.user_permissions.acount(), 0) + self.assertEqual(await self.user.aget_user_permissions(), set()) + self.assertEqual(await self.user.aget_group_permissions(), set()) + def test_str(self): self.assertEqual(str(self.user), "AnonymousUser") diff --git a/tests/auth_tests/test_remote_user.py b/tests/auth_tests/test_remote_user.py index d3cf4b9da5..85de931c1a 100644 --- a/tests/auth_tests/test_remote_user.py +++ b/tests/auth_tests/test_remote_user.py @@ -1,12 +1,18 @@ from datetime import datetime, timezone from django.conf import settings -from django.contrib.auth import authenticate +from django.contrib.auth import aauthenticate, authenticate from django.contrib.auth.backends import RemoteUserBackend from django.contrib.auth.middleware import RemoteUserMiddleware from django.contrib.auth.models import User from django.middleware.csrf import _get_new_csrf_string, _mask_cipher_secret -from django.test import Client, TestCase, modify_settings, override_settings +from django.test import ( + AsyncClient, + Client, + TestCase, + modify_settings, + override_settings, +) @override_settings(ROOT_URLCONF="auth_tests.urls") @@ -30,6 +36,11 @@ class RemoteUserTest(TestCase): ) super().setUpClass() + def test_passing_explicit_none(self): + msg = "get_response must be provided." + with self.assertRaisesMessage(ValueError, msg): + RemoteUserMiddleware(None) + def test_no_remote_user(self): """Users are not created when remote user is not specified.""" num_users = User.objects.count() @@ -46,6 +57,18 @@ class RemoteUserTest(TestCase): self.assertTrue(response.context["user"].is_anonymous) self.assertEqual(User.objects.count(), num_users) + async def test_no_remote_user_async(self): + """See test_no_remote_user.""" + num_users = await User.objects.acount() + + response = await self.async_client.get("/remote_user/") + self.assertTrue(response.context["user"].is_anonymous) + self.assertEqual(await User.objects.acount(), num_users) + + response = await self.async_client.get("/remote_user/", **{self.header: ""}) + self.assertTrue(response.context["user"].is_anonymous) + self.assertEqual(await User.objects.acount(), num_users) + def test_csrf_validation_passes_after_process_request_login(self): """ CSRF check must access the CSRF token from the session or cookie, @@ -75,6 +98,31 @@ class RemoteUserTest(TestCase): response = csrf_client.post("/remote_user/", data, **headers) self.assertEqual(response.status_code, 200) + async def test_csrf_validation_passes_after_process_request_login_async(self): + """See test_csrf_validation_passes_after_process_request_login.""" + csrf_client = AsyncClient(enforce_csrf_checks=True) + csrf_secret = _get_new_csrf_string() + csrf_token = _mask_cipher_secret(csrf_secret) + csrf_token_form = _mask_cipher_secret(csrf_secret) + headers = {self.header: "fakeuser"} + data = {"csrfmiddlewaretoken": csrf_token_form} + + # Verify that CSRF is configured for the view + csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token}) + response = await csrf_client.post("/remote_user/", **headers) + self.assertEqual(response.status_code, 403) + self.assertIn(b"CSRF verification failed.", response.content) + + # This request will call django.contrib.auth.alogin() which will call + # django.middleware.csrf.rotate_token() thus changing the value of + # request.META['CSRF_COOKIE'] from the user submitted value set by + # CsrfViewMiddleware.process_request() to the new csrftoken value set + # by rotate_token(). Csrf validation should still pass when the view is + # later processed by CsrfViewMiddleware.process_view() + csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token}) + response = await csrf_client.post("/remote_user/", data, **headers) + self.assertEqual(response.status_code, 200) + def test_unknown_user(self): """ Tests the case where the username passed in the header does not exist @@ -90,6 +138,22 @@ class RemoteUserTest(TestCase): response = self.client.get("/remote_user/", **{self.header: "newuser"}) self.assertEqual(User.objects.count(), num_users + 1) + async def test_unknown_user_async(self): + """See test_unknown_user.""" + num_users = await User.objects.acount() + response = await self.async_client.get( + "/remote_user/", **{self.header: "newuser"} + ) + self.assertEqual(response.context["user"].username, "newuser") + self.assertEqual(await User.objects.acount(), num_users + 1) + await User.objects.aget(username="newuser") + + # Another request with same user should not create any new users. + response = await self.async_client.get( + "/remote_user/", **{self.header: "newuser"} + ) + self.assertEqual(await User.objects.acount(), num_users + 1) + def test_known_user(self): """ Tests the case where the username passed in the header is a valid User. @@ -106,6 +170,24 @@ class RemoteUserTest(TestCase): self.assertEqual(response.context["user"].username, "knownuser2") self.assertEqual(User.objects.count(), num_users) + async def test_known_user_async(self): + """See test_known_user.""" + await User.objects.acreate(username="knownuser") + await User.objects.acreate(username="knownuser2") + num_users = await User.objects.acount() + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + self.assertEqual(await User.objects.acount(), num_users) + # A different user passed in the headers causes the new user + # to be logged in. + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user2} + ) + self.assertEqual(response.context["user"].username, "knownuser2") + self.assertEqual(await User.objects.acount(), num_users) + def test_last_login(self): """ A user's last_login is set the first time they make a @@ -128,6 +210,29 @@ class RemoteUserTest(TestCase): response = self.client.get("/remote_user/", **{self.header: self.known_user}) self.assertEqual(default_login, response.context["user"].last_login) + async def test_last_login_async(self): + """See test_last_login.""" + user = await User.objects.acreate(username="knownuser") + # Set last_login to something so we can determine if it changes. + default_login = datetime(2000, 1, 1) + if settings.USE_TZ: + default_login = default_login.replace(tzinfo=timezone.utc) + user.last_login = default_login + await user.asave() + + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertNotEqual(default_login, response.context["user"].last_login) + + user = await User.objects.aget(username="knownuser") + user.last_login = default_login + await user.asave() + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(default_login, response.context["user"].last_login) + def test_header_disappears(self): """ A logged in user is logged out automatically when @@ -148,6 +253,25 @@ class RemoteUserTest(TestCase): response = self.client.get("/remote_user/") self.assertEqual(response.context["user"].username, "modeluser") + async def test_header_disappears_async(self): + """See test_header_disappears.""" + await User.objects.acreate(username="knownuser") + # Known user authenticates + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + # During the session, the REMOTE_USER header disappears. Should trigger logout. + response = await self.async_client.get("/remote_user/") + self.assertTrue(response.context["user"].is_anonymous) + # verify the remoteuser middleware will not remove a user + # authenticated via another backend + await User.objects.acreate_user(username="modeluser", password="foo") + await self.async_client.alogin(username="modeluser", password="foo") + await aauthenticate(username="modeluser", password="foo") + response = await self.async_client.get("/remote_user/") + self.assertEqual(response.context["user"].username, "modeluser") + def test_user_switch_forces_new_login(self): """ If the username in the header changes between requests @@ -164,11 +288,35 @@ class RemoteUserTest(TestCase): # In backends that do not create new users, it is '' (anonymous user) self.assertNotEqual(response.context["user"].username, "knownuser") + async def test_user_switch_forces_new_login_async(self): + """See test_user_switch_forces_new_login.""" + await User.objects.acreate(username="knownuser") + # Known user authenticates + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + # During the session, the REMOTE_USER changes to a different user. + response = await self.async_client.get( + "/remote_user/", **{self.header: "newnewuser"} + ) + # The current user is not the prior remote_user. + # In backends that create a new user, username is "newnewuser" + # In backends that do not create new users, it is '' (anonymous user) + self.assertNotEqual(response.context["user"].username, "knownuser") + def test_inactive_user(self): User.objects.create(username="knownuser", is_active=False) response = self.client.get("/remote_user/", **{self.header: "knownuser"}) self.assertTrue(response.context["user"].is_anonymous) + async def test_inactive_user_async(self): + await User.objects.acreate(username="knownuser", is_active=False) + response = await self.async_client.get( + "/remote_user/", **{self.header: "knownuser"} + ) + self.assertTrue(response.context["user"].is_anonymous) + class RemoteUserNoCreateBackend(RemoteUserBackend): """Backend that doesn't create unknown users.""" @@ -190,6 +338,14 @@ class RemoteUserNoCreateTest(RemoteUserTest): self.assertTrue(response.context["user"].is_anonymous) self.assertEqual(User.objects.count(), num_users) + async def test_unknown_user_async(self): + num_users = await User.objects.acount() + response = await self.async_client.get( + "/remote_user/", **{self.header: "newuser"} + ) + self.assertTrue(response.context["user"].is_anonymous) + self.assertEqual(await User.objects.acount(), num_users) + class AllowAllUsersRemoteUserBackendTest(RemoteUserTest): """Backend that allows inactive users.""" @@ -201,6 +357,13 @@ class AllowAllUsersRemoteUserBackendTest(RemoteUserTest): response = self.client.get("/remote_user/", **{self.header: self.known_user}) self.assertEqual(response.context["user"].username, user.username) + async def test_inactive_user_async(self): + user = await User.objects.acreate(username="knownuser", is_active=False) + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, user.username) + class CustomRemoteUserBackend(RemoteUserBackend): """ @@ -311,3 +474,16 @@ class PersistentRemoteUserTest(RemoteUserTest): response = self.client.get("/remote_user/") self.assertFalse(response.context["user"].is_anonymous) self.assertEqual(response.context["user"].username, "knownuser") + + async def test_header_disappears_async(self): + """See test_header_disappears.""" + await User.objects.acreate(username="knownuser") + # Known user authenticates + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + # Should stay logged in if the REMOTE_USER header disappears. + response = await self.async_client.get("/remote_user/") + self.assertFalse(response.context["user"].is_anonymous) + self.assertEqual(response.context["user"].username, "knownuser") diff --git a/tests/deprecation/test_middleware_mixin.py b/tests/deprecation/test_middleware_mixin.py index f4eafc14e3..7e86832e2b 100644 --- a/tests/deprecation/test_middleware_mixin.py +++ b/tests/deprecation/test_middleware_mixin.py @@ -6,7 +6,6 @@ from django.contrib.admindocs.middleware import XViewMiddleware from django.contrib.auth.middleware import ( AuthenticationMiddleware, LoginRequiredMiddleware, - RemoteUserMiddleware, ) from django.contrib.flatpages.middleware import FlatpageFallbackMiddleware from django.contrib.messages.middleware import MessageMiddleware @@ -48,7 +47,6 @@ class MiddlewareMixinTests(SimpleTestCase): LocaleMiddleware, MessageMiddleware, RedirectFallbackMiddleware, - RemoteUserMiddleware, SecurityMiddleware, SessionMiddleware, UpdateCacheMiddleware,