1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #35303 -- Implemented async auth backends and utils.

This commit is contained in:
Jon Janzen 2024-03-31 12:29:10 -07:00 committed by Sarah Boyce
parent 4cad317ff1
commit 50f89ae850
17 changed files with 1285 additions and 61 deletions

View File

@ -1,8 +1,6 @@
import inspect import inspect
import re import re
from asgiref.sync import sync_to_async
from django.apps import apps as django_apps from django.apps import apps as django_apps
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.core.exceptions import ImproperlyConfigured, PermissionDenied
@ -40,6 +38,39 @@ def get_backends():
return _get_backends(return_tuples=False) return _get_backends(return_tuples=False)
def _get_compatible_backends(request, **credentials):
for backend, backend_path in _get_backends(return_tuples=True):
backend_signature = inspect.signature(backend.authenticate)
try:
backend_signature.bind(request, **credentials)
except TypeError:
# This backend doesn't accept these credentials as arguments. Try
# the next one.
continue
yield backend, backend_path
def _get_backend_from_user(user, backend=None):
try:
backend = backend or user.backend
except AttributeError:
backends = _get_backends(return_tuples=True)
if len(backends) == 1:
_, backend = backends[0]
else:
raise ValueError(
"You have multiple authentication backends configured and "
"therefore must provide the `backend` argument or set the "
"`backend` attribute on the user."
)
else:
if not isinstance(backend, str):
raise TypeError(
"backend must be a dotted import path string (got %r)." % backend
)
return backend
@sensitive_variables("credentials") @sensitive_variables("credentials")
def _clean_credentials(credentials): def _clean_credentials(credentials):
""" """
@ -62,19 +93,21 @@ def _get_user_session_key(request):
return get_user_model()._meta.pk.to_python(request.session[SESSION_KEY]) return get_user_model()._meta.pk.to_python(request.session[SESSION_KEY])
async def _aget_user_session_key(request):
# This value in the session is always serialized to a string, so we need
# to convert it back to Python whenever we access it.
session_key = await request.session.aget(SESSION_KEY)
if session_key is None:
raise KeyError()
return get_user_model()._meta.pk.to_python(session_key)
@sensitive_variables("credentials") @sensitive_variables("credentials")
def authenticate(request=None, **credentials): def authenticate(request=None, **credentials):
""" """
If the given credentials are valid, return a User object. If the given credentials are valid, return a User object.
""" """
for backend, backend_path in _get_backends(return_tuples=True): for backend, backend_path in _get_compatible_backends(request, **credentials):
backend_signature = inspect.signature(backend.authenticate)
try:
backend_signature.bind(request, **credentials)
except TypeError:
# This backend doesn't accept these credentials as arguments. Try
# the next one.
continue
try: try:
user = backend.authenticate(request, **credentials) user = backend.authenticate(request, **credentials)
except PermissionDenied: except PermissionDenied:
@ -96,7 +129,23 @@ def authenticate(request=None, **credentials):
@sensitive_variables("credentials") @sensitive_variables("credentials")
async def aauthenticate(request=None, **credentials): async def aauthenticate(request=None, **credentials):
"""See authenticate().""" """See authenticate()."""
return await sync_to_async(authenticate)(request, **credentials) for backend, backend_path in _get_compatible_backends(request, **credentials):
try:
user = await backend.aauthenticate(request, **credentials)
except PermissionDenied:
# This backend says to stop in our tracks - this user should not be
# allowed in at all.
break
if user is None:
continue
# Annotate the user object with the path of the backend.
user.backend = backend_path
return user
# The credentials supplied are invalid to all backends, fire signal.
await user_login_failed.asend(
sender=__name__, credentials=_clean_credentials(credentials), request=request
)
def login(request, user, backend=None): def login(request, user, backend=None):
@ -125,23 +174,7 @@ def login(request, user, backend=None):
else: else:
request.session.cycle_key() request.session.cycle_key()
try: backend = _get_backend_from_user(user=user, backend=backend)
backend = backend or user.backend
except AttributeError:
backends = _get_backends(return_tuples=True)
if len(backends) == 1:
_, backend = backends[0]
else:
raise ValueError(
"You have multiple authentication backends configured and "
"therefore must provide the `backend` argument or set the "
"`backend` attribute on the user."
)
else:
if not isinstance(backend, str):
raise TypeError(
"backend must be a dotted import path string (got %r)." % backend
)
request.session[SESSION_KEY] = user._meta.pk.value_to_string(user) request.session[SESSION_KEY] = user._meta.pk.value_to_string(user)
request.session[BACKEND_SESSION_KEY] = backend request.session[BACKEND_SESSION_KEY] = backend
@ -154,7 +187,36 @@ def login(request, user, backend=None):
async def alogin(request, user, backend=None): async def alogin(request, user, backend=None):
"""See login().""" """See login()."""
return await sync_to_async(login)(request, user, backend) session_auth_hash = ""
if user is None:
user = await request.auser()
if hasattr(user, "get_session_auth_hash"):
session_auth_hash = user.get_session_auth_hash()
if await request.session.ahas_key(SESSION_KEY):
if await _aget_user_session_key(request) != user.pk or (
session_auth_hash
and not constant_time_compare(
await request.session.aget(HASH_SESSION_KEY, ""),
session_auth_hash,
)
):
# To avoid reusing another user's session, create a new, empty
# session if the existing session corresponds to a different
# authenticated user.
await request.session.aflush()
else:
await request.session.acycle_key()
backend = _get_backend_from_user(user=user, backend=backend)
await request.session.aset(SESSION_KEY, user._meta.pk.value_to_string(user))
await request.session.aset(BACKEND_SESSION_KEY, backend)
await request.session.aset(HASH_SESSION_KEY, session_auth_hash)
if hasattr(request, "user"):
request.user = user
rotate_token(request)
await user_logged_in.asend(sender=user.__class__, request=request, user=user)
def logout(request): def logout(request):
@ -177,7 +239,19 @@ def logout(request):
async def alogout(request): async def alogout(request):
"""See logout().""" """See logout()."""
return await sync_to_async(logout)(request) # Dispatch the signal before the user is logged out so the receivers have a
# chance to find out *who* logged out.
user = getattr(request, "auser", None)
if user is not None:
user = await user()
if not getattr(user, "is_authenticated", True):
user = None
await user_logged_out.asend(sender=user.__class__, request=request, user=user)
await request.session.aflush()
if hasattr(request, "user"):
from django.contrib.auth.models import AnonymousUser
request.user = AnonymousUser()
def get_user_model(): def get_user_model():
@ -243,7 +317,43 @@ def get_user(request):
async def aget_user(request): async def aget_user(request):
"""See get_user().""" """See get_user()."""
return await sync_to_async(get_user)(request) from .models import AnonymousUser
user = None
try:
user_id = await _aget_user_session_key(request)
backend_path = await request.session.aget(BACKEND_SESSION_KEY)
except KeyError:
pass
else:
if backend_path in settings.AUTHENTICATION_BACKENDS:
backend = load_backend(backend_path)
user = await backend.aget_user(user_id)
# Verify the session
if hasattr(user, "get_session_auth_hash"):
session_hash = await request.session.aget(HASH_SESSION_KEY)
if not session_hash:
session_hash_verified = False
else:
session_auth_hash = user.get_session_auth_hash()
session_hash_verified = session_hash and constant_time_compare(
session_hash, user.get_session_auth_hash()
)
if not session_hash_verified:
# If the current secret does not verify the session, try
# with the fallback secrets and stop when a matching one is
# found.
if session_hash and any(
constant_time_compare(session_hash, fallback_auth_hash)
for fallback_auth_hash in user.get_session_auth_fallback_hash()
):
await request.session.acycle_key()
await request.session.aset(HASH_SESSION_KEY, session_auth_hash)
else:
await request.session.aflush()
user = None
return user or AnonymousUser()
def get_permission_codename(action, opts): def get_permission_codename(action, opts):

View File

@ -1,3 +1,5 @@
from asgiref.sync import sync_to_async
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.models import Permission from django.contrib.auth.models import Permission
from django.db.models import Exists, OuterRef, Q from django.db.models import Exists, OuterRef, Q
@ -9,24 +11,45 @@ class BaseBackend:
def authenticate(self, request, **kwargs): def authenticate(self, request, **kwargs):
return None return None
async def aauthenticate(self, request, **kwargs):
return await sync_to_async(self.authenticate)(request, **kwargs)
def get_user(self, user_id): def get_user(self, user_id):
return None return None
async def aget_user(self, user_id):
return await sync_to_async(self.get_user)(user_id)
def get_user_permissions(self, user_obj, obj=None): def get_user_permissions(self, user_obj, obj=None):
return set() return set()
async def aget_user_permissions(self, user_obj, obj=None):
return await sync_to_async(self.get_user_permissions)(user_obj, obj)
def get_group_permissions(self, user_obj, obj=None): def get_group_permissions(self, user_obj, obj=None):
return set() return set()
async def aget_group_permissions(self, user_obj, obj=None):
return await sync_to_async(self.get_group_permissions)(user_obj, obj)
def get_all_permissions(self, user_obj, obj=None): def get_all_permissions(self, user_obj, obj=None):
return { return {
*self.get_user_permissions(user_obj, obj=obj), *self.get_user_permissions(user_obj, obj=obj),
*self.get_group_permissions(user_obj, obj=obj), *self.get_group_permissions(user_obj, obj=obj),
} }
async def aget_all_permissions(self, user_obj, obj=None):
return {
*await self.aget_user_permissions(user_obj, obj=obj),
*await self.aget_group_permissions(user_obj, obj=obj),
}
def has_perm(self, user_obj, perm, obj=None): def has_perm(self, user_obj, perm, obj=None):
return perm in self.get_all_permissions(user_obj, obj=obj) return perm in self.get_all_permissions(user_obj, obj=obj)
async def ahas_perm(self, user_obj, perm, obj=None):
return perm in await self.aget_all_permissions(user_obj, obj)
class ModelBackend(BaseBackend): class ModelBackend(BaseBackend):
""" """
@ -48,6 +71,23 @@ class ModelBackend(BaseBackend):
if user.check_password(password) and self.user_can_authenticate(user): if user.check_password(password) and self.user_can_authenticate(user):
return user return user
async def aauthenticate(self, request, username=None, password=None, **kwargs):
if username is None:
username = kwargs.get(UserModel.USERNAME_FIELD)
if username is None or password is None:
return
try:
user = await UserModel._default_manager.aget_by_natural_key(username)
except UserModel.DoesNotExist:
# Run the default password hasher once to reduce the timing
# difference between an existing and a nonexistent user (#20760).
UserModel().set_password(password)
else:
if await user.acheck_password(password) and self.user_can_authenticate(
user
):
return user
def user_can_authenticate(self, user): def user_can_authenticate(self, user):
""" """
Reject users with is_active=False. Custom user models that don't have Reject users with is_active=False. Custom user models that don't have
@ -84,6 +124,25 @@ class ModelBackend(BaseBackend):
) )
return getattr(user_obj, perm_cache_name) return getattr(user_obj, perm_cache_name)
async def _aget_permissions(self, user_obj, obj, from_name):
"""See _get_permissions()."""
if not user_obj.is_active or user_obj.is_anonymous or obj is not None:
return set()
perm_cache_name = "_%s_perm_cache" % from_name
if not hasattr(user_obj, perm_cache_name):
if user_obj.is_superuser:
perms = Permission.objects.all()
else:
perms = getattr(self, "_get_%s_permissions" % from_name)(user_obj)
perms = perms.values_list("content_type__app_label", "codename").order_by()
setattr(
user_obj,
perm_cache_name,
{"%s.%s" % (ct, name) async for ct, name in perms},
)
return getattr(user_obj, perm_cache_name)
def get_user_permissions(self, user_obj, obj=None): def get_user_permissions(self, user_obj, obj=None):
""" """
Return a set of permission strings the user `user_obj` has from their Return a set of permission strings the user `user_obj` has from their
@ -91,6 +150,10 @@ class ModelBackend(BaseBackend):
""" """
return self._get_permissions(user_obj, obj, "user") return self._get_permissions(user_obj, obj, "user")
async def aget_user_permissions(self, user_obj, obj=None):
"""See get_user_permissions()."""
return await self._aget_permissions(user_obj, obj, "user")
def get_group_permissions(self, user_obj, obj=None): def get_group_permissions(self, user_obj, obj=None):
""" """
Return a set of permission strings the user `user_obj` has from the Return a set of permission strings the user `user_obj` has from the
@ -98,6 +161,10 @@ class ModelBackend(BaseBackend):
""" """
return self._get_permissions(user_obj, obj, "group") return self._get_permissions(user_obj, obj, "group")
async def aget_group_permissions(self, user_obj, obj=None):
"""See get_group_permissions()."""
return await self._aget_permissions(user_obj, obj, "group")
def get_all_permissions(self, user_obj, obj=None): def get_all_permissions(self, user_obj, obj=None):
if not user_obj.is_active or user_obj.is_anonymous or obj is not None: if not user_obj.is_active or user_obj.is_anonymous or obj is not None:
return set() return set()
@ -108,6 +175,9 @@ class ModelBackend(BaseBackend):
def has_perm(self, user_obj, perm, obj=None): def has_perm(self, user_obj, perm, obj=None):
return user_obj.is_active and super().has_perm(user_obj, perm, obj=obj) return user_obj.is_active and super().has_perm(user_obj, perm, obj=obj)
async def ahas_perm(self, user_obj, perm, obj=None):
return user_obj.is_active and await super().ahas_perm(user_obj, perm, obj=obj)
def has_module_perms(self, user_obj, app_label): def has_module_perms(self, user_obj, app_label):
""" """
Return True if user_obj has any permissions in the given app_label. Return True if user_obj has any permissions in the given app_label.
@ -117,6 +187,13 @@ class ModelBackend(BaseBackend):
for perm in self.get_all_permissions(user_obj) for perm in self.get_all_permissions(user_obj)
) )
async def ahas_module_perms(self, user_obj, app_label):
"""See has_module_perms()"""
return user_obj.is_active and any(
perm[: perm.index(".")] == app_label
for perm in await self.aget_all_permissions(user_obj)
)
def with_perm(self, perm, is_active=True, include_superusers=True, obj=None): def with_perm(self, perm, is_active=True, include_superusers=True, obj=None):
""" """
Return users that have permission "perm". By default, filter out Return users that have permission "perm". By default, filter out
@ -159,6 +236,13 @@ class ModelBackend(BaseBackend):
return None return None
return user if self.user_can_authenticate(user) else None return user if self.user_can_authenticate(user) else None
async def aget_user(self, user_id):
try:
user = await UserModel._default_manager.aget(pk=user_id)
except UserModel.DoesNotExist:
return None
return user if self.user_can_authenticate(user) else None
class AllowAllUsersModelBackend(ModelBackend): class AllowAllUsersModelBackend(ModelBackend):
def user_can_authenticate(self, user): def user_can_authenticate(self, user):
@ -210,6 +294,29 @@ class RemoteUserBackend(ModelBackend):
user = self.configure_user(request, user, created=created) user = self.configure_user(request, user, created=created)
return user if self.user_can_authenticate(user) else None return user if self.user_can_authenticate(user) else None
async def aauthenticate(self, request, remote_user):
"""See authenticate()."""
if not remote_user:
return
created = False
user = None
username = self.clean_username(remote_user)
# Note that this could be accomplished in one try-except clause, but
# instead we use get_or_create when creating unknown users since it has
# built-in safeguards for multiple threads.
if self.create_unknown_user:
user, created = await UserModel._default_manager.aget_or_create(
**{UserModel.USERNAME_FIELD: username}
)
else:
try:
user = await UserModel._default_manager.aget_by_natural_key(username)
except UserModel.DoesNotExist:
pass
user = await self.aconfigure_user(request, user, created=created)
return user if self.user_can_authenticate(user) else None
def clean_username(self, username): def clean_username(self, username):
""" """
Perform any cleaning on the "username" prior to using it to get or Perform any cleaning on the "username" prior to using it to get or
@ -227,6 +334,10 @@ class RemoteUserBackend(ModelBackend):
""" """
return user return user
async def aconfigure_user(self, request, user, created=True):
"""See configure_user()"""
return await sync_to_async(self.configure_user)(request, user, created)
class AllowAllUsersRemoteUserBackend(RemoteUserBackend): class AllowAllUsersRemoteUserBackend(RemoteUserBackend):
def user_can_authenticate(self, user): def user_can_authenticate(self, user):

View File

@ -36,6 +36,9 @@ class BaseUserManager(models.Manager):
def get_by_natural_key(self, username): def get_by_natural_key(self, username):
return self.get(**{self.model.USERNAME_FIELD: username}) return self.get(**{self.model.USERNAME_FIELD: username})
async def aget_by_natural_key(self, username):
return await self.aget(**{self.model.USERNAME_FIELD: username})
class AbstractBaseUser(models.Model): class AbstractBaseUser(models.Model):
password = models.CharField(_("password"), max_length=128) password = models.CharField(_("password"), max_length=128)

View File

@ -111,7 +111,7 @@ def permission_required(perm, login_url=None, raise_exception=False):
async def check_perms(user): async def check_perms(user):
# First check if the user has the permission (even anon users). # First check if the user has the permission (even anon users).
if await sync_to_async(user.has_perms)(perms): if await user.ahas_perms(perms):
return True return True
# In case the 403 handler should be called raise the exception. # In case the 403 handler should be called raise the exception.
if raise_exception: if raise_exception:

View File

@ -1,6 +1,8 @@
from functools import partial from functools import partial
from urllib.parse import urlsplit from urllib.parse import urlsplit
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.conf import settings from django.conf import settings
from django.contrib import auth from django.contrib import auth
from django.contrib.auth import REDIRECT_FIELD_NAME, load_backend from django.contrib.auth import REDIRECT_FIELD_NAME, load_backend
@ -88,7 +90,7 @@ class LoginRequiredMiddleware(MiddlewareMixin):
) )
class RemoteUserMiddleware(MiddlewareMixin): class RemoteUserMiddleware:
""" """
Middleware for utilizing web-server-provided authentication. Middleware for utilizing web-server-provided authentication.
@ -102,13 +104,27 @@ class RemoteUserMiddleware(MiddlewareMixin):
different header. different header.
""" """
sync_capable = True
async_capable = True
def __init__(self, get_response):
if get_response is None:
raise ValueError("get_response must be provided.")
self.get_response = get_response
self.is_async = iscoroutinefunction(get_response)
if self.is_async:
markcoroutinefunction(self)
super().__init__()
# Name of request header to grab username from. This will be the key as # Name of request header to grab username from. This will be the key as
# used in the request.META dictionary, i.e. the normalization of headers to # used in the request.META dictionary, i.e. the normalization of headers to
# all uppercase and the addition of "HTTP_" prefix apply. # all uppercase and the addition of "HTTP_" prefix apply.
header = "REMOTE_USER" header = "REMOTE_USER"
force_logout_if_no_header = True force_logout_if_no_header = True
def process_request(self, request): def __call__(self, request):
if self.is_async:
return self.__acall__(request)
# AuthenticationMiddleware is required so that request.user exists. # AuthenticationMiddleware is required so that request.user exists.
if not hasattr(request, "user"): if not hasattr(request, "user"):
raise ImproperlyConfigured( raise ImproperlyConfigured(
@ -126,13 +142,13 @@ class RemoteUserMiddleware(MiddlewareMixin):
# AnonymousUser by the AuthenticationMiddleware). # AnonymousUser by the AuthenticationMiddleware).
if self.force_logout_if_no_header and request.user.is_authenticated: if self.force_logout_if_no_header and request.user.is_authenticated:
self._remove_invalid_user(request) self._remove_invalid_user(request)
return return self.get_response(request)
# If the user is already authenticated and that user is the user we are # If the user is already authenticated and that user is the user we are
# getting passed in the headers, then the correct user is already # getting passed in the headers, then the correct user is already
# persisted in the session and we don't need to continue. # persisted in the session and we don't need to continue.
if request.user.is_authenticated: if request.user.is_authenticated:
if request.user.get_username() == self.clean_username(username, request): if request.user.get_username() == self.clean_username(username, request):
return return self.get_response(request)
else: else:
# An authenticated user is associated with the request, but # An authenticated user is associated with the request, but
# it does not match the authorized user in the header. # it does not match the authorized user in the header.
@ -146,6 +162,51 @@ class RemoteUserMiddleware(MiddlewareMixin):
# by logging the user in. # by logging the user in.
request.user = user request.user = user
auth.login(request, user) auth.login(request, user)
return self.get_response(request)
async def __acall__(self, request):
# AuthenticationMiddleware is required so that request.user exists.
if not hasattr(request, "user"):
raise ImproperlyConfigured(
"The Django remote user auth middleware requires the"
" authentication middleware to be installed. Edit your"
" MIDDLEWARE setting to insert"
" 'django.contrib.auth.middleware.AuthenticationMiddleware'"
" before the RemoteUserMiddleware class."
)
try:
username = request.META["HTTP_" + self.header]
except KeyError:
# If specified header doesn't exist then remove any existing
# authenticated remote-user, or return (leaving request.user set to
# AnonymousUser by the AuthenticationMiddleware).
if self.force_logout_if_no_header:
user = await request.auser()
if user.is_authenticated:
await self._aremove_invalid_user(request)
return await self.get_response(request)
user = await request.auser()
# If the user is already authenticated and that user is the user we are
# getting passed in the headers, then the correct user is already
# persisted in the session and we don't need to continue.
if user.is_authenticated:
if user.get_username() == self.clean_username(username, request):
return await self.get_response(request)
else:
# An authenticated user is associated with the request, but
# it does not match the authorized user in the header.
await self._aremove_invalid_user(request)
# We are seeing this user for the first time in this session, attempt
# to authenticate the user.
user = await auth.aauthenticate(request, remote_user=username)
if user:
# User is valid. Set request.user and persist user in the session
# by logging the user in.
request.user = user
await auth.alogin(request, user)
return await self.get_response(request)
def clean_username(self, username, request): def clean_username(self, username, request):
""" """
@ -176,6 +237,22 @@ class RemoteUserMiddleware(MiddlewareMixin):
if isinstance(stored_backend, RemoteUserBackend): if isinstance(stored_backend, RemoteUserBackend):
auth.logout(request) auth.logout(request)
async def _aremove_invalid_user(self, request):
"""
Remove the current authenticated user in the request which is invalid
but only if the user is authenticated via the RemoteUserBackend.
"""
try:
stored_backend = load_backend(
await request.session.aget(auth.BACKEND_SESSION_KEY, "")
)
except ImportError:
# Backend failed to load.
await auth.alogout(request)
else:
if isinstance(stored_backend, RemoteUserBackend):
await auth.alogout(request)
class PersistentRemoteUserMiddleware(RemoteUserMiddleware): class PersistentRemoteUserMiddleware(RemoteUserMiddleware):
""" """

View File

@ -95,6 +95,9 @@ class GroupManager(models.Manager):
def get_by_natural_key(self, name): def get_by_natural_key(self, name):
return self.get(name=name) return self.get(name=name)
async def aget_by_natural_key(self, name):
return await self.aget(name=name)
class Group(models.Model): class Group(models.Model):
""" """
@ -137,10 +140,7 @@ class Group(models.Model):
class UserManager(BaseUserManager): class UserManager(BaseUserManager):
use_in_migrations = True use_in_migrations = True
def _create_user(self, username, email, password, **extra_fields): def _create_user_object(self, username, email, password, **extra_fields):
"""
Create and save a user with the given username, email, and password.
"""
if not username: if not username:
raise ValueError("The given username must be set") raise ValueError("The given username must be set")
email = self.normalize_email(email) email = self.normalize_email(email)
@ -153,14 +153,32 @@ class UserManager(BaseUserManager):
username = GlobalUserModel.normalize_username(username) username = GlobalUserModel.normalize_username(username)
user = self.model(username=username, email=email, **extra_fields) user = self.model(username=username, email=email, **extra_fields)
user.password = make_password(password) user.password = make_password(password)
return user
def _create_user(self, username, email, password, **extra_fields):
"""
Create and save a user with the given username, email, and password.
"""
user = self._create_user_object(username, email, password, **extra_fields)
user.save(using=self._db) user.save(using=self._db)
return user return user
async def _acreate_user(self, username, email, password, **extra_fields):
"""See _create_user()"""
user = self._create_user_object(username, email, password, **extra_fields)
await user.asave(using=self._db)
return user
def create_user(self, username, email=None, password=None, **extra_fields): def create_user(self, username, email=None, password=None, **extra_fields):
extra_fields.setdefault("is_staff", False) extra_fields.setdefault("is_staff", False)
extra_fields.setdefault("is_superuser", False) extra_fields.setdefault("is_superuser", False)
return self._create_user(username, email, password, **extra_fields) return self._create_user(username, email, password, **extra_fields)
async def acreate_user(self, username, email=None, password=None, **extra_fields):
extra_fields.setdefault("is_staff", False)
extra_fields.setdefault("is_superuser", False)
return await self._acreate_user(username, email, password, **extra_fields)
def create_superuser(self, username, email=None, password=None, **extra_fields): def create_superuser(self, username, email=None, password=None, **extra_fields):
extra_fields.setdefault("is_staff", True) extra_fields.setdefault("is_staff", True)
extra_fields.setdefault("is_superuser", True) extra_fields.setdefault("is_superuser", True)
@ -172,6 +190,19 @@ class UserManager(BaseUserManager):
return self._create_user(username, email, password, **extra_fields) return self._create_user(username, email, password, **extra_fields)
async def acreate_superuser(
self, username, email=None, password=None, **extra_fields
):
extra_fields.setdefault("is_staff", True)
extra_fields.setdefault("is_superuser", True)
if extra_fields.get("is_staff") is not True:
raise ValueError("Superuser must have is_staff=True.")
if extra_fields.get("is_superuser") is not True:
raise ValueError("Superuser must have is_superuser=True.")
return await self._acreate_user(username, email, password, **extra_fields)
def with_perm( def with_perm(
self, perm, is_active=True, include_superusers=True, backend=None, obj=None self, perm, is_active=True, include_superusers=True, backend=None, obj=None
): ):
@ -210,6 +241,15 @@ def _user_get_permissions(user, obj, from_name):
return permissions return permissions
async def _auser_get_permissions(user, obj, from_name):
permissions = set()
name = "aget_%s_permissions" % from_name
for backend in auth.get_backends():
if hasattr(backend, name):
permissions.update(await getattr(backend, name)(user, obj))
return permissions
def _user_has_perm(user, perm, obj): def _user_has_perm(user, perm, obj):
""" """
A backend can raise `PermissionDenied` to short-circuit permission checking. A backend can raise `PermissionDenied` to short-circuit permission checking.
@ -225,6 +265,19 @@ def _user_has_perm(user, perm, obj):
return False return False
async def _auser_has_perm(user, perm, obj):
"""See _user_has_perm()"""
for backend in auth.get_backends():
if not hasattr(backend, "ahas_perm"):
continue
try:
if await backend.ahas_perm(user, perm, obj):
return True
except PermissionDenied:
return False
return False
def _user_has_module_perms(user, app_label): def _user_has_module_perms(user, app_label):
""" """
A backend can raise `PermissionDenied` to short-circuit permission checking. A backend can raise `PermissionDenied` to short-circuit permission checking.
@ -240,6 +293,19 @@ def _user_has_module_perms(user, app_label):
return False return False
async def _auser_has_module_perms(user, app_label):
"""See _user_has_module_perms()"""
for backend in auth.get_backends():
if not hasattr(backend, "ahas_module_perms"):
continue
try:
if await backend.ahas_module_perms(user, app_label):
return True
except PermissionDenied:
return False
return False
class PermissionsMixin(models.Model): class PermissionsMixin(models.Model):
""" """
Add the fields and methods necessary to support the Group and Permission Add the fields and methods necessary to support the Group and Permission
@ -285,6 +351,10 @@ class PermissionsMixin(models.Model):
""" """
return _user_get_permissions(self, obj, "user") return _user_get_permissions(self, obj, "user")
async def aget_user_permissions(self, obj=None):
"""See get_user_permissions()"""
return await _auser_get_permissions(self, obj, "user")
def get_group_permissions(self, obj=None): def get_group_permissions(self, obj=None):
""" """
Return a list of permission strings that this user has through their Return a list of permission strings that this user has through their
@ -293,9 +363,16 @@ class PermissionsMixin(models.Model):
""" """
return _user_get_permissions(self, obj, "group") return _user_get_permissions(self, obj, "group")
async def aget_group_permissions(self, obj=None):
"""See get_group_permissions()"""
return await _auser_get_permissions(self, obj, "group")
def get_all_permissions(self, obj=None): def get_all_permissions(self, obj=None):
return _user_get_permissions(self, obj, "all") return _user_get_permissions(self, obj, "all")
async def aget_all_permissions(self, obj=None):
return await _auser_get_permissions(self, obj, "all")
def has_perm(self, perm, obj=None): def has_perm(self, perm, obj=None):
""" """
Return True if the user has the specified permission. Query all Return True if the user has the specified permission. Query all
@ -311,6 +388,15 @@ class PermissionsMixin(models.Model):
# Otherwise we need to check the backends. # Otherwise we need to check the backends.
return _user_has_perm(self, perm, obj) return _user_has_perm(self, perm, obj)
async def ahas_perm(self, perm, obj=None):
"""See has_perm()"""
# Active superusers have all permissions.
if self.is_active and self.is_superuser:
return True
# Otherwise we need to check the backends.
return await _auser_has_perm(self, perm, obj)
def has_perms(self, perm_list, obj=None): def has_perms(self, perm_list, obj=None):
""" """
Return True if the user has each of the specified permissions. If Return True if the user has each of the specified permissions. If
@ -320,6 +406,15 @@ class PermissionsMixin(models.Model):
raise ValueError("perm_list must be an iterable of permissions.") raise ValueError("perm_list must be an iterable of permissions.")
return all(self.has_perm(perm, obj) for perm in perm_list) return all(self.has_perm(perm, obj) for perm in perm_list)
async def ahas_perms(self, perm_list, obj=None):
"""See has_perms()"""
if not isinstance(perm_list, Iterable) or isinstance(perm_list, str):
raise ValueError("perm_list must be an iterable of permissions.")
for perm in perm_list:
if not await self.ahas_perm(perm, obj):
return False
return True
def has_module_perms(self, app_label): def has_module_perms(self, app_label):
""" """
Return True if the user has any permissions in the given app label. Return True if the user has any permissions in the given app label.
@ -331,6 +426,14 @@ class PermissionsMixin(models.Model):
return _user_has_module_perms(self, app_label) return _user_has_module_perms(self, app_label)
async def ahas_module_perms(self, app_label):
"""See has_module_perms()"""
# Active superusers have all permissions.
if self.is_active and self.is_superuser:
return True
return await _auser_has_module_perms(self, app_label)
class AbstractUser(AbstractBaseUser, PermissionsMixin): class AbstractUser(AbstractBaseUser, PermissionsMixin):
""" """
@ -471,23 +574,46 @@ class AnonymousUser:
def get_user_permissions(self, obj=None): def get_user_permissions(self, obj=None):
return _user_get_permissions(self, obj, "user") return _user_get_permissions(self, obj, "user")
async def aget_user_permissions(self, obj=None):
return await _auser_get_permissions(self, obj, "user")
def get_group_permissions(self, obj=None): def get_group_permissions(self, obj=None):
return set() return set()
async def aget_group_permissions(self, obj=None):
return self.get_group_permissions(obj)
def get_all_permissions(self, obj=None): def get_all_permissions(self, obj=None):
return _user_get_permissions(self, obj, "all") return _user_get_permissions(self, obj, "all")
async def aget_all_permissions(self, obj=None):
return await _auser_get_permissions(self, obj, "all")
def has_perm(self, perm, obj=None): def has_perm(self, perm, obj=None):
return _user_has_perm(self, perm, obj=obj) return _user_has_perm(self, perm, obj=obj)
async def ahas_perm(self, perm, obj=None):
return await _auser_has_perm(self, perm, obj=obj)
def has_perms(self, perm_list, obj=None): def has_perms(self, perm_list, obj=None):
if not isinstance(perm_list, Iterable) or isinstance(perm_list, str): if not isinstance(perm_list, Iterable) or isinstance(perm_list, str):
raise ValueError("perm_list must be an iterable of permissions.") raise ValueError("perm_list must be an iterable of permissions.")
return all(self.has_perm(perm, obj) for perm in perm_list) return all(self.has_perm(perm, obj) for perm in perm_list)
async def ahas_perms(self, perm_list, obj=None):
if not isinstance(perm_list, Iterable) or isinstance(perm_list, str):
raise ValueError("perm_list must be an iterable of permissions.")
for perm in perm_list:
if not await self.ahas_perm(perm, obj):
return False
return True
def has_module_perms(self, module): def has_module_perms(self, module):
return _user_has_module_perms(self, module) return _user_has_module_perms(self, module)
async def ahas_module_perms(self, module):
return await _auser_has_module_perms(self, module)
@property @property
def is_anonymous(self): def is_anonymous(self):
return True return True

View File

@ -197,13 +197,23 @@ Methods
been called for this user. been called for this user.
.. method:: get_user_permissions(obj=None) .. method:: get_user_permissions(obj=None)
.. method:: aget_user_permissions(obj=None)
*Asynchronous version*: ``aget_user_permissions()``
Returns a set of permission strings that the user has directly. Returns a set of permission strings that the user has directly.
If ``obj`` is passed in, only returns the user permissions for this If ``obj`` is passed in, only returns the user permissions for this
specific object. specific object.
.. versionchanged:: 5.2
``aget_user_permissions()`` method was added.
.. method:: get_group_permissions(obj=None) .. method:: get_group_permissions(obj=None)
.. method:: aget_group_permissions(obj=None)
*Asynchronous version*: ``aget_group_permissions()``
Returns a set of permission strings that the user has, through their Returns a set of permission strings that the user has, through their
groups. groups.
@ -211,7 +221,14 @@ Methods
If ``obj`` is passed in, only returns the group permissions for If ``obj`` is passed in, only returns the group permissions for
this specific object. this specific object.
.. versionchanged:: 5.2
``aget_group_permissions()`` method was added.
.. method:: get_all_permissions(obj=None) .. method:: get_all_permissions(obj=None)
.. method:: aget_all_permissions(obj=None)
*Asynchronous version*: ``aget_all_permissions()``
Returns a set of permission strings that the user has, both through Returns a set of permission strings that the user has, both through
group and user permissions. group and user permissions.
@ -219,7 +236,14 @@ Methods
If ``obj`` is passed in, only returns the permissions for this If ``obj`` is passed in, only returns the permissions for this
specific object. specific object.
.. versionchanged:: 5.2
``aget_all_permissions()`` method was added.
.. method:: has_perm(perm, obj=None) .. method:: has_perm(perm, obj=None)
.. method:: ahas_perm(perm, obj=None)
*Asynchronous version*: ``ahas_perm()``
Returns ``True`` if the user has the specified permission, where perm Returns ``True`` if the user has the specified permission, where perm
is in the format ``"<app label>.<permission codename>"``. (see is in the format ``"<app label>.<permission codename>"``. (see
@ -230,7 +254,14 @@ Methods
If ``obj`` is passed in, this method won't check for a permission for If ``obj`` is passed in, this method won't check for a permission for
the model, but for this specific object. the model, but for this specific object.
.. versionchanged:: 5.2
``ahas_perm()`` method was added.
.. method:: has_perms(perm_list, obj=None) .. method:: has_perms(perm_list, obj=None)
.. method:: ahas_perms(perm_list, obj=None)
*Asynchronous version*: ``ahas_perms()``
Returns ``True`` if the user has each of the specified permissions, Returns ``True`` if the user has each of the specified permissions,
where each perm is in the format where each perm is in the format
@ -241,13 +272,24 @@ Methods
If ``obj`` is passed in, this method won't check for permissions for If ``obj`` is passed in, this method won't check for permissions for
the model, but for the specific object. the model, but for the specific object.
.. versionchanged:: 5.2
``ahas_perms()`` method was added.
.. method:: has_module_perms(package_name) .. method:: has_module_perms(package_name)
.. method:: ahas_module_perms(package_name)
*Asynchronous version*: ``ahas_module_perms()``
Returns ``True`` if the user has any permissions in the given package Returns ``True`` if the user has any permissions in the given package
(the Django app label). If the user is inactive, this method will (the Django app label). If the user is inactive, this method will
always return ``False``. For an active superuser, this method will always return ``False``. For an active superuser, this method will
always return ``True``. always return ``True``.
.. versionchanged:: 5.2
``ahas_module_perms()`` method was added.
.. method:: email_user(subject, message, from_email=None, **kwargs) .. method:: email_user(subject, message, from_email=None, **kwargs)
Sends an email to the user. If ``from_email`` is ``None``, Django uses Sends an email to the user. If ``from_email`` is ``None``, Django uses
@ -264,6 +306,9 @@ Manager methods
by :class:`~django.contrib.auth.models.BaseUserManager`): by :class:`~django.contrib.auth.models.BaseUserManager`):
.. method:: create_user(username, email=None, password=None, **extra_fields) .. method:: create_user(username, email=None, password=None, **extra_fields)
.. method:: acreate_user(username, email=None, password=None, **extra_fields)
*Asynchronous version*: ``acreate_user()``
Creates, saves and returns a :class:`~django.contrib.auth.models.User`. Creates, saves and returns a :class:`~django.contrib.auth.models.User`.
@ -285,11 +330,22 @@ Manager methods
See :ref:`Creating users <topics-auth-creating-users>` for example usage. See :ref:`Creating users <topics-auth-creating-users>` for example usage.
.. versionchanged:: 5.2
``acreate_user()`` method was added.
.. method:: create_superuser(username, email=None, password=None, **extra_fields) .. method:: create_superuser(username, email=None, password=None, **extra_fields)
.. method:: acreate_superuser(username, email=None, password=None, **extra_fields)
*Asynchronous version*: ``acreate_superuser()``
Same as :meth:`create_user`, but sets :attr:`~models.User.is_staff` and Same as :meth:`create_user`, but sets :attr:`~models.User.is_staff` and
:attr:`~models.User.is_superuser` to ``True``. :attr:`~models.User.is_superuser` to ``True``.
.. versionchanged:: 5.2
``acreate_superuser()`` method was added.
.. method:: with_perm(perm, is_active=True, include_superusers=True, backend=None, obj=None) .. method:: with_perm(perm, is_active=True, include_superusers=True, backend=None, obj=None)
Returns users that have the given permission ``perm`` either in the Returns users that have the given permission ``perm`` either in the
@ -499,23 +555,51 @@ The following backends are available in :mod:`django.contrib.auth.backends`:
methods. By default, it will reject any user and provide no permissions. methods. By default, it will reject any user and provide no permissions.
.. method:: get_user_permissions(user_obj, obj=None) .. method:: get_user_permissions(user_obj, obj=None)
.. method:: aget_user_permissions(user_obj, obj=None)
*Asynchronous version*: ``aget_user_permissions()``
Returns an empty set. Returns an empty set.
.. versionchanged:: 5.2
``aget_user_permissions()`` function was added.
.. method:: get_group_permissions(user_obj, obj=None) .. method:: get_group_permissions(user_obj, obj=None)
.. method:: aget_group_permissions(user_obj, obj=None)
*Asynchronous version*: ``aget_group_permissions()``
Returns an empty set. Returns an empty set.
.. versionchanged:: 5.2
``aget_group_permissions()`` function was added.
.. method:: get_all_permissions(user_obj, obj=None) .. method:: get_all_permissions(user_obj, obj=None)
.. method:: aget_all_permissions(user_obj, obj=None)
*Asynchronous version*: ``aget_all_permissions()``
Uses :meth:`get_user_permissions` and :meth:`get_group_permissions` to Uses :meth:`get_user_permissions` and :meth:`get_group_permissions` to
get the set of permission strings the ``user_obj`` has. get the set of permission strings the ``user_obj`` has.
.. versionchanged:: 5.2
``aget_all_permissions()`` function was added.
.. method:: has_perm(user_obj, perm, obj=None) .. method:: has_perm(user_obj, perm, obj=None)
.. method:: ahas_perm(user_obj, perm, obj=None)
*Asynchronous version*: ``ahas_perm()``
Uses :meth:`get_all_permissions` to check if ``user_obj`` has the Uses :meth:`get_all_permissions` to check if ``user_obj`` has the
permission string ``perm``. permission string ``perm``.
.. versionchanged:: 5.2
``ahas_perm()`` function was added.
.. class:: ModelBackend .. class:: ModelBackend
This is the default authentication backend used by Django. It This is the default authentication backend used by Django. It
@ -539,6 +623,9 @@ The following backends are available in :mod:`django.contrib.auth.backends`:
unlike others methods it returns an empty queryset if ``obj is not None``. unlike others methods it returns an empty queryset if ``obj is not None``.
.. method:: authenticate(request, username=None, password=None, **kwargs) .. method:: authenticate(request, username=None, password=None, **kwargs)
.. method:: aauthenticate(request, username=None, password=None, **kwargs)
*Asynchronous version*: ``aauthenticate()``
Tries to authenticate ``username`` with ``password`` by calling Tries to authenticate ``username`` with ``password`` by calling
:meth:`User.check_password :meth:`User.check_password
@ -552,38 +639,77 @@ The following backends are available in :mod:`django.contrib.auth.backends`:
if it wasn't provided to :func:`~django.contrib.auth.authenticate` if it wasn't provided to :func:`~django.contrib.auth.authenticate`
(which passes it on to the backend). (which passes it on to the backend).
.. versionchanged:: 5.2
``aauthenticate()`` function was added.
.. method:: get_user_permissions(user_obj, obj=None) .. method:: get_user_permissions(user_obj, obj=None)
.. method:: aget_user_permissions(user_obj, obj=None)
*Asynchronous version*: ``aget_user_permissions()``
Returns the set of permission strings the ``user_obj`` has from their Returns the set of permission strings the ``user_obj`` has from their
own user permissions. Returns an empty set if own user permissions. Returns an empty set if
:attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or :attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or
:attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``. :attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``.
.. versionchanged:: 5.2
``aget_user_permissions()`` function was added.
.. method:: get_group_permissions(user_obj, obj=None) .. method:: get_group_permissions(user_obj, obj=None)
.. method:: aget_group_permissions(user_obj, obj=None)
*Asynchronous version*: ``aget_group_permissions()``
Returns the set of permission strings the ``user_obj`` has from the Returns the set of permission strings the ``user_obj`` has from the
permissions of the groups they belong. Returns an empty set if permissions of the groups they belong. Returns an empty set if
:attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or :attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or
:attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``. :attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``.
.. versionchanged:: 5.2
``aget_group_permissions()`` function was added.
.. method:: get_all_permissions(user_obj, obj=None) .. method:: get_all_permissions(user_obj, obj=None)
.. method:: aget_all_permissions(user_obj, obj=None)
*Asynchronous version*: ``aget_all_permissions()``
Returns the set of permission strings the ``user_obj`` has, including both Returns the set of permission strings the ``user_obj`` has, including both
user permissions and group permissions. Returns an empty set if user permissions and group permissions. Returns an empty set if
:attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or :attr:`~django.contrib.auth.models.AbstractBaseUser.is_anonymous` or
:attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``. :attr:`~django.contrib.auth.models.CustomUser.is_active` is ``False``.
.. versionchanged:: 5.2
``aget_all_permissions()`` function was added.
.. method:: has_perm(user_obj, perm, obj=None) .. method:: has_perm(user_obj, perm, obj=None)
.. method:: ahas_perm(user_obj, perm, obj=None)
*Asynchronous version*: ``ahas_perm()``
Uses :meth:`get_all_permissions` to check if ``user_obj`` has the Uses :meth:`get_all_permissions` to check if ``user_obj`` has the
permission string ``perm``. Returns ``False`` if the user is not permission string ``perm``. Returns ``False`` if the user is not
:attr:`~django.contrib.auth.models.CustomUser.is_active`. :attr:`~django.contrib.auth.models.CustomUser.is_active`.
.. versionchanged:: 5.2
``ahas_perm()`` function was added.
.. method:: has_module_perms(user_obj, app_label) .. method:: has_module_perms(user_obj, app_label)
.. method:: ahas_module_perms(user_obj, app_label)
*Asynchronous version*: ``ahas_module_perms()``
Returns whether the ``user_obj`` has any permissions on the app Returns whether the ``user_obj`` has any permissions on the app
``app_label``. ``app_label``.
.. versionchanged:: 5.2
``ahas_module_perms()`` function was added.
.. method:: user_can_authenticate() .. method:: user_can_authenticate()
Returns whether the user is allowed to authenticate. To match the Returns whether the user is allowed to authenticate. To match the
@ -637,6 +763,9 @@ The following backends are available in :mod:`django.contrib.auth.backends`:
created if not already in the database Defaults to ``True``. created if not already in the database Defaults to ``True``.
.. method:: authenticate(request, remote_user) .. method:: authenticate(request, remote_user)
.. method:: aauthenticate(request, remote_user)
*Asynchronous version*: ``aauthenticate()``
The username passed as ``remote_user`` is considered trusted. This The username passed as ``remote_user`` is considered trusted. This
method returns the user object with the given username, creating a new method returns the user object with the given username, creating a new
@ -651,6 +780,10 @@ The following backends are available in :mod:`django.contrib.auth.backends`:
if it wasn't provided to :func:`~django.contrib.auth.authenticate` if it wasn't provided to :func:`~django.contrib.auth.authenticate`
(which passes it on to the backend). (which passes it on to the backend).
.. versionchanged:: 5.2
``aauthenticate()`` function was added.
.. method:: clean_username(username) .. method:: clean_username(username)
Performs any cleaning on the ``username`` (e.g. stripping LDAP DN Performs any cleaning on the ``username`` (e.g. stripping LDAP DN
@ -658,12 +791,17 @@ The following backends are available in :mod:`django.contrib.auth.backends`:
the cleaned username. the cleaned username.
.. method:: configure_user(request, user, created=True) .. method:: configure_user(request, user, created=True)
.. method:: aconfigure_user(request, user, created=True)
*Asynchronous version*: ``aconfigure_user()``
Configures the user on each authentication attempt. This method is Configures the user on each authentication attempt. This method is
called immediately after fetching or creating the user being called immediately after fetching or creating the user being
authenticated, and can be used to perform custom setup actions, such as authenticated, and can be used to perform custom setup actions, such as
setting the user's groups based on attributes in an LDAP directory. setting the user's groups based on attributes in an LDAP directory.
Returns the user object. Returns the user object. When fetching or creating an user is called
from a synchronous context, ``configure_user`` is called,
``aconfigure_user`` is called from async contexts.
The setup can be performed either once when the user is created The setup can be performed either once when the user is created
(``created`` is ``True``) or on existing users (``created`` is (``created`` is ``True``) or on existing users (``created`` is
@ -674,6 +812,10 @@ The following backends are available in :mod:`django.contrib.auth.backends`:
if it wasn't provided to :func:`~django.contrib.auth.authenticate` if it wasn't provided to :func:`~django.contrib.auth.authenticate`
(which passes it on to the backend). (which passes it on to the backend).
.. versionchanged:: 5.2
``aconfigure_user()`` function was added.
.. method:: user_can_authenticate() .. method:: user_can_authenticate()
Returns whether the user is allowed to authenticate. This method Returns whether the user is allowed to authenticate. This method

View File

@ -52,6 +52,36 @@ Minor features
* The default iteration count for the PBKDF2 password hasher is increased from * The default iteration count for the PBKDF2 password hasher is increased from
870,000 to 1,000,000. 870,000 to 1,000,000.
* The following new asynchronous methods on are now provided, using an ``a``
prefix:
* :meth:`.UserManager.acreate_user`
* :meth:`.UserManager.acreate_superuser`
* :meth:`.BaseUserManager.aget_by_natural_key`
* :meth:`.User.aget_user_permissions()`
* :meth:`.User.aget_all_permissions()`
* :meth:`.User.aget_group_permissions()`
* :meth:`.User.ahas_perm()`
* :meth:`.User.ahas_perms()`
* :meth:`.User.ahas_module_perms()`
* :meth:`.User.aget_user_permissions()`
* :meth:`.User.aget_group_permissions()`
* :meth:`.User.ahas_perm()`
* :meth:`.ModelBackend.aauthenticate()`
* :meth:`.ModelBackend.aget_user_permissions()`
* :meth:`.ModelBackend.aget_group_permissions()`
* :meth:`.ModelBackend.aget_all_permissions()`
* :meth:`.ModelBackend.ahas_perm()`
* :meth:`.ModelBackend.ahas_module_perms()`
* :meth:`.RemoteUserBackend.aauthenticate()`
* :meth:`.RemoteUserBackend.aconfigure_user()`
* Auth backends can now provide async implementations which are used when
calling async auth functions (e.g.
:func:`~.django.contrib.auth.aauthenticate`) to reduce context-switching which
improves performance. See :ref:`adding an async interface
<writing-authentication-backends-async-interface>` for more details.
:mod:`django.contrib.contenttypes` :mod:`django.contrib.contenttypes`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -790,10 +790,17 @@ utility methods:
email address. email address.
.. method:: models.BaseUserManager.get_by_natural_key(username) .. method:: models.BaseUserManager.get_by_natural_key(username)
.. method:: models.BaseUserManager.aget_by_natural_key(username)
*Asynchronous version*: ``aget_by_natural_key()``
Retrieves a user instance using the contents of the field Retrieves a user instance using the contents of the field
nominated by ``USERNAME_FIELD``. nominated by ``USERNAME_FIELD``.
.. versionchanged:: 5.2
``aget_by_natural_key()`` method was added.
Extending Django's default ``User`` Extending Django's default ``User``
----------------------------------- -----------------------------------
@ -1186,3 +1193,25 @@ Finally, specify the custom model as the default user model for your project
using the :setting:`AUTH_USER_MODEL` setting in your ``settings.py``:: using the :setting:`AUTH_USER_MODEL` setting in your ``settings.py``::
AUTH_USER_MODEL = "customauth.MyUser" AUTH_USER_MODEL = "customauth.MyUser"
.. _writing-authentication-backends-async-interface:
Adding an async interface
~~~~~~~~~~~~~~~~~~~~~~~~~
.. versionadded:: 5.2
To optimize performance when called from an async context authentication,
backends can implement async versions of each function - ``aget_user(user_id)``
and ``aauthenticate(request, **credentials)``. When an authentication backend
extends ``BaseBackend`` and async versions of these functions are not provided,
they will be automatically synthesized with ``sync_to_async``. This has
:ref:`performance penalties <async_performance>`.
While an async interface is optional, a synchronous interface is always
required. There is no automatic synthesis for a synchronous interface if an
async interface is implemented.
Django's out-of-the-box authentication backends have native async support. If
these native backends are extended take special care to make sure the async
versions of modified functions are modified as well.

View File

@ -33,9 +33,40 @@ class AsyncAuthTest(TestCase):
self.assertIsInstance(user, User) self.assertIsInstance(user, User)
self.assertEqual(user.username, self.test_user.username) self.assertEqual(user.username, self.test_user.username)
async def test_changed_password_invalidates_aget_user(self):
request = HttpRequest()
request.session = await self.client.asession()
await alogin(request, self.test_user)
self.test_user.set_password("new_password")
await self.test_user.asave()
user = await aget_user(request)
self.assertIsNotNone(user)
self.assertTrue(user.is_anonymous)
# Session should be flushed.
self.assertIsNone(request.session.session_key)
async def test_alogin_new_user(self):
request = HttpRequest()
request.session = await self.client.asession()
await alogin(request, self.test_user)
second_user = await User.objects.acreate_user(
"testuser2", "test2@example.com", "testpw2"
)
await alogin(request, second_user)
user = await aget_user(request)
self.assertIsInstance(user, User)
self.assertEqual(user.username, second_user.username)
async def test_alogin_without_user(self): async def test_alogin_without_user(self):
async def auser():
return self.test_user
request = HttpRequest() request = HttpRequest()
request.user = self.test_user request.user = self.test_user
request.auser = auser
request.session = await self.client.asession() request.session = await self.client.asession()
await alogin(request, None) await alogin(request, None)
user = await aget_user(request) user = await aget_user(request)

View File

@ -29,6 +29,19 @@ class CustomUserManager(BaseUserManager):
user.save(using=self._db) user.save(using=self._db)
return user return user
async def acreate_user(self, email, date_of_birth, password=None, **fields):
"""See create_user()"""
if not email:
raise ValueError("Users must have an email address")
user = self.model(
email=self.normalize_email(email), date_of_birth=date_of_birth, **fields
)
user.set_password(password)
await user.asave(using=self._db)
return user
def create_superuser(self, email, password, date_of_birth, **fields): def create_superuser(self, email, password, date_of_birth, **fields):
u = self.create_user( u = self.create_user(
email, password=password, date_of_birth=date_of_birth, **fields email, password=password, date_of_birth=date_of_birth, **fields

View File

@ -2,6 +2,8 @@ import sys
from datetime import date from datetime import date
from unittest import mock from unittest import mock
from asgiref.sync import sync_to_async
from django.contrib.auth import ( from django.contrib.auth import (
BACKEND_SESSION_KEY, BACKEND_SESSION_KEY,
SESSION_KEY, SESSION_KEY,
@ -55,17 +57,33 @@ class BaseBackendTest(TestCase):
def test_get_user_permissions(self): def test_get_user_permissions(self):
self.assertEqual(self.user.get_user_permissions(), {"user_perm"}) self.assertEqual(self.user.get_user_permissions(), {"user_perm"})
async def test_aget_user_permissions(self):
self.assertEqual(await self.user.aget_user_permissions(), {"user_perm"})
def test_get_group_permissions(self): def test_get_group_permissions(self):
self.assertEqual(self.user.get_group_permissions(), {"group_perm"}) self.assertEqual(self.user.get_group_permissions(), {"group_perm"})
async def test_aget_group_permissions(self):
self.assertEqual(await self.user.aget_group_permissions(), {"group_perm"})
def test_get_all_permissions(self): def test_get_all_permissions(self):
self.assertEqual(self.user.get_all_permissions(), {"user_perm", "group_perm"}) self.assertEqual(self.user.get_all_permissions(), {"user_perm", "group_perm"})
async def test_aget_all_permissions(self):
self.assertEqual(
await self.user.aget_all_permissions(), {"user_perm", "group_perm"}
)
def test_has_perm(self): def test_has_perm(self):
self.assertIs(self.user.has_perm("user_perm"), True) self.assertIs(self.user.has_perm("user_perm"), True)
self.assertIs(self.user.has_perm("group_perm"), True) self.assertIs(self.user.has_perm("group_perm"), True)
self.assertIs(self.user.has_perm("other_perm", TestObj()), False) self.assertIs(self.user.has_perm("other_perm", TestObj()), False)
async def test_ahas_perm(self):
self.assertIs(await self.user.ahas_perm("user_perm"), True)
self.assertIs(await self.user.ahas_perm("group_perm"), True)
self.assertIs(await self.user.ahas_perm("other_perm", TestObj()), False)
def test_has_perms_perm_list_invalid(self): def test_has_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions." msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
@ -73,6 +91,13 @@ class BaseBackendTest(TestCase):
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
self.user.has_perms(object()) self.user.has_perms(object())
async def test_ahas_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
await self.user.ahas_perms("user_perm")
with self.assertRaisesMessage(ValueError, msg):
await self.user.ahas_perms(object())
class CountingMD5PasswordHasher(MD5PasswordHasher): class CountingMD5PasswordHasher(MD5PasswordHasher):
"""Hasher that counts how many times it computes a hash.""" """Hasher that counts how many times it computes a hash."""
@ -125,6 +150,25 @@ class BaseModelBackendTest:
user.save() user.save()
self.assertIs(user.has_perm("auth.test"), False) self.assertIs(user.has_perm("auth.test"), False)
async def test_ahas_perm(self):
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
self.assertIs(await user.ahas_perm("auth.test"), False)
user.is_staff = True
await user.asave()
self.assertIs(await user.ahas_perm("auth.test"), False)
user.is_superuser = True
await user.asave()
self.assertIs(await user.ahas_perm("auth.test"), True)
self.assertIs(await user.ahas_module_perms("auth"), True)
user.is_staff = True
user.is_superuser = True
user.is_active = False
await user.asave()
self.assertIs(await user.ahas_perm("auth.test"), False)
def test_custom_perms(self): def test_custom_perms(self):
user = self.UserModel._default_manager.get(pk=self.user.pk) user = self.UserModel._default_manager.get(pk=self.user.pk)
content_type = ContentType.objects.get_for_model(Group) content_type = ContentType.objects.get_for_model(Group)
@ -174,6 +218,55 @@ class BaseModelBackendTest:
self.assertIs(user.has_perm("test"), False) self.assertIs(user.has_perm("test"), False)
self.assertIs(user.has_perms(["auth.test2", "auth.test3"]), False) self.assertIs(user.has_perms(["auth.test2", "auth.test3"]), False)
async def test_acustom_perms(self):
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await user.user_permissions.aadd(perm)
# Reloading user to purge the _perm_cache.
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
self.assertEqual(await user.aget_all_permissions(), {"auth.test"})
self.assertEqual(await user.aget_user_permissions(), {"auth.test"})
self.assertEqual(await user.aget_group_permissions(), set())
self.assertIs(await user.ahas_module_perms("Group"), False)
self.assertIs(await user.ahas_module_perms("auth"), True)
perm = await Permission.objects.acreate(
name="test2", content_type=content_type, codename="test2"
)
await user.user_permissions.aadd(perm)
perm = await Permission.objects.acreate(
name="test3", content_type=content_type, codename="test3"
)
await user.user_permissions.aadd(perm)
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
expected_user_perms = {"auth.test2", "auth.test", "auth.test3"}
self.assertEqual(await user.aget_all_permissions(), expected_user_perms)
self.assertIs(await user.ahas_perm("test"), False)
self.assertIs(await user.ahas_perm("auth.test"), True)
self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), True)
perm = await Permission.objects.acreate(
name="test_group", content_type=content_type, codename="test_group"
)
group = await Group.objects.acreate(name="test_group")
await group.permissions.aadd(perm)
await user.groups.aadd(group)
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
self.assertEqual(
await user.aget_all_permissions(), {*expected_user_perms, "auth.test_group"}
)
self.assertEqual(await user.aget_user_permissions(), expected_user_perms)
self.assertEqual(await user.aget_group_permissions(), {"auth.test_group"})
self.assertIs(await user.ahas_perms(["auth.test3", "auth.test_group"]), True)
user = AnonymousUser()
self.assertIs(await user.ahas_perm("test"), False)
self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), False)
def test_has_no_object_perm(self): def test_has_no_object_perm(self):
"""Regressiontest for #12462""" """Regressiontest for #12462"""
user = self.UserModel._default_manager.get(pk=self.user.pk) user = self.UserModel._default_manager.get(pk=self.user.pk)
@ -188,6 +281,20 @@ class BaseModelBackendTest:
self.assertIs(user.has_perm("auth.test"), True) self.assertIs(user.has_perm("auth.test"), True)
self.assertEqual(user.get_all_permissions(), {"auth.test"}) self.assertEqual(user.get_all_permissions(), {"auth.test"})
async def test_ahas_no_object_perm(self):
"""See test_has_no_object_perm()"""
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await user.user_permissions.aadd(perm)
self.assertIs(await user.ahas_perm("auth.test", "object"), False)
self.assertEqual(await user.aget_all_permissions("object"), set())
self.assertIs(await user.ahas_perm("auth.test"), True)
self.assertEqual(await user.aget_all_permissions(), {"auth.test"})
def test_anonymous_has_no_permissions(self): def test_anonymous_has_no_permissions(self):
""" """
#17903 -- Anonymous users shouldn't have permissions in #17903 -- Anonymous users shouldn't have permissions in
@ -220,6 +327,38 @@ class BaseModelBackendTest:
self.assertEqual(backend.get_user_permissions(user), set()) self.assertEqual(backend.get_user_permissions(user), set())
self.assertEqual(backend.get_group_permissions(user), set()) self.assertEqual(backend.get_group_permissions(user), set())
async def test_aanonymous_has_no_permissions(self):
"""See test_anonymous_has_no_permissions()"""
backend = ModelBackend()
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
user_perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test_user"
)
group_perm = await Permission.objects.acreate(
name="test2", content_type=content_type, codename="test_group"
)
await user.user_permissions.aadd(user_perm)
group = await Group.objects.acreate(name="test_group")
await user.groups.aadd(group)
await group.permissions.aadd(group_perm)
self.assertEqual(
await backend.aget_all_permissions(user),
{"auth.test_user", "auth.test_group"},
)
self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"})
self.assertEqual(
await backend.aget_group_permissions(user), {"auth.test_group"}
)
with mock.patch.object(self.UserModel, "is_anonymous", True):
self.assertEqual(await backend.aget_all_permissions(user), set())
self.assertEqual(await backend.aget_user_permissions(user), set())
self.assertEqual(await backend.aget_group_permissions(user), set())
def test_inactive_has_no_permissions(self): def test_inactive_has_no_permissions(self):
""" """
#17903 -- Inactive users shouldn't have permissions in #17903 -- Inactive users shouldn't have permissions in
@ -254,11 +393,52 @@ class BaseModelBackendTest:
self.assertEqual(backend.get_user_permissions(user), set()) self.assertEqual(backend.get_user_permissions(user), set())
self.assertEqual(backend.get_group_permissions(user), set()) self.assertEqual(backend.get_group_permissions(user), set())
async def test_ainactive_has_no_permissions(self):
"""See test_inactive_has_no_permissions()"""
backend = ModelBackend()
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
user_perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test_user"
)
group_perm = await Permission.objects.acreate(
name="test2", content_type=content_type, codename="test_group"
)
await user.user_permissions.aadd(user_perm)
group = await Group.objects.acreate(name="test_group")
await user.groups.aadd(group)
await group.permissions.aadd(group_perm)
self.assertEqual(
await backend.aget_all_permissions(user),
{"auth.test_user", "auth.test_group"},
)
self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"})
self.assertEqual(
await backend.aget_group_permissions(user), {"auth.test_group"}
)
user.is_active = False
await user.asave()
self.assertEqual(await backend.aget_all_permissions(user), set())
self.assertEqual(await backend.aget_user_permissions(user), set())
self.assertEqual(await backend.aget_group_permissions(user), set())
def test_get_all_superuser_permissions(self): def test_get_all_superuser_permissions(self):
"""A superuser has all permissions. Refs #14795.""" """A superuser has all permissions. Refs #14795."""
user = self.UserModel._default_manager.get(pk=self.superuser.pk) user = self.UserModel._default_manager.get(pk=self.superuser.pk)
self.assertEqual(len(user.get_all_permissions()), len(Permission.objects.all())) self.assertEqual(len(user.get_all_permissions()), len(Permission.objects.all()))
async def test_aget_all_superuser_permissions(self):
"""See test_get_all_superuser_permissions()"""
user = await self.UserModel._default_manager.aget(pk=self.superuser.pk)
self.assertEqual(
len(await user.aget_all_permissions()), await Permission.objects.acount()
)
@override_settings( @override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
) )
@ -277,6 +457,24 @@ class BaseModelBackendTest:
authenticate(username="no_such_user", password="test") authenticate(username="no_such_user", password="test")
self.assertEqual(CountingMD5PasswordHasher.calls, 1) self.assertEqual(CountingMD5PasswordHasher.calls, 1)
@override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
)
async def test_aauthentication_timing(self):
"""See test_authentication_timing()"""
# Re-set the password, because this tests overrides PASSWORD_HASHERS.
self.user.set_password("test")
await self.user.asave()
CountingMD5PasswordHasher.calls = 0
username = getattr(self.user, self.UserModel.USERNAME_FIELD)
await aauthenticate(username=username, password="test")
self.assertEqual(CountingMD5PasswordHasher.calls, 1)
CountingMD5PasswordHasher.calls = 0
await aauthenticate(username="no_such_user", password="test")
self.assertEqual(CountingMD5PasswordHasher.calls, 1)
@override_settings( @override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
) )
@ -320,6 +518,15 @@ class ModelBackendTest(BaseModelBackendTest, TestCase):
self.user.save() self.user.save()
self.assertIsNone(authenticate(**self.user_credentials)) self.assertIsNone(authenticate(**self.user_credentials))
async def test_aauthenticate_inactive(self):
"""
An inactive user can't authenticate.
"""
self.assertEqual(await aauthenticate(**self.user_credentials), self.user)
self.user.is_active = False
await self.user.asave()
self.assertIsNone(await aauthenticate(**self.user_credentials))
@override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField") @override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField")
def test_authenticate_user_without_is_active_field(self): def test_authenticate_user_without_is_active_field(self):
""" """
@ -332,6 +539,18 @@ class ModelBackendTest(BaseModelBackendTest, TestCase):
) )
self.assertEqual(authenticate(username="test", password="test"), user) self.assertEqual(authenticate(username="test", password="test"), user)
@override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField")
async def test_aauthenticate_user_without_is_active_field(self):
"""
A custom user without an `is_active` field is allowed to authenticate.
"""
user = await CustomUserWithoutIsActiveField.objects._acreate_user(
username="test",
email="test@example.com",
password="test",
)
self.assertEqual(await aauthenticate(username="test", password="test"), user)
@override_settings(AUTH_USER_MODEL="auth_tests.ExtensionUser") @override_settings(AUTH_USER_MODEL="auth_tests.ExtensionUser")
class ExtensionUserModelBackendTest(BaseModelBackendTest, TestCase): class ExtensionUserModelBackendTest(BaseModelBackendTest, TestCase):
@ -403,6 +622,15 @@ class CustomUserModelBackendAuthenticateTest(TestCase):
authenticated_user = authenticate(email="test@example.com", password="test") authenticated_user = authenticate(email="test@example.com", password="test")
self.assertEqual(test_user, authenticated_user) self.assertEqual(test_user, authenticated_user)
async def test_aauthenticate(self):
test_user = await CustomUser._default_manager.acreate_user(
email="test@example.com", password="test", date_of_birth=date(2006, 4, 25)
)
authenticated_user = await aauthenticate(
email="test@example.com", password="test"
)
self.assertEqual(test_user, authenticated_user)
@override_settings(AUTH_USER_MODEL="auth_tests.UUIDUser") @override_settings(AUTH_USER_MODEL="auth_tests.UUIDUser")
class UUIDUserTests(TestCase): class UUIDUserTests(TestCase):
@ -416,6 +644,13 @@ class UUIDUserTests(TestCase):
UUIDUser.objects.get(pk=self.client.session[SESSION_KEY]), user UUIDUser.objects.get(pk=self.client.session[SESSION_KEY]), user
) )
async def test_alogin(self):
"""See test_login()"""
user = await UUIDUser.objects.acreate_user(username="uuid", password="test")
self.assertTrue(await self.client.alogin(username="uuid", password="test"))
session_key = await self.client.session.aget(SESSION_KEY)
self.assertEqual(await UUIDUser.objects.aget(pk=session_key), user)
class TestObj: class TestObj:
pass pass
@ -435,9 +670,15 @@ class SimpleRowlevelBackend:
return True return True
return False return False
async def ahas_perm(self, user, perm, obj=None):
return self.has_perm(user, perm, obj)
def has_module_perms(self, user, app_label): def has_module_perms(self, user, app_label):
return (user.is_anonymous or user.is_active) and app_label == "app1" return (user.is_anonymous or user.is_active) and app_label == "app1"
async def ahas_module_perms(self, user, app_label):
return self.has_module_perms(user, app_label)
def get_all_permissions(self, user, obj=None): def get_all_permissions(self, user, obj=None):
if not obj: if not obj:
return [] # We only support row level perms return [] # We only support row level perms
@ -452,6 +693,9 @@ class SimpleRowlevelBackend:
else: else:
return ["simple"] return ["simple"]
async def aget_all_permissions(self, user, obj=None):
return self.get_all_permissions(user, obj)
def get_group_permissions(self, user, obj=None): def get_group_permissions(self, user, obj=None):
if not obj: if not obj:
return # We only support row level perms return # We only support row level perms
@ -524,10 +768,18 @@ class AnonymousUserBackendTest(SimpleTestCase):
self.assertIs(self.user1.has_perm("perm", TestObj()), False) self.assertIs(self.user1.has_perm("perm", TestObj()), False)
self.assertIs(self.user1.has_perm("anon", TestObj()), True) self.assertIs(self.user1.has_perm("anon", TestObj()), True)
async def test_ahas_perm(self):
self.assertIs(await self.user1.ahas_perm("perm", TestObj()), False)
self.assertIs(await self.user1.ahas_perm("anon", TestObj()), True)
def test_has_perms(self): def test_has_perms(self):
self.assertIs(self.user1.has_perms(["anon"], TestObj()), True) self.assertIs(self.user1.has_perms(["anon"], TestObj()), True)
self.assertIs(self.user1.has_perms(["anon", "perm"], TestObj()), False) self.assertIs(self.user1.has_perms(["anon", "perm"], TestObj()), False)
async def test_ahas_perms(self):
self.assertIs(await self.user1.ahas_perms(["anon"], TestObj()), True)
self.assertIs(await self.user1.ahas_perms(["anon", "perm"], TestObj()), False)
def test_has_perms_perm_list_invalid(self): def test_has_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions." msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
@ -535,13 +787,27 @@ class AnonymousUserBackendTest(SimpleTestCase):
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
self.user1.has_perms(object()) self.user1.has_perms(object())
async def test_ahas_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
await self.user1.ahas_perms("perm")
with self.assertRaisesMessage(ValueError, msg):
await self.user1.ahas_perms(object())
def test_has_module_perms(self): def test_has_module_perms(self):
self.assertIs(self.user1.has_module_perms("app1"), True) self.assertIs(self.user1.has_module_perms("app1"), True)
self.assertIs(self.user1.has_module_perms("app2"), False) self.assertIs(self.user1.has_module_perms("app2"), False)
async def test_ahas_module_perms(self):
self.assertIs(await self.user1.ahas_module_perms("app1"), True)
self.assertIs(await self.user1.ahas_module_perms("app2"), False)
def test_get_all_permissions(self): def test_get_all_permissions(self):
self.assertEqual(self.user1.get_all_permissions(TestObj()), {"anon"}) self.assertEqual(self.user1.get_all_permissions(TestObj()), {"anon"})
async def test_aget_all_permissions(self):
self.assertEqual(await self.user1.aget_all_permissions(TestObj()), {"anon"})
@override_settings(AUTHENTICATION_BACKENDS=[]) @override_settings(AUTHENTICATION_BACKENDS=[])
class NoBackendsTest(TestCase): class NoBackendsTest(TestCase):
@ -561,6 +827,14 @@ class NoBackendsTest(TestCase):
with self.assertRaisesMessage(ImproperlyConfigured, msg): with self.assertRaisesMessage(ImproperlyConfigured, msg):
self.user.has_perm(("perm", TestObj())) self.user.has_perm(("perm", TestObj()))
async def test_araises_exception(self):
msg = (
"No authentication backends have been defined. "
"Does AUTHENTICATION_BACKENDS contain anything?"
)
with self.assertRaisesMessage(ImproperlyConfigured, msg):
await self.user.ahas_perm(("perm", TestObj()))
@override_settings( @override_settings(
AUTHENTICATION_BACKENDS=["auth_tests.test_auth_backends.SimpleRowlevelBackend"] AUTHENTICATION_BACKENDS=["auth_tests.test_auth_backends.SimpleRowlevelBackend"]
@ -593,12 +867,21 @@ class PermissionDeniedBackend:
def authenticate(self, request, username=None, password=None): def authenticate(self, request, username=None, password=None):
raise PermissionDenied raise PermissionDenied
async def aauthenticate(self, request, username=None, password=None):
raise PermissionDenied
def has_perm(self, user_obj, perm, obj=None): def has_perm(self, user_obj, perm, obj=None):
raise PermissionDenied raise PermissionDenied
async def ahas_perm(self, user_obj, perm, obj=None):
raise PermissionDenied
def has_module_perms(self, user_obj, app_label): def has_module_perms(self, user_obj, app_label):
raise PermissionDenied raise PermissionDenied
async def ahas_module_perms(self, user_obj, app_label):
raise PermissionDenied
class PermissionDeniedBackendTest(TestCase): class PermissionDeniedBackendTest(TestCase):
""" """
@ -631,10 +914,25 @@ class PermissionDeniedBackendTest(TestCase):
[{"password": "********************", "username": "test"}], [{"password": "********************", "username": "test"}],
) )
@modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
async def test_aauthenticate_permission_denied(self):
self.assertIsNone(await aauthenticate(username="test", password="test"))
# user_login_failed signal is sent.
self.assertEqual(
self.user_login_failed,
[{"password": "********************", "username": "test"}],
)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) @modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
def test_authenticates(self): def test_authenticates(self):
self.assertEqual(authenticate(username="test", password="test"), self.user1) self.assertEqual(authenticate(username="test", password="test"), self.user1)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
async def test_aauthenticate(self):
self.assertEqual(
await aauthenticate(username="test", password="test"), self.user1
)
@modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend}) @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
def test_has_perm_denied(self): def test_has_perm_denied(self):
content_type = ContentType.objects.get_for_model(Group) content_type = ContentType.objects.get_for_model(Group)
@ -646,6 +944,17 @@ class PermissionDeniedBackendTest(TestCase):
self.assertIs(self.user1.has_perm("auth.test"), False) self.assertIs(self.user1.has_perm("auth.test"), False)
self.assertIs(self.user1.has_module_perms("auth"), False) self.assertIs(self.user1.has_module_perms("auth"), False)
@modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
async def test_ahas_perm_denied(self):
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await self.user1.user_permissions.aadd(perm)
self.assertIs(await self.user1.ahas_perm("auth.test"), False)
self.assertIs(await self.user1.ahas_module_perms("auth"), False)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) @modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
def test_has_perm(self): def test_has_perm(self):
content_type = ContentType.objects.get_for_model(Group) content_type = ContentType.objects.get_for_model(Group)
@ -657,6 +966,17 @@ class PermissionDeniedBackendTest(TestCase):
self.assertIs(self.user1.has_perm("auth.test"), True) self.assertIs(self.user1.has_perm("auth.test"), True)
self.assertIs(self.user1.has_module_perms("auth"), True) self.assertIs(self.user1.has_module_perms("auth"), True)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
async def test_ahas_perm(self):
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await self.user1.user_permissions.aadd(perm)
self.assertIs(await self.user1.ahas_perm("auth.test"), True)
self.assertIs(await self.user1.ahas_module_perms("auth"), True)
class NewModelBackend(ModelBackend): class NewModelBackend(ModelBackend):
pass pass
@ -715,6 +1035,10 @@ class TypeErrorBackend:
def authenticate(self, request, username=None, password=None): def authenticate(self, request, username=None, password=None):
raise TypeError raise TypeError
@sensitive_variables("password")
async def aauthenticate(self, request, username=None, password=None):
raise TypeError
class SkippedBackend: class SkippedBackend:
def authenticate(self): def authenticate(self):

View File

@ -1,5 +1,3 @@
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth import aget_user, get_user, get_user_model from django.contrib.auth import aget_user, get_user, get_user_model
from django.contrib.auth.models import AnonymousUser, User from django.contrib.auth.models import AnonymousUser, User
@ -44,6 +42,12 @@ class BasicTestCase(TestCase):
u2 = User.objects.create_user("testuser2", "test2@example.com") u2 = User.objects.create_user("testuser2", "test2@example.com")
self.assertFalse(u2.has_usable_password()) self.assertFalse(u2.has_usable_password())
async def test_acreate(self):
u = await User.objects.acreate_user("testuser", "test@example.com", "testpw")
self.assertTrue(u.has_usable_password())
self.assertFalse(await u.acheck_password("bad"))
self.assertTrue(await u.acheck_password("testpw"))
def test_unicode_username(self): def test_unicode_username(self):
User.objects.create_user("jörg") User.objects.create_user("jörg")
User.objects.create_user("Григорий") User.objects.create_user("Григорий")
@ -73,6 +77,15 @@ class BasicTestCase(TestCase):
self.assertTrue(super.is_active) self.assertTrue(super.is_active)
self.assertTrue(super.is_staff) self.assertTrue(super.is_staff)
async def test_asuperuser(self):
"Check the creation and properties of a superuser"
super = await User.objects.acreate_superuser(
"super", "super@example.com", "super"
)
self.assertTrue(super.is_superuser)
self.assertTrue(super.is_active)
self.assertTrue(super.is_staff)
def test_superuser_no_email_or_password(self): def test_superuser_no_email_or_password(self):
cases = [ cases = [
{}, {},
@ -171,13 +184,25 @@ class TestGetUser(TestCase):
self.assertIsInstance(user, User) self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username) self.assertEqual(user.username, created_user.username)
async def test_aget_user(self): async def test_aget_user_fallback_secret(self):
created_user = await sync_to_async(User.objects.create_user)( created_user = await User.objects.acreate_user(
"testuser", "test@example.com", "testpw" "testuser", "test@example.com", "testpw"
) )
await self.client.alogin(username="testuser", password="testpw") await self.client.alogin(username="testuser", password="testpw")
request = HttpRequest() request = HttpRequest()
request.session = await self.client.asession() request.session = await self.client.asession()
user = await aget_user(request) prev_session_key = request.session.session_key
self.assertIsInstance(user, User) with override_settings(
self.assertEqual(user.username, created_user.username) SECRET_KEY="newsecret",
SECRET_KEY_FALLBACKS=[settings.SECRET_KEY],
):
user = await aget_user(request)
self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username)
self.assertNotEqual(request.session.session_key, prev_session_key)
# Remove the fallback secret.
# The session hash should be updated using the current secret.
with override_settings(SECRET_KEY="newsecret"):
user = await aget_user(request)
self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username)

View File

@ -1,7 +1,5 @@
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth import models from django.contrib.auth import models
from django.contrib.auth.decorators import ( from django.contrib.auth.decorators import (
@ -374,7 +372,7 @@ class UserPassesTestDecoratorTest(TestCase):
def test_decorator_async_test_func(self): def test_decorator_async_test_func(self):
async def async_test_func(user): async def async_test_func(user):
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) return await user.ahas_perms(["auth_tests.add_customuser"])
@user_passes_test(async_test_func) @user_passes_test(async_test_func)
def sync_view(request): def sync_view(request):
@ -410,7 +408,7 @@ class UserPassesTestDecoratorTest(TestCase):
async def test_decorator_async_view_async_test_func(self): async def test_decorator_async_view_async_test_func(self):
async def async_test_func(user): async def async_test_func(user):
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) return await user.ahas_perms(["auth_tests.add_customuser"])
@user_passes_test(async_test_func) @user_passes_test(async_test_func)
async def async_view(request): async def async_view(request):

View File

@ -1,7 +1,5 @@
from unittest import mock from unittest import mock
from asgiref.sync import sync_to_async
from django.conf.global_settings import PASSWORD_HASHERS from django.conf.global_settings import PASSWORD_HASHERS
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend from django.contrib.auth.backends import ModelBackend
@ -30,10 +28,19 @@ class NaturalKeysTestCase(TestCase):
self.assertEqual(User.objects.get_by_natural_key("staff"), staff_user) self.assertEqual(User.objects.get_by_natural_key("staff"), staff_user)
self.assertEqual(staff_user.natural_key(), ("staff",)) self.assertEqual(staff_user.natural_key(), ("staff",))
async def test_auser_natural_key(self):
staff_user = await User.objects.acreate_user(username="staff")
self.assertEqual(await User.objects.aget_by_natural_key("staff"), staff_user)
self.assertEqual(staff_user.natural_key(), ("staff",))
def test_group_natural_key(self): def test_group_natural_key(self):
users_group = Group.objects.create(name="users") users_group = Group.objects.create(name="users")
self.assertEqual(Group.objects.get_by_natural_key("users"), users_group) self.assertEqual(Group.objects.get_by_natural_key("users"), users_group)
async def test_agroup_natural_key(self):
users_group = await Group.objects.acreate(name="users")
self.assertEqual(await Group.objects.aget_by_natural_key("users"), users_group)
class LoadDataWithoutNaturalKeysTestCase(TestCase): class LoadDataWithoutNaturalKeysTestCase(TestCase):
fixtures = ["regular.json"] fixtures = ["regular.json"]
@ -157,6 +164,17 @@ class UserManagerTestCase(TransactionTestCase):
is_superuser=False, is_superuser=False,
) )
async def test_acreate_super_user_raises_error_on_false_is_superuser(self):
with self.assertRaisesMessage(
ValueError, "Superuser must have is_superuser=True."
):
await User.objects.acreate_superuser(
username="test",
email="test@test.com",
password="test",
is_superuser=False,
)
def test_create_superuser_raises_error_on_false_is_staff(self): def test_create_superuser_raises_error_on_false_is_staff(self):
with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."): with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."):
User.objects.create_superuser( User.objects.create_superuser(
@ -166,6 +184,15 @@ class UserManagerTestCase(TransactionTestCase):
is_staff=False, is_staff=False,
) )
async def test_acreate_superuser_raises_error_on_false_is_staff(self):
with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."):
await User.objects.acreate_superuser(
username="test",
email="test@test.com",
password="test",
is_staff=False,
)
def test_runpython_manager_methods(self): def test_runpython_manager_methods(self):
def forwards(apps, schema_editor): def forwards(apps, schema_editor):
UserModel = apps.get_model("auth", "User") UserModel = apps.get_model("auth", "User")
@ -301,9 +328,7 @@ class AbstractUserTestCase(TestCase):
@override_settings(PASSWORD_HASHERS=PASSWORD_HASHERS) @override_settings(PASSWORD_HASHERS=PASSWORD_HASHERS)
async def test_acheck_password_upgrade(self): async def test_acheck_password_upgrade(self):
user = await sync_to_async(User.objects.create_user)( user = await User.objects.acreate_user(username="user", password="foo")
username="user", password="foo"
)
initial_password = user.password initial_password = user.password
self.assertIs(await user.acheck_password("foo"), True) self.assertIs(await user.acheck_password("foo"), True)
hasher = get_hasher("default") hasher = get_hasher("default")
@ -557,6 +582,12 @@ class AnonymousUserTests(SimpleTestCase):
self.assertEqual(self.user.get_user_permissions(), set()) self.assertEqual(self.user.get_user_permissions(), set())
self.assertEqual(self.user.get_group_permissions(), set()) self.assertEqual(self.user.get_group_permissions(), set())
async def test_properties_async_versions(self):
self.assertEqual(await self.user.groups.acount(), 0)
self.assertEqual(await self.user.user_permissions.acount(), 0)
self.assertEqual(await self.user.aget_user_permissions(), set())
self.assertEqual(await self.user.aget_group_permissions(), set())
def test_str(self): def test_str(self):
self.assertEqual(str(self.user), "AnonymousUser") self.assertEqual(str(self.user), "AnonymousUser")

View File

@ -1,12 +1,18 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from django.conf import settings from django.conf import settings
from django.contrib.auth import authenticate from django.contrib.auth import aauthenticate, authenticate
from django.contrib.auth.backends import RemoteUserBackend from django.contrib.auth.backends import RemoteUserBackend
from django.contrib.auth.middleware import RemoteUserMiddleware from django.contrib.auth.middleware import RemoteUserMiddleware
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.middleware.csrf import _get_new_csrf_string, _mask_cipher_secret from django.middleware.csrf import _get_new_csrf_string, _mask_cipher_secret
from django.test import Client, TestCase, modify_settings, override_settings from django.test import (
AsyncClient,
Client,
TestCase,
modify_settings,
override_settings,
)
@override_settings(ROOT_URLCONF="auth_tests.urls") @override_settings(ROOT_URLCONF="auth_tests.urls")
@ -30,6 +36,11 @@ class RemoteUserTest(TestCase):
) )
super().setUpClass() super().setUpClass()
def test_passing_explicit_none(self):
msg = "get_response must be provided."
with self.assertRaisesMessage(ValueError, msg):
RemoteUserMiddleware(None)
def test_no_remote_user(self): def test_no_remote_user(self):
"""Users are not created when remote user is not specified.""" """Users are not created when remote user is not specified."""
num_users = User.objects.count() num_users = User.objects.count()
@ -46,6 +57,18 @@ class RemoteUserTest(TestCase):
self.assertTrue(response.context["user"].is_anonymous) self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(User.objects.count(), num_users) self.assertEqual(User.objects.count(), num_users)
async def test_no_remote_user_async(self):
"""See test_no_remote_user."""
num_users = await User.objects.acount()
response = await self.async_client.get("/remote_user/")
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)
response = await self.async_client.get("/remote_user/", **{self.header: ""})
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)
def test_csrf_validation_passes_after_process_request_login(self): def test_csrf_validation_passes_after_process_request_login(self):
""" """
CSRF check must access the CSRF token from the session or cookie, CSRF check must access the CSRF token from the session or cookie,
@ -75,6 +98,31 @@ class RemoteUserTest(TestCase):
response = csrf_client.post("/remote_user/", data, **headers) response = csrf_client.post("/remote_user/", data, **headers)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
async def test_csrf_validation_passes_after_process_request_login_async(self):
"""See test_csrf_validation_passes_after_process_request_login."""
csrf_client = AsyncClient(enforce_csrf_checks=True)
csrf_secret = _get_new_csrf_string()
csrf_token = _mask_cipher_secret(csrf_secret)
csrf_token_form = _mask_cipher_secret(csrf_secret)
headers = {self.header: "fakeuser"}
data = {"csrfmiddlewaretoken": csrf_token_form}
# Verify that CSRF is configured for the view
csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
response = await csrf_client.post("/remote_user/", **headers)
self.assertEqual(response.status_code, 403)
self.assertIn(b"CSRF verification failed.", response.content)
# This request will call django.contrib.auth.alogin() which will call
# django.middleware.csrf.rotate_token() thus changing the value of
# request.META['CSRF_COOKIE'] from the user submitted value set by
# CsrfViewMiddleware.process_request() to the new csrftoken value set
# by rotate_token(). Csrf validation should still pass when the view is
# later processed by CsrfViewMiddleware.process_view()
csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
response = await csrf_client.post("/remote_user/", data, **headers)
self.assertEqual(response.status_code, 200)
def test_unknown_user(self): def test_unknown_user(self):
""" """
Tests the case where the username passed in the header does not exist Tests the case where the username passed in the header does not exist
@ -90,6 +138,22 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/", **{self.header: "newuser"}) response = self.client.get("/remote_user/", **{self.header: "newuser"})
self.assertEqual(User.objects.count(), num_users + 1) self.assertEqual(User.objects.count(), num_users + 1)
async def test_unknown_user_async(self):
"""See test_unknown_user."""
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertEqual(response.context["user"].username, "newuser")
self.assertEqual(await User.objects.acount(), num_users + 1)
await User.objects.aget(username="newuser")
# Another request with same user should not create any new users.
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertEqual(await User.objects.acount(), num_users + 1)
def test_known_user(self): def test_known_user(self):
""" """
Tests the case where the username passed in the header is a valid User. Tests the case where the username passed in the header is a valid User.
@ -106,6 +170,24 @@ class RemoteUserTest(TestCase):
self.assertEqual(response.context["user"].username, "knownuser2") self.assertEqual(response.context["user"].username, "knownuser2")
self.assertEqual(User.objects.count(), num_users) self.assertEqual(User.objects.count(), num_users)
async def test_known_user_async(self):
"""See test_known_user."""
await User.objects.acreate(username="knownuser")
await User.objects.acreate(username="knownuser2")
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
self.assertEqual(await User.objects.acount(), num_users)
# A different user passed in the headers causes the new user
# to be logged in.
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user2}
)
self.assertEqual(response.context["user"].username, "knownuser2")
self.assertEqual(await User.objects.acount(), num_users)
def test_last_login(self): def test_last_login(self):
""" """
A user's last_login is set the first time they make a A user's last_login is set the first time they make a
@ -128,6 +210,29 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/", **{self.header: self.known_user}) response = self.client.get("/remote_user/", **{self.header: self.known_user})
self.assertEqual(default_login, response.context["user"].last_login) self.assertEqual(default_login, response.context["user"].last_login)
async def test_last_login_async(self):
"""See test_last_login."""
user = await User.objects.acreate(username="knownuser")
# Set last_login to something so we can determine if it changes.
default_login = datetime(2000, 1, 1)
if settings.USE_TZ:
default_login = default_login.replace(tzinfo=timezone.utc)
user.last_login = default_login
await user.asave()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertNotEqual(default_login, response.context["user"].last_login)
user = await User.objects.aget(username="knownuser")
user.last_login = default_login
await user.asave()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(default_login, response.context["user"].last_login)
def test_header_disappears(self): def test_header_disappears(self):
""" """
A logged in user is logged out automatically when A logged in user is logged out automatically when
@ -148,6 +253,25 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/") response = self.client.get("/remote_user/")
self.assertEqual(response.context["user"].username, "modeluser") self.assertEqual(response.context["user"].username, "modeluser")
async def test_header_disappears_async(self):
"""See test_header_disappears."""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# During the session, the REMOTE_USER header disappears. Should trigger logout.
response = await self.async_client.get("/remote_user/")
self.assertTrue(response.context["user"].is_anonymous)
# verify the remoteuser middleware will not remove a user
# authenticated via another backend
await User.objects.acreate_user(username="modeluser", password="foo")
await self.async_client.alogin(username="modeluser", password="foo")
await aauthenticate(username="modeluser", password="foo")
response = await self.async_client.get("/remote_user/")
self.assertEqual(response.context["user"].username, "modeluser")
def test_user_switch_forces_new_login(self): def test_user_switch_forces_new_login(self):
""" """
If the username in the header changes between requests If the username in the header changes between requests
@ -164,11 +288,35 @@ class RemoteUserTest(TestCase):
# In backends that do not create new users, it is '' (anonymous user) # In backends that do not create new users, it is '' (anonymous user)
self.assertNotEqual(response.context["user"].username, "knownuser") self.assertNotEqual(response.context["user"].username, "knownuser")
async def test_user_switch_forces_new_login_async(self):
"""See test_user_switch_forces_new_login."""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# During the session, the REMOTE_USER changes to a different user.
response = await self.async_client.get(
"/remote_user/", **{self.header: "newnewuser"}
)
# The current user is not the prior remote_user.
# In backends that create a new user, username is "newnewuser"
# In backends that do not create new users, it is '' (anonymous user)
self.assertNotEqual(response.context["user"].username, "knownuser")
def test_inactive_user(self): def test_inactive_user(self):
User.objects.create(username="knownuser", is_active=False) User.objects.create(username="knownuser", is_active=False)
response = self.client.get("/remote_user/", **{self.header: "knownuser"}) response = self.client.get("/remote_user/", **{self.header: "knownuser"})
self.assertTrue(response.context["user"].is_anonymous) self.assertTrue(response.context["user"].is_anonymous)
async def test_inactive_user_async(self):
await User.objects.acreate(username="knownuser", is_active=False)
response = await self.async_client.get(
"/remote_user/", **{self.header: "knownuser"}
)
self.assertTrue(response.context["user"].is_anonymous)
class RemoteUserNoCreateBackend(RemoteUserBackend): class RemoteUserNoCreateBackend(RemoteUserBackend):
"""Backend that doesn't create unknown users.""" """Backend that doesn't create unknown users."""
@ -190,6 +338,14 @@ class RemoteUserNoCreateTest(RemoteUserTest):
self.assertTrue(response.context["user"].is_anonymous) self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(User.objects.count(), num_users) self.assertEqual(User.objects.count(), num_users)
async def test_unknown_user_async(self):
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)
class AllowAllUsersRemoteUserBackendTest(RemoteUserTest): class AllowAllUsersRemoteUserBackendTest(RemoteUserTest):
"""Backend that allows inactive users.""" """Backend that allows inactive users."""
@ -201,6 +357,13 @@ class AllowAllUsersRemoteUserBackendTest(RemoteUserTest):
response = self.client.get("/remote_user/", **{self.header: self.known_user}) response = self.client.get("/remote_user/", **{self.header: self.known_user})
self.assertEqual(response.context["user"].username, user.username) self.assertEqual(response.context["user"].username, user.username)
async def test_inactive_user_async(self):
user = await User.objects.acreate(username="knownuser", is_active=False)
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, user.username)
class CustomRemoteUserBackend(RemoteUserBackend): class CustomRemoteUserBackend(RemoteUserBackend):
""" """
@ -311,3 +474,16 @@ class PersistentRemoteUserTest(RemoteUserTest):
response = self.client.get("/remote_user/") response = self.client.get("/remote_user/")
self.assertFalse(response.context["user"].is_anonymous) self.assertFalse(response.context["user"].is_anonymous)
self.assertEqual(response.context["user"].username, "knownuser") self.assertEqual(response.context["user"].username, "knownuser")
async def test_header_disappears_async(self):
"""See test_header_disappears."""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# Should stay logged in if the REMOTE_USER header disappears.
response = await self.async_client.get("/remote_user/")
self.assertFalse(response.context["user"].is_anonymous)
self.assertEqual(response.context["user"].username, "knownuser")

View File

@ -6,7 +6,6 @@ from django.contrib.admindocs.middleware import XViewMiddleware
from django.contrib.auth.middleware import ( from django.contrib.auth.middleware import (
AuthenticationMiddleware, AuthenticationMiddleware,
LoginRequiredMiddleware, LoginRequiredMiddleware,
RemoteUserMiddleware,
) )
from django.contrib.flatpages.middleware import FlatpageFallbackMiddleware from django.contrib.flatpages.middleware import FlatpageFallbackMiddleware
from django.contrib.messages.middleware import MessageMiddleware from django.contrib.messages.middleware import MessageMiddleware
@ -48,7 +47,6 @@ class MiddlewareMixinTests(SimpleTestCase):
LocaleMiddleware, LocaleMiddleware,
MessageMiddleware, MessageMiddleware,
RedirectFallbackMiddleware, RedirectFallbackMiddleware,
RemoteUserMiddleware,
SecurityMiddleware, SecurityMiddleware,
SessionMiddleware, SessionMiddleware,
UpdateCacheMiddleware, UpdateCacheMiddleware,