1
0
mirror of https://github.com/django/django.git synced 2025-10-31 09:41:08 +00:00

Fixed #35303 -- Implemented async auth backends and utils.

This commit is contained in:
Jon Janzen
2024-03-31 12:29:10 -07:00
committed by Sarah Boyce
parent 4cad317ff1
commit 50f89ae850
17 changed files with 1285 additions and 61 deletions

View File

@@ -29,6 +29,19 @@ class CustomUserManager(BaseUserManager):
user.save(using=self._db)
return user
async def acreate_user(self, email, date_of_birth, password=None, **fields):
"""See create_user()"""
if not email:
raise ValueError("Users must have an email address")
user = self.model(
email=self.normalize_email(email), date_of_birth=date_of_birth, **fields
)
user.set_password(password)
await user.asave(using=self._db)
return user
def create_superuser(self, email, password, date_of_birth, **fields):
u = self.create_user(
email, password=password, date_of_birth=date_of_birth, **fields

View File

@@ -2,6 +2,8 @@ import sys
from datetime import date
from unittest import mock
from asgiref.sync import sync_to_async
from django.contrib.auth import (
BACKEND_SESSION_KEY,
SESSION_KEY,
@@ -55,17 +57,33 @@ class BaseBackendTest(TestCase):
def test_get_user_permissions(self):
self.assertEqual(self.user.get_user_permissions(), {"user_perm"})
async def test_aget_user_permissions(self):
self.assertEqual(await self.user.aget_user_permissions(), {"user_perm"})
def test_get_group_permissions(self):
self.assertEqual(self.user.get_group_permissions(), {"group_perm"})
async def test_aget_group_permissions(self):
self.assertEqual(await self.user.aget_group_permissions(), {"group_perm"})
def test_get_all_permissions(self):
self.assertEqual(self.user.get_all_permissions(), {"user_perm", "group_perm"})
async def test_aget_all_permissions(self):
self.assertEqual(
await self.user.aget_all_permissions(), {"user_perm", "group_perm"}
)
def test_has_perm(self):
self.assertIs(self.user.has_perm("user_perm"), True)
self.assertIs(self.user.has_perm("group_perm"), True)
self.assertIs(self.user.has_perm("other_perm", TestObj()), False)
async def test_ahas_perm(self):
self.assertIs(await self.user.ahas_perm("user_perm"), True)
self.assertIs(await self.user.ahas_perm("group_perm"), True)
self.assertIs(await self.user.ahas_perm("other_perm", TestObj()), False)
def test_has_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
@@ -73,6 +91,13 @@ class BaseBackendTest(TestCase):
with self.assertRaisesMessage(ValueError, msg):
self.user.has_perms(object())
async def test_ahas_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
await self.user.ahas_perms("user_perm")
with self.assertRaisesMessage(ValueError, msg):
await self.user.ahas_perms(object())
class CountingMD5PasswordHasher(MD5PasswordHasher):
"""Hasher that counts how many times it computes a hash."""
@@ -125,6 +150,25 @@ class BaseModelBackendTest:
user.save()
self.assertIs(user.has_perm("auth.test"), False)
async def test_ahas_perm(self):
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
self.assertIs(await user.ahas_perm("auth.test"), False)
user.is_staff = True
await user.asave()
self.assertIs(await user.ahas_perm("auth.test"), False)
user.is_superuser = True
await user.asave()
self.assertIs(await user.ahas_perm("auth.test"), True)
self.assertIs(await user.ahas_module_perms("auth"), True)
user.is_staff = True
user.is_superuser = True
user.is_active = False
await user.asave()
self.assertIs(await user.ahas_perm("auth.test"), False)
def test_custom_perms(self):
user = self.UserModel._default_manager.get(pk=self.user.pk)
content_type = ContentType.objects.get_for_model(Group)
@@ -174,6 +218,55 @@ class BaseModelBackendTest:
self.assertIs(user.has_perm("test"), False)
self.assertIs(user.has_perms(["auth.test2", "auth.test3"]), False)
async def test_acustom_perms(self):
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await user.user_permissions.aadd(perm)
# Reloading user to purge the _perm_cache.
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
self.assertEqual(await user.aget_all_permissions(), {"auth.test"})
self.assertEqual(await user.aget_user_permissions(), {"auth.test"})
self.assertEqual(await user.aget_group_permissions(), set())
self.assertIs(await user.ahas_module_perms("Group"), False)
self.assertIs(await user.ahas_module_perms("auth"), True)
perm = await Permission.objects.acreate(
name="test2", content_type=content_type, codename="test2"
)
await user.user_permissions.aadd(perm)
perm = await Permission.objects.acreate(
name="test3", content_type=content_type, codename="test3"
)
await user.user_permissions.aadd(perm)
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
expected_user_perms = {"auth.test2", "auth.test", "auth.test3"}
self.assertEqual(await user.aget_all_permissions(), expected_user_perms)
self.assertIs(await user.ahas_perm("test"), False)
self.assertIs(await user.ahas_perm("auth.test"), True)
self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), True)
perm = await Permission.objects.acreate(
name="test_group", content_type=content_type, codename="test_group"
)
group = await Group.objects.acreate(name="test_group")
await group.permissions.aadd(perm)
await user.groups.aadd(group)
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
self.assertEqual(
await user.aget_all_permissions(), {*expected_user_perms, "auth.test_group"}
)
self.assertEqual(await user.aget_user_permissions(), expected_user_perms)
self.assertEqual(await user.aget_group_permissions(), {"auth.test_group"})
self.assertIs(await user.ahas_perms(["auth.test3", "auth.test_group"]), True)
user = AnonymousUser()
self.assertIs(await user.ahas_perm("test"), False)
self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), False)
def test_has_no_object_perm(self):
"""Regressiontest for #12462"""
user = self.UserModel._default_manager.get(pk=self.user.pk)
@@ -188,6 +281,20 @@ class BaseModelBackendTest:
self.assertIs(user.has_perm("auth.test"), True)
self.assertEqual(user.get_all_permissions(), {"auth.test"})
async def test_ahas_no_object_perm(self):
"""See test_has_no_object_perm()"""
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await user.user_permissions.aadd(perm)
self.assertIs(await user.ahas_perm("auth.test", "object"), False)
self.assertEqual(await user.aget_all_permissions("object"), set())
self.assertIs(await user.ahas_perm("auth.test"), True)
self.assertEqual(await user.aget_all_permissions(), {"auth.test"})
def test_anonymous_has_no_permissions(self):
"""
#17903 -- Anonymous users shouldn't have permissions in
@@ -220,6 +327,38 @@ class BaseModelBackendTest:
self.assertEqual(backend.get_user_permissions(user), set())
self.assertEqual(backend.get_group_permissions(user), set())
async def test_aanonymous_has_no_permissions(self):
"""See test_anonymous_has_no_permissions()"""
backend = ModelBackend()
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
user_perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test_user"
)
group_perm = await Permission.objects.acreate(
name="test2", content_type=content_type, codename="test_group"
)
await user.user_permissions.aadd(user_perm)
group = await Group.objects.acreate(name="test_group")
await user.groups.aadd(group)
await group.permissions.aadd(group_perm)
self.assertEqual(
await backend.aget_all_permissions(user),
{"auth.test_user", "auth.test_group"},
)
self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"})
self.assertEqual(
await backend.aget_group_permissions(user), {"auth.test_group"}
)
with mock.patch.object(self.UserModel, "is_anonymous", True):
self.assertEqual(await backend.aget_all_permissions(user), set())
self.assertEqual(await backend.aget_user_permissions(user), set())
self.assertEqual(await backend.aget_group_permissions(user), set())
def test_inactive_has_no_permissions(self):
"""
#17903 -- Inactive users shouldn't have permissions in
@@ -254,11 +393,52 @@ class BaseModelBackendTest:
self.assertEqual(backend.get_user_permissions(user), set())
self.assertEqual(backend.get_group_permissions(user), set())
async def test_ainactive_has_no_permissions(self):
"""See test_inactive_has_no_permissions()"""
backend = ModelBackend()
user = await self.UserModel._default_manager.aget(pk=self.user.pk)
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
user_perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test_user"
)
group_perm = await Permission.objects.acreate(
name="test2", content_type=content_type, codename="test_group"
)
await user.user_permissions.aadd(user_perm)
group = await Group.objects.acreate(name="test_group")
await user.groups.aadd(group)
await group.permissions.aadd(group_perm)
self.assertEqual(
await backend.aget_all_permissions(user),
{"auth.test_user", "auth.test_group"},
)
self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"})
self.assertEqual(
await backend.aget_group_permissions(user), {"auth.test_group"}
)
user.is_active = False
await user.asave()
self.assertEqual(await backend.aget_all_permissions(user), set())
self.assertEqual(await backend.aget_user_permissions(user), set())
self.assertEqual(await backend.aget_group_permissions(user), set())
def test_get_all_superuser_permissions(self):
"""A superuser has all permissions. Refs #14795."""
user = self.UserModel._default_manager.get(pk=self.superuser.pk)
self.assertEqual(len(user.get_all_permissions()), len(Permission.objects.all()))
async def test_aget_all_superuser_permissions(self):
"""See test_get_all_superuser_permissions()"""
user = await self.UserModel._default_manager.aget(pk=self.superuser.pk)
self.assertEqual(
len(await user.aget_all_permissions()), await Permission.objects.acount()
)
@override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
)
@@ -277,6 +457,24 @@ class BaseModelBackendTest:
authenticate(username="no_such_user", password="test")
self.assertEqual(CountingMD5PasswordHasher.calls, 1)
@override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
)
async def test_aauthentication_timing(self):
"""See test_authentication_timing()"""
# Re-set the password, because this tests overrides PASSWORD_HASHERS.
self.user.set_password("test")
await self.user.asave()
CountingMD5PasswordHasher.calls = 0
username = getattr(self.user, self.UserModel.USERNAME_FIELD)
await aauthenticate(username=username, password="test")
self.assertEqual(CountingMD5PasswordHasher.calls, 1)
CountingMD5PasswordHasher.calls = 0
await aauthenticate(username="no_such_user", password="test")
self.assertEqual(CountingMD5PasswordHasher.calls, 1)
@override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
)
@@ -320,6 +518,15 @@ class ModelBackendTest(BaseModelBackendTest, TestCase):
self.user.save()
self.assertIsNone(authenticate(**self.user_credentials))
async def test_aauthenticate_inactive(self):
"""
An inactive user can't authenticate.
"""
self.assertEqual(await aauthenticate(**self.user_credentials), self.user)
self.user.is_active = False
await self.user.asave()
self.assertIsNone(await aauthenticate(**self.user_credentials))
@override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField")
def test_authenticate_user_without_is_active_field(self):
"""
@@ -332,6 +539,18 @@ class ModelBackendTest(BaseModelBackendTest, TestCase):
)
self.assertEqual(authenticate(username="test", password="test"), user)
@override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField")
async def test_aauthenticate_user_without_is_active_field(self):
"""
A custom user without an `is_active` field is allowed to authenticate.
"""
user = await CustomUserWithoutIsActiveField.objects._acreate_user(
username="test",
email="test@example.com",
password="test",
)
self.assertEqual(await aauthenticate(username="test", password="test"), user)
@override_settings(AUTH_USER_MODEL="auth_tests.ExtensionUser")
class ExtensionUserModelBackendTest(BaseModelBackendTest, TestCase):
@@ -403,6 +622,15 @@ class CustomUserModelBackendAuthenticateTest(TestCase):
authenticated_user = authenticate(email="test@example.com", password="test")
self.assertEqual(test_user, authenticated_user)
async def test_aauthenticate(self):
test_user = await CustomUser._default_manager.acreate_user(
email="test@example.com", password="test", date_of_birth=date(2006, 4, 25)
)
authenticated_user = await aauthenticate(
email="test@example.com", password="test"
)
self.assertEqual(test_user, authenticated_user)
@override_settings(AUTH_USER_MODEL="auth_tests.UUIDUser")
class UUIDUserTests(TestCase):
@@ -416,6 +644,13 @@ class UUIDUserTests(TestCase):
UUIDUser.objects.get(pk=self.client.session[SESSION_KEY]), user
)
async def test_alogin(self):
"""See test_login()"""
user = await UUIDUser.objects.acreate_user(username="uuid", password="test")
self.assertTrue(await self.client.alogin(username="uuid", password="test"))
session_key = await self.client.session.aget(SESSION_KEY)
self.assertEqual(await UUIDUser.objects.aget(pk=session_key), user)
class TestObj:
pass
@@ -435,9 +670,15 @@ class SimpleRowlevelBackend:
return True
return False
async def ahas_perm(self, user, perm, obj=None):
return self.has_perm(user, perm, obj)
def has_module_perms(self, user, app_label):
return (user.is_anonymous or user.is_active) and app_label == "app1"
async def ahas_module_perms(self, user, app_label):
return self.has_module_perms(user, app_label)
def get_all_permissions(self, user, obj=None):
if not obj:
return [] # We only support row level perms
@@ -452,6 +693,9 @@ class SimpleRowlevelBackend:
else:
return ["simple"]
async def aget_all_permissions(self, user, obj=None):
return self.get_all_permissions(user, obj)
def get_group_permissions(self, user, obj=None):
if not obj:
return # We only support row level perms
@@ -524,10 +768,18 @@ class AnonymousUserBackendTest(SimpleTestCase):
self.assertIs(self.user1.has_perm("perm", TestObj()), False)
self.assertIs(self.user1.has_perm("anon", TestObj()), True)
async def test_ahas_perm(self):
self.assertIs(await self.user1.ahas_perm("perm", TestObj()), False)
self.assertIs(await self.user1.ahas_perm("anon", TestObj()), True)
def test_has_perms(self):
self.assertIs(self.user1.has_perms(["anon"], TestObj()), True)
self.assertIs(self.user1.has_perms(["anon", "perm"], TestObj()), False)
async def test_ahas_perms(self):
self.assertIs(await self.user1.ahas_perms(["anon"], TestObj()), True)
self.assertIs(await self.user1.ahas_perms(["anon", "perm"], TestObj()), False)
def test_has_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
@@ -535,13 +787,27 @@ class AnonymousUserBackendTest(SimpleTestCase):
with self.assertRaisesMessage(ValueError, msg):
self.user1.has_perms(object())
async def test_ahas_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
await self.user1.ahas_perms("perm")
with self.assertRaisesMessage(ValueError, msg):
await self.user1.ahas_perms(object())
def test_has_module_perms(self):
self.assertIs(self.user1.has_module_perms("app1"), True)
self.assertIs(self.user1.has_module_perms("app2"), False)
async def test_ahas_module_perms(self):
self.assertIs(await self.user1.ahas_module_perms("app1"), True)
self.assertIs(await self.user1.ahas_module_perms("app2"), False)
def test_get_all_permissions(self):
self.assertEqual(self.user1.get_all_permissions(TestObj()), {"anon"})
async def test_aget_all_permissions(self):
self.assertEqual(await self.user1.aget_all_permissions(TestObj()), {"anon"})
@override_settings(AUTHENTICATION_BACKENDS=[])
class NoBackendsTest(TestCase):
@@ -561,6 +827,14 @@ class NoBackendsTest(TestCase):
with self.assertRaisesMessage(ImproperlyConfigured, msg):
self.user.has_perm(("perm", TestObj()))
async def test_araises_exception(self):
msg = (
"No authentication backends have been defined. "
"Does AUTHENTICATION_BACKENDS contain anything?"
)
with self.assertRaisesMessage(ImproperlyConfigured, msg):
await self.user.ahas_perm(("perm", TestObj()))
@override_settings(
AUTHENTICATION_BACKENDS=["auth_tests.test_auth_backends.SimpleRowlevelBackend"]
@@ -593,12 +867,21 @@ class PermissionDeniedBackend:
def authenticate(self, request, username=None, password=None):
raise PermissionDenied
async def aauthenticate(self, request, username=None, password=None):
raise PermissionDenied
def has_perm(self, user_obj, perm, obj=None):
raise PermissionDenied
async def ahas_perm(self, user_obj, perm, obj=None):
raise PermissionDenied
def has_module_perms(self, user_obj, app_label):
raise PermissionDenied
async def ahas_module_perms(self, user_obj, app_label):
raise PermissionDenied
class PermissionDeniedBackendTest(TestCase):
"""
@@ -631,10 +914,25 @@ class PermissionDeniedBackendTest(TestCase):
[{"password": "********************", "username": "test"}],
)
@modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
async def test_aauthenticate_permission_denied(self):
self.assertIsNone(await aauthenticate(username="test", password="test"))
# user_login_failed signal is sent.
self.assertEqual(
self.user_login_failed,
[{"password": "********************", "username": "test"}],
)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
def test_authenticates(self):
self.assertEqual(authenticate(username="test", password="test"), self.user1)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
async def test_aauthenticate(self):
self.assertEqual(
await aauthenticate(username="test", password="test"), self.user1
)
@modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
def test_has_perm_denied(self):
content_type = ContentType.objects.get_for_model(Group)
@@ -646,6 +944,17 @@ class PermissionDeniedBackendTest(TestCase):
self.assertIs(self.user1.has_perm("auth.test"), False)
self.assertIs(self.user1.has_module_perms("auth"), False)
@modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
async def test_ahas_perm_denied(self):
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await self.user1.user_permissions.aadd(perm)
self.assertIs(await self.user1.ahas_perm("auth.test"), False)
self.assertIs(await self.user1.ahas_module_perms("auth"), False)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
def test_has_perm(self):
content_type = ContentType.objects.get_for_model(Group)
@@ -657,6 +966,17 @@ class PermissionDeniedBackendTest(TestCase):
self.assertIs(self.user1.has_perm("auth.test"), True)
self.assertIs(self.user1.has_module_perms("auth"), True)
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
async def test_ahas_perm(self):
content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
perm = await Permission.objects.acreate(
name="test", content_type=content_type, codename="test"
)
await self.user1.user_permissions.aadd(perm)
self.assertIs(await self.user1.ahas_perm("auth.test"), True)
self.assertIs(await self.user1.ahas_module_perms("auth"), True)
class NewModelBackend(ModelBackend):
pass
@@ -715,6 +1035,10 @@ class TypeErrorBackend:
def authenticate(self, request, username=None, password=None):
raise TypeError
@sensitive_variables("password")
async def aauthenticate(self, request, username=None, password=None):
raise TypeError
class SkippedBackend:
def authenticate(self):

View File

@@ -1,5 +1,3 @@
from asgiref.sync import sync_to_async
from django.conf import settings
from django.contrib.auth import aget_user, get_user, get_user_model
from django.contrib.auth.models import AnonymousUser, User
@@ -44,6 +42,12 @@ class BasicTestCase(TestCase):
u2 = User.objects.create_user("testuser2", "test2@example.com")
self.assertFalse(u2.has_usable_password())
async def test_acreate(self):
u = await User.objects.acreate_user("testuser", "test@example.com", "testpw")
self.assertTrue(u.has_usable_password())
self.assertFalse(await u.acheck_password("bad"))
self.assertTrue(await u.acheck_password("testpw"))
def test_unicode_username(self):
User.objects.create_user("jörg")
User.objects.create_user("Григорий")
@@ -73,6 +77,15 @@ class BasicTestCase(TestCase):
self.assertTrue(super.is_active)
self.assertTrue(super.is_staff)
async def test_asuperuser(self):
"Check the creation and properties of a superuser"
super = await User.objects.acreate_superuser(
"super", "super@example.com", "super"
)
self.assertTrue(super.is_superuser)
self.assertTrue(super.is_active)
self.assertTrue(super.is_staff)
def test_superuser_no_email_or_password(self):
cases = [
{},
@@ -171,13 +184,25 @@ class TestGetUser(TestCase):
self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username)
async def test_aget_user(self):
created_user = await sync_to_async(User.objects.create_user)(
async def test_aget_user_fallback_secret(self):
created_user = await User.objects.acreate_user(
"testuser", "test@example.com", "testpw"
)
await self.client.alogin(username="testuser", password="testpw")
request = HttpRequest()
request.session = await self.client.asession()
user = await aget_user(request)
self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username)
prev_session_key = request.session.session_key
with override_settings(
SECRET_KEY="newsecret",
SECRET_KEY_FALLBACKS=[settings.SECRET_KEY],
):
user = await aget_user(request)
self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username)
self.assertNotEqual(request.session.session_key, prev_session_key)
# Remove the fallback secret.
# The session hash should be updated using the current secret.
with override_settings(SECRET_KEY="newsecret"):
user = await aget_user(request)
self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username)

View File

@@ -1,7 +1,5 @@
from asyncio import iscoroutinefunction
from asgiref.sync import sync_to_async
from django.conf import settings
from django.contrib.auth import models
from django.contrib.auth.decorators import (
@@ -374,7 +372,7 @@ class UserPassesTestDecoratorTest(TestCase):
def test_decorator_async_test_func(self):
async def async_test_func(user):
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
return await user.ahas_perms(["auth_tests.add_customuser"])
@user_passes_test(async_test_func)
def sync_view(request):
@@ -410,7 +408,7 @@ class UserPassesTestDecoratorTest(TestCase):
async def test_decorator_async_view_async_test_func(self):
async def async_test_func(user):
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
return await user.ahas_perms(["auth_tests.add_customuser"])
@user_passes_test(async_test_func)
async def async_view(request):

View File

@@ -1,7 +1,5 @@
from unittest import mock
from asgiref.sync import sync_to_async
from django.conf.global_settings import PASSWORD_HASHERS
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
@@ -30,10 +28,19 @@ class NaturalKeysTestCase(TestCase):
self.assertEqual(User.objects.get_by_natural_key("staff"), staff_user)
self.assertEqual(staff_user.natural_key(), ("staff",))
async def test_auser_natural_key(self):
staff_user = await User.objects.acreate_user(username="staff")
self.assertEqual(await User.objects.aget_by_natural_key("staff"), staff_user)
self.assertEqual(staff_user.natural_key(), ("staff",))
def test_group_natural_key(self):
users_group = Group.objects.create(name="users")
self.assertEqual(Group.objects.get_by_natural_key("users"), users_group)
async def test_agroup_natural_key(self):
users_group = await Group.objects.acreate(name="users")
self.assertEqual(await Group.objects.aget_by_natural_key("users"), users_group)
class LoadDataWithoutNaturalKeysTestCase(TestCase):
fixtures = ["regular.json"]
@@ -157,6 +164,17 @@ class UserManagerTestCase(TransactionTestCase):
is_superuser=False,
)
async def test_acreate_super_user_raises_error_on_false_is_superuser(self):
with self.assertRaisesMessage(
ValueError, "Superuser must have is_superuser=True."
):
await User.objects.acreate_superuser(
username="test",
email="test@test.com",
password="test",
is_superuser=False,
)
def test_create_superuser_raises_error_on_false_is_staff(self):
with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."):
User.objects.create_superuser(
@@ -166,6 +184,15 @@ class UserManagerTestCase(TransactionTestCase):
is_staff=False,
)
async def test_acreate_superuser_raises_error_on_false_is_staff(self):
with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."):
await User.objects.acreate_superuser(
username="test",
email="test@test.com",
password="test",
is_staff=False,
)
def test_runpython_manager_methods(self):
def forwards(apps, schema_editor):
UserModel = apps.get_model("auth", "User")
@@ -301,9 +328,7 @@ class AbstractUserTestCase(TestCase):
@override_settings(PASSWORD_HASHERS=PASSWORD_HASHERS)
async def test_acheck_password_upgrade(self):
user = await sync_to_async(User.objects.create_user)(
username="user", password="foo"
)
user = await User.objects.acreate_user(username="user", password="foo")
initial_password = user.password
self.assertIs(await user.acheck_password("foo"), True)
hasher = get_hasher("default")
@@ -557,6 +582,12 @@ class AnonymousUserTests(SimpleTestCase):
self.assertEqual(self.user.get_user_permissions(), set())
self.assertEqual(self.user.get_group_permissions(), set())
async def test_properties_async_versions(self):
self.assertEqual(await self.user.groups.acount(), 0)
self.assertEqual(await self.user.user_permissions.acount(), 0)
self.assertEqual(await self.user.aget_user_permissions(), set())
self.assertEqual(await self.user.aget_group_permissions(), set())
def test_str(self):
self.assertEqual(str(self.user), "AnonymousUser")

View File

@@ -1,12 +1,18 @@
from datetime import datetime, timezone
from django.conf import settings
from django.contrib.auth import authenticate
from django.contrib.auth import aauthenticate, authenticate
from django.contrib.auth.backends import RemoteUserBackend
from django.contrib.auth.middleware import RemoteUserMiddleware
from django.contrib.auth.models import User
from django.middleware.csrf import _get_new_csrf_string, _mask_cipher_secret
from django.test import Client, TestCase, modify_settings, override_settings
from django.test import (
AsyncClient,
Client,
TestCase,
modify_settings,
override_settings,
)
@override_settings(ROOT_URLCONF="auth_tests.urls")
@@ -30,6 +36,11 @@ class RemoteUserTest(TestCase):
)
super().setUpClass()
def test_passing_explicit_none(self):
msg = "get_response must be provided."
with self.assertRaisesMessage(ValueError, msg):
RemoteUserMiddleware(None)
def test_no_remote_user(self):
"""Users are not created when remote user is not specified."""
num_users = User.objects.count()
@@ -46,6 +57,18 @@ class RemoteUserTest(TestCase):
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(User.objects.count(), num_users)
async def test_no_remote_user_async(self):
"""See test_no_remote_user."""
num_users = await User.objects.acount()
response = await self.async_client.get("/remote_user/")
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)
response = await self.async_client.get("/remote_user/", **{self.header: ""})
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)
def test_csrf_validation_passes_after_process_request_login(self):
"""
CSRF check must access the CSRF token from the session or cookie,
@@ -75,6 +98,31 @@ class RemoteUserTest(TestCase):
response = csrf_client.post("/remote_user/", data, **headers)
self.assertEqual(response.status_code, 200)
async def test_csrf_validation_passes_after_process_request_login_async(self):
"""See test_csrf_validation_passes_after_process_request_login."""
csrf_client = AsyncClient(enforce_csrf_checks=True)
csrf_secret = _get_new_csrf_string()
csrf_token = _mask_cipher_secret(csrf_secret)
csrf_token_form = _mask_cipher_secret(csrf_secret)
headers = {self.header: "fakeuser"}
data = {"csrfmiddlewaretoken": csrf_token_form}
# Verify that CSRF is configured for the view
csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
response = await csrf_client.post("/remote_user/", **headers)
self.assertEqual(response.status_code, 403)
self.assertIn(b"CSRF verification failed.", response.content)
# This request will call django.contrib.auth.alogin() which will call
# django.middleware.csrf.rotate_token() thus changing the value of
# request.META['CSRF_COOKIE'] from the user submitted value set by
# CsrfViewMiddleware.process_request() to the new csrftoken value set
# by rotate_token(). Csrf validation should still pass when the view is
# later processed by CsrfViewMiddleware.process_view()
csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
response = await csrf_client.post("/remote_user/", data, **headers)
self.assertEqual(response.status_code, 200)
def test_unknown_user(self):
"""
Tests the case where the username passed in the header does not exist
@@ -90,6 +138,22 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/", **{self.header: "newuser"})
self.assertEqual(User.objects.count(), num_users + 1)
async def test_unknown_user_async(self):
"""See test_unknown_user."""
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertEqual(response.context["user"].username, "newuser")
self.assertEqual(await User.objects.acount(), num_users + 1)
await User.objects.aget(username="newuser")
# Another request with same user should not create any new users.
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertEqual(await User.objects.acount(), num_users + 1)
def test_known_user(self):
"""
Tests the case where the username passed in the header is a valid User.
@@ -106,6 +170,24 @@ class RemoteUserTest(TestCase):
self.assertEqual(response.context["user"].username, "knownuser2")
self.assertEqual(User.objects.count(), num_users)
async def test_known_user_async(self):
"""See test_known_user."""
await User.objects.acreate(username="knownuser")
await User.objects.acreate(username="knownuser2")
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
self.assertEqual(await User.objects.acount(), num_users)
# A different user passed in the headers causes the new user
# to be logged in.
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user2}
)
self.assertEqual(response.context["user"].username, "knownuser2")
self.assertEqual(await User.objects.acount(), num_users)
def test_last_login(self):
"""
A user's last_login is set the first time they make a
@@ -128,6 +210,29 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/", **{self.header: self.known_user})
self.assertEqual(default_login, response.context["user"].last_login)
async def test_last_login_async(self):
"""See test_last_login."""
user = await User.objects.acreate(username="knownuser")
# Set last_login to something so we can determine if it changes.
default_login = datetime(2000, 1, 1)
if settings.USE_TZ:
default_login = default_login.replace(tzinfo=timezone.utc)
user.last_login = default_login
await user.asave()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertNotEqual(default_login, response.context["user"].last_login)
user = await User.objects.aget(username="knownuser")
user.last_login = default_login
await user.asave()
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(default_login, response.context["user"].last_login)
def test_header_disappears(self):
"""
A logged in user is logged out automatically when
@@ -148,6 +253,25 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/")
self.assertEqual(response.context["user"].username, "modeluser")
async def test_header_disappears_async(self):
"""See test_header_disappears."""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# During the session, the REMOTE_USER header disappears. Should trigger logout.
response = await self.async_client.get("/remote_user/")
self.assertTrue(response.context["user"].is_anonymous)
# verify the remoteuser middleware will not remove a user
# authenticated via another backend
await User.objects.acreate_user(username="modeluser", password="foo")
await self.async_client.alogin(username="modeluser", password="foo")
await aauthenticate(username="modeluser", password="foo")
response = await self.async_client.get("/remote_user/")
self.assertEqual(response.context["user"].username, "modeluser")
def test_user_switch_forces_new_login(self):
"""
If the username in the header changes between requests
@@ -164,11 +288,35 @@ class RemoteUserTest(TestCase):
# In backends that do not create new users, it is '' (anonymous user)
self.assertNotEqual(response.context["user"].username, "knownuser")
async def test_user_switch_forces_new_login_async(self):
"""See test_user_switch_forces_new_login."""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# During the session, the REMOTE_USER changes to a different user.
response = await self.async_client.get(
"/remote_user/", **{self.header: "newnewuser"}
)
# The current user is not the prior remote_user.
# In backends that create a new user, username is "newnewuser"
# In backends that do not create new users, it is '' (anonymous user)
self.assertNotEqual(response.context["user"].username, "knownuser")
def test_inactive_user(self):
User.objects.create(username="knownuser", is_active=False)
response = self.client.get("/remote_user/", **{self.header: "knownuser"})
self.assertTrue(response.context["user"].is_anonymous)
async def test_inactive_user_async(self):
await User.objects.acreate(username="knownuser", is_active=False)
response = await self.async_client.get(
"/remote_user/", **{self.header: "knownuser"}
)
self.assertTrue(response.context["user"].is_anonymous)
class RemoteUserNoCreateBackend(RemoteUserBackend):
"""Backend that doesn't create unknown users."""
@@ -190,6 +338,14 @@ class RemoteUserNoCreateTest(RemoteUserTest):
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(User.objects.count(), num_users)
async def test_unknown_user_async(self):
num_users = await User.objects.acount()
response = await self.async_client.get(
"/remote_user/", **{self.header: "newuser"}
)
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(await User.objects.acount(), num_users)
class AllowAllUsersRemoteUserBackendTest(RemoteUserTest):
"""Backend that allows inactive users."""
@@ -201,6 +357,13 @@ class AllowAllUsersRemoteUserBackendTest(RemoteUserTest):
response = self.client.get("/remote_user/", **{self.header: self.known_user})
self.assertEqual(response.context["user"].username, user.username)
async def test_inactive_user_async(self):
user = await User.objects.acreate(username="knownuser", is_active=False)
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, user.username)
class CustomRemoteUserBackend(RemoteUserBackend):
"""
@@ -311,3 +474,16 @@ class PersistentRemoteUserTest(RemoteUserTest):
response = self.client.get("/remote_user/")
self.assertFalse(response.context["user"].is_anonymous)
self.assertEqual(response.context["user"].username, "knownuser")
async def test_header_disappears_async(self):
"""See test_header_disappears."""
await User.objects.acreate(username="knownuser")
# Known user authenticates
response = await self.async_client.get(
"/remote_user/", **{self.header: self.known_user}
)
self.assertEqual(response.context["user"].username, "knownuser")
# Should stay logged in if the REMOTE_USER header disappears.
response = await self.async_client.get("/remote_user/")
self.assertFalse(response.context["user"].is_anonymous)
self.assertEqual(response.context["user"].username, "knownuser")