mirror of
https://github.com/django/django.git
synced 2025-10-31 09:41:08 +00:00
Fixed #35303 -- Implemented async auth backends and utils.
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.apps import apps as django_apps
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
|
||||
@@ -40,6 +38,39 @@ def get_backends():
|
||||
return _get_backends(return_tuples=False)
|
||||
|
||||
|
||||
def _get_compatible_backends(request, **credentials):
|
||||
for backend, backend_path in _get_backends(return_tuples=True):
|
||||
backend_signature = inspect.signature(backend.authenticate)
|
||||
try:
|
||||
backend_signature.bind(request, **credentials)
|
||||
except TypeError:
|
||||
# This backend doesn't accept these credentials as arguments. Try
|
||||
# the next one.
|
||||
continue
|
||||
yield backend, backend_path
|
||||
|
||||
|
||||
def _get_backend_from_user(user, backend=None):
|
||||
try:
|
||||
backend = backend or user.backend
|
||||
except AttributeError:
|
||||
backends = _get_backends(return_tuples=True)
|
||||
if len(backends) == 1:
|
||||
_, backend = backends[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have multiple authentication backends configured and "
|
||||
"therefore must provide the `backend` argument or set the "
|
||||
"`backend` attribute on the user."
|
||||
)
|
||||
else:
|
||||
if not isinstance(backend, str):
|
||||
raise TypeError(
|
||||
"backend must be a dotted import path string (got %r)." % backend
|
||||
)
|
||||
return backend
|
||||
|
||||
|
||||
@sensitive_variables("credentials")
|
||||
def _clean_credentials(credentials):
|
||||
"""
|
||||
@@ -62,19 +93,21 @@ def _get_user_session_key(request):
|
||||
return get_user_model()._meta.pk.to_python(request.session[SESSION_KEY])
|
||||
|
||||
|
||||
async def _aget_user_session_key(request):
|
||||
# This value in the session is always serialized to a string, so we need
|
||||
# to convert it back to Python whenever we access it.
|
||||
session_key = await request.session.aget(SESSION_KEY)
|
||||
if session_key is None:
|
||||
raise KeyError()
|
||||
return get_user_model()._meta.pk.to_python(session_key)
|
||||
|
||||
|
||||
@sensitive_variables("credentials")
|
||||
def authenticate(request=None, **credentials):
|
||||
"""
|
||||
If the given credentials are valid, return a User object.
|
||||
"""
|
||||
for backend, backend_path in _get_backends(return_tuples=True):
|
||||
backend_signature = inspect.signature(backend.authenticate)
|
||||
try:
|
||||
backend_signature.bind(request, **credentials)
|
||||
except TypeError:
|
||||
# This backend doesn't accept these credentials as arguments. Try
|
||||
# the next one.
|
||||
continue
|
||||
for backend, backend_path in _get_compatible_backends(request, **credentials):
|
||||
try:
|
||||
user = backend.authenticate(request, **credentials)
|
||||
except PermissionDenied:
|
||||
@@ -96,7 +129,23 @@ def authenticate(request=None, **credentials):
|
||||
@sensitive_variables("credentials")
|
||||
async def aauthenticate(request=None, **credentials):
|
||||
"""See authenticate()."""
|
||||
return await sync_to_async(authenticate)(request, **credentials)
|
||||
for backend, backend_path in _get_compatible_backends(request, **credentials):
|
||||
try:
|
||||
user = await backend.aauthenticate(request, **credentials)
|
||||
except PermissionDenied:
|
||||
# This backend says to stop in our tracks - this user should not be
|
||||
# allowed in at all.
|
||||
break
|
||||
if user is None:
|
||||
continue
|
||||
# Annotate the user object with the path of the backend.
|
||||
user.backend = backend_path
|
||||
return user
|
||||
|
||||
# The credentials supplied are invalid to all backends, fire signal.
|
||||
await user_login_failed.asend(
|
||||
sender=__name__, credentials=_clean_credentials(credentials), request=request
|
||||
)
|
||||
|
||||
|
||||
def login(request, user, backend=None):
|
||||
@@ -125,23 +174,7 @@ def login(request, user, backend=None):
|
||||
else:
|
||||
request.session.cycle_key()
|
||||
|
||||
try:
|
||||
backend = backend or user.backend
|
||||
except AttributeError:
|
||||
backends = _get_backends(return_tuples=True)
|
||||
if len(backends) == 1:
|
||||
_, backend = backends[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have multiple authentication backends configured and "
|
||||
"therefore must provide the `backend` argument or set the "
|
||||
"`backend` attribute on the user."
|
||||
)
|
||||
else:
|
||||
if not isinstance(backend, str):
|
||||
raise TypeError(
|
||||
"backend must be a dotted import path string (got %r)." % backend
|
||||
)
|
||||
backend = _get_backend_from_user(user=user, backend=backend)
|
||||
|
||||
request.session[SESSION_KEY] = user._meta.pk.value_to_string(user)
|
||||
request.session[BACKEND_SESSION_KEY] = backend
|
||||
@@ -154,7 +187,36 @@ def login(request, user, backend=None):
|
||||
|
||||
async def alogin(request, user, backend=None):
|
||||
"""See login()."""
|
||||
return await sync_to_async(login)(request, user, backend)
|
||||
session_auth_hash = ""
|
||||
if user is None:
|
||||
user = await request.auser()
|
||||
if hasattr(user, "get_session_auth_hash"):
|
||||
session_auth_hash = user.get_session_auth_hash()
|
||||
|
||||
if await request.session.ahas_key(SESSION_KEY):
|
||||
if await _aget_user_session_key(request) != user.pk or (
|
||||
session_auth_hash
|
||||
and not constant_time_compare(
|
||||
await request.session.aget(HASH_SESSION_KEY, ""),
|
||||
session_auth_hash,
|
||||
)
|
||||
):
|
||||
# To avoid reusing another user's session, create a new, empty
|
||||
# session if the existing session corresponds to a different
|
||||
# authenticated user.
|
||||
await request.session.aflush()
|
||||
else:
|
||||
await request.session.acycle_key()
|
||||
|
||||
backend = _get_backend_from_user(user=user, backend=backend)
|
||||
|
||||
await request.session.aset(SESSION_KEY, user._meta.pk.value_to_string(user))
|
||||
await request.session.aset(BACKEND_SESSION_KEY, backend)
|
||||
await request.session.aset(HASH_SESSION_KEY, session_auth_hash)
|
||||
if hasattr(request, "user"):
|
||||
request.user = user
|
||||
rotate_token(request)
|
||||
await user_logged_in.asend(sender=user.__class__, request=request, user=user)
|
||||
|
||||
|
||||
def logout(request):
|
||||
@@ -177,7 +239,19 @@ def logout(request):
|
||||
|
||||
async def alogout(request):
|
||||
"""See logout()."""
|
||||
return await sync_to_async(logout)(request)
|
||||
# Dispatch the signal before the user is logged out so the receivers have a
|
||||
# chance to find out *who* logged out.
|
||||
user = getattr(request, "auser", None)
|
||||
if user is not None:
|
||||
user = await user()
|
||||
if not getattr(user, "is_authenticated", True):
|
||||
user = None
|
||||
await user_logged_out.asend(sender=user.__class__, request=request, user=user)
|
||||
await request.session.aflush()
|
||||
if hasattr(request, "user"):
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
request.user = AnonymousUser()
|
||||
|
||||
|
||||
def get_user_model():
|
||||
@@ -243,7 +317,43 @@ def get_user(request):
|
||||
|
||||
async def aget_user(request):
|
||||
"""See get_user()."""
|
||||
return await sync_to_async(get_user)(request)
|
||||
from .models import AnonymousUser
|
||||
|
||||
user = None
|
||||
try:
|
||||
user_id = await _aget_user_session_key(request)
|
||||
backend_path = await request.session.aget(BACKEND_SESSION_KEY)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
user = await backend.aget_user(user_id)
|
||||
# Verify the session
|
||||
if hasattr(user, "get_session_auth_hash"):
|
||||
session_hash = await request.session.aget(HASH_SESSION_KEY)
|
||||
if not session_hash:
|
||||
session_hash_verified = False
|
||||
else:
|
||||
session_auth_hash = user.get_session_auth_hash()
|
||||
session_hash_verified = session_hash and constant_time_compare(
|
||||
session_hash, user.get_session_auth_hash()
|
||||
)
|
||||
if not session_hash_verified:
|
||||
# If the current secret does not verify the session, try
|
||||
# with the fallback secrets and stop when a matching one is
|
||||
# found.
|
||||
if session_hash and any(
|
||||
constant_time_compare(session_hash, fallback_auth_hash)
|
||||
for fallback_auth_hash in user.get_session_auth_fallback_hash()
|
||||
):
|
||||
await request.session.acycle_key()
|
||||
await request.session.aset(HASH_SESSION_KEY, session_auth_hash)
|
||||
else:
|
||||
await request.session.aflush()
|
||||
user = None
|
||||
|
||||
return user or AnonymousUser()
|
||||
|
||||
|
||||
def get_permission_codename(action, opts):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.db.models import Exists, OuterRef, Q
|
||||
@@ -9,24 +11,45 @@ class BaseBackend:
|
||||
def authenticate(self, request, **kwargs):
|
||||
return None
|
||||
|
||||
async def aauthenticate(self, request, **kwargs):
|
||||
return await sync_to_async(self.authenticate)(request, **kwargs)
|
||||
|
||||
def get_user(self, user_id):
|
||||
return None
|
||||
|
||||
async def aget_user(self, user_id):
|
||||
return await sync_to_async(self.get_user)(user_id)
|
||||
|
||||
def get_user_permissions(self, user_obj, obj=None):
|
||||
return set()
|
||||
|
||||
async def aget_user_permissions(self, user_obj, obj=None):
|
||||
return await sync_to_async(self.get_user_permissions)(user_obj, obj)
|
||||
|
||||
def get_group_permissions(self, user_obj, obj=None):
|
||||
return set()
|
||||
|
||||
async def aget_group_permissions(self, user_obj, obj=None):
|
||||
return await sync_to_async(self.get_group_permissions)(user_obj, obj)
|
||||
|
||||
def get_all_permissions(self, user_obj, obj=None):
|
||||
return {
|
||||
*self.get_user_permissions(user_obj, obj=obj),
|
||||
*self.get_group_permissions(user_obj, obj=obj),
|
||||
}
|
||||
|
||||
async def aget_all_permissions(self, user_obj, obj=None):
|
||||
return {
|
||||
*await self.aget_user_permissions(user_obj, obj=obj),
|
||||
*await self.aget_group_permissions(user_obj, obj=obj),
|
||||
}
|
||||
|
||||
def has_perm(self, user_obj, perm, obj=None):
|
||||
return perm in self.get_all_permissions(user_obj, obj=obj)
|
||||
|
||||
async def ahas_perm(self, user_obj, perm, obj=None):
|
||||
return perm in await self.aget_all_permissions(user_obj, obj)
|
||||
|
||||
|
||||
class ModelBackend(BaseBackend):
|
||||
"""
|
||||
@@ -48,6 +71,23 @@ class ModelBackend(BaseBackend):
|
||||
if user.check_password(password) and self.user_can_authenticate(user):
|
||||
return user
|
||||
|
||||
async def aauthenticate(self, request, username=None, password=None, **kwargs):
|
||||
if username is None:
|
||||
username = kwargs.get(UserModel.USERNAME_FIELD)
|
||||
if username is None or password is None:
|
||||
return
|
||||
try:
|
||||
user = await UserModel._default_manager.aget_by_natural_key(username)
|
||||
except UserModel.DoesNotExist:
|
||||
# Run the default password hasher once to reduce the timing
|
||||
# difference between an existing and a nonexistent user (#20760).
|
||||
UserModel().set_password(password)
|
||||
else:
|
||||
if await user.acheck_password(password) and self.user_can_authenticate(
|
||||
user
|
||||
):
|
||||
return user
|
||||
|
||||
def user_can_authenticate(self, user):
|
||||
"""
|
||||
Reject users with is_active=False. Custom user models that don't have
|
||||
@@ -84,6 +124,25 @@ class ModelBackend(BaseBackend):
|
||||
)
|
||||
return getattr(user_obj, perm_cache_name)
|
||||
|
||||
async def _aget_permissions(self, user_obj, obj, from_name):
|
||||
"""See _get_permissions()."""
|
||||
if not user_obj.is_active or user_obj.is_anonymous or obj is not None:
|
||||
return set()
|
||||
|
||||
perm_cache_name = "_%s_perm_cache" % from_name
|
||||
if not hasattr(user_obj, perm_cache_name):
|
||||
if user_obj.is_superuser:
|
||||
perms = Permission.objects.all()
|
||||
else:
|
||||
perms = getattr(self, "_get_%s_permissions" % from_name)(user_obj)
|
||||
perms = perms.values_list("content_type__app_label", "codename").order_by()
|
||||
setattr(
|
||||
user_obj,
|
||||
perm_cache_name,
|
||||
{"%s.%s" % (ct, name) async for ct, name in perms},
|
||||
)
|
||||
return getattr(user_obj, perm_cache_name)
|
||||
|
||||
def get_user_permissions(self, user_obj, obj=None):
|
||||
"""
|
||||
Return a set of permission strings the user `user_obj` has from their
|
||||
@@ -91,6 +150,10 @@ class ModelBackend(BaseBackend):
|
||||
"""
|
||||
return self._get_permissions(user_obj, obj, "user")
|
||||
|
||||
async def aget_user_permissions(self, user_obj, obj=None):
|
||||
"""See get_user_permissions()."""
|
||||
return await self._aget_permissions(user_obj, obj, "user")
|
||||
|
||||
def get_group_permissions(self, user_obj, obj=None):
|
||||
"""
|
||||
Return a set of permission strings the user `user_obj` has from the
|
||||
@@ -98,6 +161,10 @@ class ModelBackend(BaseBackend):
|
||||
"""
|
||||
return self._get_permissions(user_obj, obj, "group")
|
||||
|
||||
async def aget_group_permissions(self, user_obj, obj=None):
|
||||
"""See get_group_permissions()."""
|
||||
return await self._aget_permissions(user_obj, obj, "group")
|
||||
|
||||
def get_all_permissions(self, user_obj, obj=None):
|
||||
if not user_obj.is_active or user_obj.is_anonymous or obj is not None:
|
||||
return set()
|
||||
@@ -108,6 +175,9 @@ class ModelBackend(BaseBackend):
|
||||
def has_perm(self, user_obj, perm, obj=None):
|
||||
return user_obj.is_active and super().has_perm(user_obj, perm, obj=obj)
|
||||
|
||||
async def ahas_perm(self, user_obj, perm, obj=None):
|
||||
return user_obj.is_active and await super().ahas_perm(user_obj, perm, obj=obj)
|
||||
|
||||
def has_module_perms(self, user_obj, app_label):
|
||||
"""
|
||||
Return True if user_obj has any permissions in the given app_label.
|
||||
@@ -117,6 +187,13 @@ class ModelBackend(BaseBackend):
|
||||
for perm in self.get_all_permissions(user_obj)
|
||||
)
|
||||
|
||||
async def ahas_module_perms(self, user_obj, app_label):
|
||||
"""See has_module_perms()"""
|
||||
return user_obj.is_active and any(
|
||||
perm[: perm.index(".")] == app_label
|
||||
for perm in await self.aget_all_permissions(user_obj)
|
||||
)
|
||||
|
||||
def with_perm(self, perm, is_active=True, include_superusers=True, obj=None):
|
||||
"""
|
||||
Return users that have permission "perm". By default, filter out
|
||||
@@ -159,6 +236,13 @@ class ModelBackend(BaseBackend):
|
||||
return None
|
||||
return user if self.user_can_authenticate(user) else None
|
||||
|
||||
async def aget_user(self, user_id):
|
||||
try:
|
||||
user = await UserModel._default_manager.aget(pk=user_id)
|
||||
except UserModel.DoesNotExist:
|
||||
return None
|
||||
return user if self.user_can_authenticate(user) else None
|
||||
|
||||
|
||||
class AllowAllUsersModelBackend(ModelBackend):
|
||||
def user_can_authenticate(self, user):
|
||||
@@ -210,6 +294,29 @@ class RemoteUserBackend(ModelBackend):
|
||||
user = self.configure_user(request, user, created=created)
|
||||
return user if self.user_can_authenticate(user) else None
|
||||
|
||||
async def aauthenticate(self, request, remote_user):
|
||||
"""See authenticate()."""
|
||||
if not remote_user:
|
||||
return
|
||||
created = False
|
||||
user = None
|
||||
username = self.clean_username(remote_user)
|
||||
|
||||
# Note that this could be accomplished in one try-except clause, but
|
||||
# instead we use get_or_create when creating unknown users since it has
|
||||
# built-in safeguards for multiple threads.
|
||||
if self.create_unknown_user:
|
||||
user, created = await UserModel._default_manager.aget_or_create(
|
||||
**{UserModel.USERNAME_FIELD: username}
|
||||
)
|
||||
else:
|
||||
try:
|
||||
user = await UserModel._default_manager.aget_by_natural_key(username)
|
||||
except UserModel.DoesNotExist:
|
||||
pass
|
||||
user = await self.aconfigure_user(request, user, created=created)
|
||||
return user if self.user_can_authenticate(user) else None
|
||||
|
||||
def clean_username(self, username):
|
||||
"""
|
||||
Perform any cleaning on the "username" prior to using it to get or
|
||||
@@ -227,6 +334,10 @@ class RemoteUserBackend(ModelBackend):
|
||||
"""
|
||||
return user
|
||||
|
||||
async def aconfigure_user(self, request, user, created=True):
|
||||
"""See configure_user()"""
|
||||
return await sync_to_async(self.configure_user)(request, user, created)
|
||||
|
||||
|
||||
class AllowAllUsersRemoteUserBackend(RemoteUserBackend):
|
||||
def user_can_authenticate(self, user):
|
||||
|
||||
@@ -36,6 +36,9 @@ class BaseUserManager(models.Manager):
|
||||
def get_by_natural_key(self, username):
|
||||
return self.get(**{self.model.USERNAME_FIELD: username})
|
||||
|
||||
async def aget_by_natural_key(self, username):
|
||||
return await self.aget(**{self.model.USERNAME_FIELD: username})
|
||||
|
||||
|
||||
class AbstractBaseUser(models.Model):
|
||||
password = models.CharField(_("password"), max_length=128)
|
||||
|
||||
@@ -111,7 +111,7 @@ def permission_required(perm, login_url=None, raise_exception=False):
|
||||
|
||||
async def check_perms(user):
|
||||
# First check if the user has the permission (even anon users).
|
||||
if await sync_to_async(user.has_perms)(perms):
|
||||
if await user.ahas_perms(perms):
|
||||
return True
|
||||
# In case the 403 handler should be called raise the exception.
|
||||
if raise_exception:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from functools import partial
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib import auth
|
||||
from django.contrib.auth import REDIRECT_FIELD_NAME, load_backend
|
||||
@@ -88,7 +90,7 @@ class LoginRequiredMiddleware(MiddlewareMixin):
|
||||
)
|
||||
|
||||
|
||||
class RemoteUserMiddleware(MiddlewareMixin):
|
||||
class RemoteUserMiddleware:
|
||||
"""
|
||||
Middleware for utilizing web-server-provided authentication.
|
||||
|
||||
@@ -102,13 +104,27 @@ class RemoteUserMiddleware(MiddlewareMixin):
|
||||
different header.
|
||||
"""
|
||||
|
||||
sync_capable = True
|
||||
async_capable = True
|
||||
|
||||
def __init__(self, get_response):
|
||||
if get_response is None:
|
||||
raise ValueError("get_response must be provided.")
|
||||
self.get_response = get_response
|
||||
self.is_async = iscoroutinefunction(get_response)
|
||||
if self.is_async:
|
||||
markcoroutinefunction(self)
|
||||
super().__init__()
|
||||
|
||||
# Name of request header to grab username from. This will be the key as
|
||||
# used in the request.META dictionary, i.e. the normalization of headers to
|
||||
# all uppercase and the addition of "HTTP_" prefix apply.
|
||||
header = "REMOTE_USER"
|
||||
force_logout_if_no_header = True
|
||||
|
||||
def process_request(self, request):
|
||||
def __call__(self, request):
|
||||
if self.is_async:
|
||||
return self.__acall__(request)
|
||||
# AuthenticationMiddleware is required so that request.user exists.
|
||||
if not hasattr(request, "user"):
|
||||
raise ImproperlyConfigured(
|
||||
@@ -126,13 +142,13 @@ class RemoteUserMiddleware(MiddlewareMixin):
|
||||
# AnonymousUser by the AuthenticationMiddleware).
|
||||
if self.force_logout_if_no_header and request.user.is_authenticated:
|
||||
self._remove_invalid_user(request)
|
||||
return
|
||||
return self.get_response(request)
|
||||
# If the user is already authenticated and that user is the user we are
|
||||
# getting passed in the headers, then the correct user is already
|
||||
# persisted in the session and we don't need to continue.
|
||||
if request.user.is_authenticated:
|
||||
if request.user.get_username() == self.clean_username(username, request):
|
||||
return
|
||||
return self.get_response(request)
|
||||
else:
|
||||
# An authenticated user is associated with the request, but
|
||||
# it does not match the authorized user in the header.
|
||||
@@ -146,6 +162,51 @@ class RemoteUserMiddleware(MiddlewareMixin):
|
||||
# by logging the user in.
|
||||
request.user = user
|
||||
auth.login(request, user)
|
||||
return self.get_response(request)
|
||||
|
||||
async def __acall__(self, request):
|
||||
# AuthenticationMiddleware is required so that request.user exists.
|
||||
if not hasattr(request, "user"):
|
||||
raise ImproperlyConfigured(
|
||||
"The Django remote user auth middleware requires the"
|
||||
" authentication middleware to be installed. Edit your"
|
||||
" MIDDLEWARE setting to insert"
|
||||
" 'django.contrib.auth.middleware.AuthenticationMiddleware'"
|
||||
" before the RemoteUserMiddleware class."
|
||||
)
|
||||
try:
|
||||
username = request.META["HTTP_" + self.header]
|
||||
except KeyError:
|
||||
# If specified header doesn't exist then remove any existing
|
||||
# authenticated remote-user, or return (leaving request.user set to
|
||||
# AnonymousUser by the AuthenticationMiddleware).
|
||||
if self.force_logout_if_no_header:
|
||||
user = await request.auser()
|
||||
if user.is_authenticated:
|
||||
await self._aremove_invalid_user(request)
|
||||
return await self.get_response(request)
|
||||
user = await request.auser()
|
||||
# If the user is already authenticated and that user is the user we are
|
||||
# getting passed in the headers, then the correct user is already
|
||||
# persisted in the session and we don't need to continue.
|
||||
if user.is_authenticated:
|
||||
if user.get_username() == self.clean_username(username, request):
|
||||
return await self.get_response(request)
|
||||
else:
|
||||
# An authenticated user is associated with the request, but
|
||||
# it does not match the authorized user in the header.
|
||||
await self._aremove_invalid_user(request)
|
||||
|
||||
# We are seeing this user for the first time in this session, attempt
|
||||
# to authenticate the user.
|
||||
user = await auth.aauthenticate(request, remote_user=username)
|
||||
if user:
|
||||
# User is valid. Set request.user and persist user in the session
|
||||
# by logging the user in.
|
||||
request.user = user
|
||||
await auth.alogin(request, user)
|
||||
|
||||
return await self.get_response(request)
|
||||
|
||||
def clean_username(self, username, request):
|
||||
"""
|
||||
@@ -176,6 +237,22 @@ class RemoteUserMiddleware(MiddlewareMixin):
|
||||
if isinstance(stored_backend, RemoteUserBackend):
|
||||
auth.logout(request)
|
||||
|
||||
async def _aremove_invalid_user(self, request):
|
||||
"""
|
||||
Remove the current authenticated user in the request which is invalid
|
||||
but only if the user is authenticated via the RemoteUserBackend.
|
||||
"""
|
||||
try:
|
||||
stored_backend = load_backend(
|
||||
await request.session.aget(auth.BACKEND_SESSION_KEY, "")
|
||||
)
|
||||
except ImportError:
|
||||
# Backend failed to load.
|
||||
await auth.alogout(request)
|
||||
else:
|
||||
if isinstance(stored_backend, RemoteUserBackend):
|
||||
await auth.alogout(request)
|
||||
|
||||
|
||||
class PersistentRemoteUserMiddleware(RemoteUserMiddleware):
|
||||
"""
|
||||
|
||||
@@ -95,6 +95,9 @@ class GroupManager(models.Manager):
|
||||
def get_by_natural_key(self, name):
|
||||
return self.get(name=name)
|
||||
|
||||
async def aget_by_natural_key(self, name):
|
||||
return await self.aget(name=name)
|
||||
|
||||
|
||||
class Group(models.Model):
|
||||
"""
|
||||
@@ -137,10 +140,7 @@ class Group(models.Model):
|
||||
class UserManager(BaseUserManager):
|
||||
use_in_migrations = True
|
||||
|
||||
def _create_user(self, username, email, password, **extra_fields):
|
||||
"""
|
||||
Create and save a user with the given username, email, and password.
|
||||
"""
|
||||
def _create_user_object(self, username, email, password, **extra_fields):
|
||||
if not username:
|
||||
raise ValueError("The given username must be set")
|
||||
email = self.normalize_email(email)
|
||||
@@ -153,14 +153,32 @@ class UserManager(BaseUserManager):
|
||||
username = GlobalUserModel.normalize_username(username)
|
||||
user = self.model(username=username, email=email, **extra_fields)
|
||||
user.password = make_password(password)
|
||||
return user
|
||||
|
||||
def _create_user(self, username, email, password, **extra_fields):
|
||||
"""
|
||||
Create and save a user with the given username, email, and password.
|
||||
"""
|
||||
user = self._create_user_object(username, email, password, **extra_fields)
|
||||
user.save(using=self._db)
|
||||
return user
|
||||
|
||||
async def _acreate_user(self, username, email, password, **extra_fields):
|
||||
"""See _create_user()"""
|
||||
user = self._create_user_object(username, email, password, **extra_fields)
|
||||
await user.asave(using=self._db)
|
||||
return user
|
||||
|
||||
def create_user(self, username, email=None, password=None, **extra_fields):
|
||||
extra_fields.setdefault("is_staff", False)
|
||||
extra_fields.setdefault("is_superuser", False)
|
||||
return self._create_user(username, email, password, **extra_fields)
|
||||
|
||||
async def acreate_user(self, username, email=None, password=None, **extra_fields):
|
||||
extra_fields.setdefault("is_staff", False)
|
||||
extra_fields.setdefault("is_superuser", False)
|
||||
return await self._acreate_user(username, email, password, **extra_fields)
|
||||
|
||||
def create_superuser(self, username, email=None, password=None, **extra_fields):
|
||||
extra_fields.setdefault("is_staff", True)
|
||||
extra_fields.setdefault("is_superuser", True)
|
||||
@@ -172,6 +190,19 @@ class UserManager(BaseUserManager):
|
||||
|
||||
return self._create_user(username, email, password, **extra_fields)
|
||||
|
||||
async def acreate_superuser(
|
||||
self, username, email=None, password=None, **extra_fields
|
||||
):
|
||||
extra_fields.setdefault("is_staff", True)
|
||||
extra_fields.setdefault("is_superuser", True)
|
||||
|
||||
if extra_fields.get("is_staff") is not True:
|
||||
raise ValueError("Superuser must have is_staff=True.")
|
||||
if extra_fields.get("is_superuser") is not True:
|
||||
raise ValueError("Superuser must have is_superuser=True.")
|
||||
|
||||
return await self._acreate_user(username, email, password, **extra_fields)
|
||||
|
||||
def with_perm(
|
||||
self, perm, is_active=True, include_superusers=True, backend=None, obj=None
|
||||
):
|
||||
@@ -210,6 +241,15 @@ def _user_get_permissions(user, obj, from_name):
|
||||
return permissions
|
||||
|
||||
|
||||
async def _auser_get_permissions(user, obj, from_name):
|
||||
permissions = set()
|
||||
name = "aget_%s_permissions" % from_name
|
||||
for backend in auth.get_backends():
|
||||
if hasattr(backend, name):
|
||||
permissions.update(await getattr(backend, name)(user, obj))
|
||||
return permissions
|
||||
|
||||
|
||||
def _user_has_perm(user, perm, obj):
|
||||
"""
|
||||
A backend can raise `PermissionDenied` to short-circuit permission checking.
|
||||
@@ -225,6 +265,19 @@ def _user_has_perm(user, perm, obj):
|
||||
return False
|
||||
|
||||
|
||||
async def _auser_has_perm(user, perm, obj):
|
||||
"""See _user_has_perm()"""
|
||||
for backend in auth.get_backends():
|
||||
if not hasattr(backend, "ahas_perm"):
|
||||
continue
|
||||
try:
|
||||
if await backend.ahas_perm(user, perm, obj):
|
||||
return True
|
||||
except PermissionDenied:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def _user_has_module_perms(user, app_label):
|
||||
"""
|
||||
A backend can raise `PermissionDenied` to short-circuit permission checking.
|
||||
@@ -240,6 +293,19 @@ def _user_has_module_perms(user, app_label):
|
||||
return False
|
||||
|
||||
|
||||
async def _auser_has_module_perms(user, app_label):
|
||||
"""See _user_has_module_perms()"""
|
||||
for backend in auth.get_backends():
|
||||
if not hasattr(backend, "ahas_module_perms"):
|
||||
continue
|
||||
try:
|
||||
if await backend.ahas_module_perms(user, app_label):
|
||||
return True
|
||||
except PermissionDenied:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
class PermissionsMixin(models.Model):
|
||||
"""
|
||||
Add the fields and methods necessary to support the Group and Permission
|
||||
@@ -285,6 +351,10 @@ class PermissionsMixin(models.Model):
|
||||
"""
|
||||
return _user_get_permissions(self, obj, "user")
|
||||
|
||||
async def aget_user_permissions(self, obj=None):
|
||||
"""See get_user_permissions()"""
|
||||
return await _auser_get_permissions(self, obj, "user")
|
||||
|
||||
def get_group_permissions(self, obj=None):
|
||||
"""
|
||||
Return a list of permission strings that this user has through their
|
||||
@@ -293,9 +363,16 @@ class PermissionsMixin(models.Model):
|
||||
"""
|
||||
return _user_get_permissions(self, obj, "group")
|
||||
|
||||
async def aget_group_permissions(self, obj=None):
|
||||
"""See get_group_permissions()"""
|
||||
return await _auser_get_permissions(self, obj, "group")
|
||||
|
||||
def get_all_permissions(self, obj=None):
|
||||
return _user_get_permissions(self, obj, "all")
|
||||
|
||||
async def aget_all_permissions(self, obj=None):
|
||||
return await _auser_get_permissions(self, obj, "all")
|
||||
|
||||
def has_perm(self, perm, obj=None):
|
||||
"""
|
||||
Return True if the user has the specified permission. Query all
|
||||
@@ -311,6 +388,15 @@ class PermissionsMixin(models.Model):
|
||||
# Otherwise we need to check the backends.
|
||||
return _user_has_perm(self, perm, obj)
|
||||
|
||||
async def ahas_perm(self, perm, obj=None):
|
||||
"""See has_perm()"""
|
||||
# Active superusers have all permissions.
|
||||
if self.is_active and self.is_superuser:
|
||||
return True
|
||||
|
||||
# Otherwise we need to check the backends.
|
||||
return await _auser_has_perm(self, perm, obj)
|
||||
|
||||
def has_perms(self, perm_list, obj=None):
|
||||
"""
|
||||
Return True if the user has each of the specified permissions. If
|
||||
@@ -320,6 +406,15 @@ class PermissionsMixin(models.Model):
|
||||
raise ValueError("perm_list must be an iterable of permissions.")
|
||||
return all(self.has_perm(perm, obj) for perm in perm_list)
|
||||
|
||||
async def ahas_perms(self, perm_list, obj=None):
|
||||
"""See has_perms()"""
|
||||
if not isinstance(perm_list, Iterable) or isinstance(perm_list, str):
|
||||
raise ValueError("perm_list must be an iterable of permissions.")
|
||||
for perm in perm_list:
|
||||
if not await self.ahas_perm(perm, obj):
|
||||
return False
|
||||
return True
|
||||
|
||||
def has_module_perms(self, app_label):
|
||||
"""
|
||||
Return True if the user has any permissions in the given app label.
|
||||
@@ -331,6 +426,14 @@ class PermissionsMixin(models.Model):
|
||||
|
||||
return _user_has_module_perms(self, app_label)
|
||||
|
||||
async def ahas_module_perms(self, app_label):
|
||||
"""See has_module_perms()"""
|
||||
# Active superusers have all permissions.
|
||||
if self.is_active and self.is_superuser:
|
||||
return True
|
||||
|
||||
return await _auser_has_module_perms(self, app_label)
|
||||
|
||||
|
||||
class AbstractUser(AbstractBaseUser, PermissionsMixin):
|
||||
"""
|
||||
@@ -471,23 +574,46 @@ class AnonymousUser:
|
||||
def get_user_permissions(self, obj=None):
|
||||
return _user_get_permissions(self, obj, "user")
|
||||
|
||||
async def aget_user_permissions(self, obj=None):
|
||||
return await _auser_get_permissions(self, obj, "user")
|
||||
|
||||
def get_group_permissions(self, obj=None):
|
||||
return set()
|
||||
|
||||
async def aget_group_permissions(self, obj=None):
|
||||
return self.get_group_permissions(obj)
|
||||
|
||||
def get_all_permissions(self, obj=None):
|
||||
return _user_get_permissions(self, obj, "all")
|
||||
|
||||
async def aget_all_permissions(self, obj=None):
|
||||
return await _auser_get_permissions(self, obj, "all")
|
||||
|
||||
def has_perm(self, perm, obj=None):
|
||||
return _user_has_perm(self, perm, obj=obj)
|
||||
|
||||
async def ahas_perm(self, perm, obj=None):
|
||||
return await _auser_has_perm(self, perm, obj=obj)
|
||||
|
||||
def has_perms(self, perm_list, obj=None):
|
||||
if not isinstance(perm_list, Iterable) or isinstance(perm_list, str):
|
||||
raise ValueError("perm_list must be an iterable of permissions.")
|
||||
return all(self.has_perm(perm, obj) for perm in perm_list)
|
||||
|
||||
async def ahas_perms(self, perm_list, obj=None):
|
||||
if not isinstance(perm_list, Iterable) or isinstance(perm_list, str):
|
||||
raise ValueError("perm_list must be an iterable of permissions.")
|
||||
for perm in perm_list:
|
||||
if not await self.ahas_perm(perm, obj):
|
||||
return False
|
||||
return True
|
||||
|
||||
def has_module_perms(self, module):
|
||||
return _user_has_module_perms(self, module)
|
||||
|
||||
async def ahas_module_perms(self, module):
|
||||
return await _auser_has_module_perms(self, module)
|
||||
|
||||
@property
|
||||
def is_anonymous(self):
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user