1
0
mirror of https://github.com/django/django.git synced 2025-10-23 21:59:11 +00:00

Fixed #34901 -- Added async-compatible interface to session engines.

Thanks Andrew-Chen-Wang for the initial implementation which was posted
to the Django forum thread about asyncifying contrib modules.
This commit is contained in:
Jon Janzen
2023-10-16 18:50:20 -07:00
committed by Mariusz Felisiak
parent 33c06ca0da
commit f5c340684b
12 changed files with 975 additions and 9 deletions

View File

@@ -61,11 +61,19 @@ class SessionTestsMixin:
def test_get_empty(self):
self.assertIsNone(self.session.get("cat"))
async def test_get_empty_async(self):
self.assertIsNone(await self.session.aget("cat"))
def test_store(self):
self.session["cat"] = "dog"
self.assertIs(self.session.modified, True)
self.assertEqual(self.session.pop("cat"), "dog")
async def test_store_async(self):
await self.session.aset("cat", "dog")
self.assertIs(self.session.modified, True)
self.assertEqual(await self.session.apop("cat"), "dog")
def test_pop(self):
self.session["some key"] = "exists"
# Need to reset these to pretend we haven't accessed it:
@@ -77,6 +85,17 @@ class SessionTestsMixin:
self.assertIs(self.session.modified, True)
self.assertIsNone(self.session.get("some key"))
async def test_pop_async(self):
await self.session.aset("some key", "exists")
# Need to reset these to pretend we haven't accessed it:
self.accessed = False
self.modified = False
self.assertEqual(await self.session.apop("some key"), "exists")
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
self.assertIsNone(await self.session.aget("some key"))
def test_pop_default(self):
self.assertEqual(
self.session.pop("some key", "does not exist"), "does not exist"
@@ -84,6 +103,13 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
async def test_pop_default_async(self):
self.assertEqual(
await self.session.apop("some key", "does not exist"), "does not exist"
)
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_pop_default_named_argument(self):
self.assertEqual(
self.session.pop("some key", default="does not exist"), "does not exist"
@@ -91,22 +117,46 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
async def test_pop_default_named_argument_async(self):
self.assertEqual(
await self.session.apop("some key", default="does not exist"),
"does not exist",
)
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_pop_no_default_keyerror_raised(self):
with self.assertRaises(KeyError):
self.session.pop("some key")
async def test_pop_no_default_keyerror_raised_async(self):
with self.assertRaises(KeyError):
await self.session.apop("some key")
def test_setdefault(self):
self.assertEqual(self.session.setdefault("foo", "bar"), "bar")
self.assertEqual(self.session.setdefault("foo", "baz"), "bar")
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
async def test_setdefault_async(self):
self.assertEqual(await self.session.asetdefault("foo", "bar"), "bar")
self.assertEqual(await self.session.asetdefault("foo", "baz"), "bar")
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
def test_update(self):
self.session.update({"update key": 1})
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
self.assertEqual(self.session.get("update key", None), 1)
async def test_update_async(self):
await self.session.aupdate({"update key": 1})
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
self.assertEqual(await self.session.aget("update key", None), 1)
def test_has_key(self):
self.session["some key"] = 1
self.session.modified = False
@@ -115,6 +165,14 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
async def test_has_key_async(self):
await self.session.aset("some key", 1)
self.session.modified = False
self.session.accessed = False
self.assertIs(await self.session.ahas_key("some key"), True)
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_values(self):
self.assertEqual(list(self.session.values()), [])
self.assertIs(self.session.accessed, True)
@@ -125,6 +183,16 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
async def test_values_async(self):
self.assertEqual(list(await self.session.avalues()), [])
self.assertIs(self.session.accessed, True)
await self.session.aset("some key", 1)
self.session.modified = False
self.session.accessed = False
self.assertEqual(list(await self.session.avalues()), [1])
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_keys(self):
self.session["x"] = 1
self.session.modified = False
@@ -133,6 +201,14 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
async def test_keys_async(self):
await self.session.aset("x", 1)
self.session.modified = False
self.session.accessed = False
self.assertEqual(list(await self.session.akeys()), ["x"])
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_items(self):
self.session["x"] = 1
self.session.modified = False
@@ -141,6 +217,14 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
async def test_items_async(self):
await self.session.aset("x", 1)
self.session.modified = False
self.session.accessed = False
self.assertEqual(list(await self.session.aitems()), [("x", 1)])
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_clear(self):
self.session["x"] = 1
self.session.modified = False
@@ -155,11 +239,20 @@ class SessionTestsMixin:
self.session.save()
self.assertIs(self.session.exists(self.session.session_key), True)
async def test_save_async(self):
await self.session.asave()
self.assertIs(await self.session.aexists(self.session.session_key), True)
def test_delete(self):
self.session.save()
self.session.delete(self.session.session_key)
self.assertIs(self.session.exists(self.session.session_key), False)
async def test_delete_async(self):
await self.session.asave()
await self.session.adelete(self.session.session_key)
self.assertIs(await self.session.aexists(self.session.session_key), False)
def test_flush(self):
self.session["foo"] = "bar"
self.session.save()
@@ -171,6 +264,17 @@ class SessionTestsMixin:
self.assertIs(self.session.modified, True)
self.assertIs(self.session.accessed, True)
async def test_flush_async(self):
await self.session.aset("foo", "bar")
await self.session.asave()
prev_key = self.session.session_key
await self.session.aflush()
self.assertIs(await self.session.aexists(prev_key), False)
self.assertNotEqual(self.session.session_key, prev_key)
self.assertIsNone(self.session.session_key)
self.assertIs(self.session.modified, True)
self.assertIs(self.session.accessed, True)
def test_cycle(self):
self.session["a"], self.session["b"] = "c", "d"
self.session.save()
@@ -181,6 +285,17 @@ class SessionTestsMixin:
self.assertNotEqual(self.session.session_key, prev_key)
self.assertEqual(list(self.session.items()), prev_data)
async def test_cycle_async(self):
await self.session.aset("a", "c")
await self.session.aset("b", "d")
await self.session.asave()
prev_key = self.session.session_key
prev_data = list(await self.session.aitems())
await self.session.acycle_key()
self.assertIs(await self.session.aexists(prev_key), False)
self.assertNotEqual(self.session.session_key, prev_key)
self.assertEqual(list(await self.session.aitems()), prev_data)
def test_cycle_with_no_session_cache(self):
self.session["a"], self.session["b"] = "c", "d"
self.session.save()
@@ -190,11 +305,26 @@ class SessionTestsMixin:
self.session.cycle_key()
self.assertCountEqual(self.session.items(), prev_data)
async def test_cycle_with_no_session_cache_async(self):
await self.session.aset("a", "c")
await self.session.aset("b", "d")
await self.session.asave()
prev_data = await self.session.aitems()
self.session = self.backend(self.session.session_key)
self.assertIs(hasattr(self.session, "_session_cache"), False)
await self.session.acycle_key()
self.assertCountEqual(await self.session.aitems(), prev_data)
def test_save_doesnt_clear_data(self):
self.session["a"] = "b"
self.session.save()
self.assertEqual(self.session["a"], "b")
async def test_save_doesnt_clear_data_async(self):
await self.session.aset("a", "b")
await self.session.asave()
self.assertEqual(await self.session.aget("a"), "b")
def test_invalid_key(self):
# Submitting an invalid session key (either by guessing, or if the db has
# removed the key) results in a new key being generated.
@@ -209,6 +339,20 @@ class SessionTestsMixin:
# session key; make sure that entry is manually deleted
session.delete("1")
async def test_invalid_key_async(self):
# Submitting an invalid session key (either by guessing, or if the db has
# removed the key) results in a new key being generated.
try:
session = self.backend("1")
await session.asave()
self.assertNotEqual(session.session_key, "1")
self.assertIsNone(await session.aget("cat"))
await session.adelete()
finally:
# Some backends leave a stale cache entry for the invalid
# session key; make sure that entry is manually deleted
await session.adelete("1")
def test_session_key_empty_string_invalid(self):
"""Falsey values (Such as an empty string) are rejected."""
self.session._session_key = ""
@@ -241,6 +385,18 @@ class SessionTestsMixin:
self.session.set_expiry(0)
self.assertEqual(self.session.get_expiry_age(), settings.SESSION_COOKIE_AGE)
async def test_default_expiry_async(self):
# A normal session has a max age equal to settings.
self.assertEqual(
await self.session.aget_expiry_age(), settings.SESSION_COOKIE_AGE
)
# So does a custom session with an idle expiration time of 0 (but it'll
# expire at browser close).
await self.session.aset_expiry(0)
self.assertEqual(
await self.session.aget_expiry_age(), settings.SESSION_COOKIE_AGE
)
def test_custom_expiry_seconds(self):
modification = timezone.now()
@@ -252,6 +408,17 @@ class SessionTestsMixin:
age = self.session.get_expiry_age(modification=modification)
self.assertEqual(age, 10)
async def test_custom_expiry_seconds_async(self):
modification = timezone.now()
await self.session.aset_expiry(10)
date = await self.session.aget_expiry_date(modification=modification)
self.assertEqual(date, modification + timedelta(seconds=10))
age = await self.session.aget_expiry_age(modification=modification)
self.assertEqual(age, 10)
def test_custom_expiry_timedelta(self):
modification = timezone.now()
@@ -269,6 +436,23 @@ class SessionTestsMixin:
age = self.session.get_expiry_age(modification=modification)
self.assertEqual(age, 10)
async def test_custom_expiry_timedelta_async(self):
modification = timezone.now()
# Mock timezone.now, because set_expiry calls it on this code path.
original_now = timezone.now
try:
timezone.now = lambda: modification
await self.session.aset_expiry(timedelta(seconds=10))
finally:
timezone.now = original_now
date = await self.session.aget_expiry_date(modification=modification)
self.assertEqual(date, modification + timedelta(seconds=10))
age = await self.session.aget_expiry_age(modification=modification)
self.assertEqual(age, 10)
def test_custom_expiry_datetime(self):
modification = timezone.now()
@@ -280,12 +464,31 @@ class SessionTestsMixin:
age = self.session.get_expiry_age(modification=modification)
self.assertEqual(age, 10)
async def test_custom_expiry_datetime_async(self):
modification = timezone.now()
await self.session.aset_expiry(modification + timedelta(seconds=10))
date = await self.session.aget_expiry_date(modification=modification)
self.assertEqual(date, modification + timedelta(seconds=10))
age = await self.session.aget_expiry_age(modification=modification)
self.assertEqual(age, 10)
def test_custom_expiry_reset(self):
self.session.set_expiry(None)
self.session.set_expiry(10)
self.session.set_expiry(None)
self.assertEqual(self.session.get_expiry_age(), settings.SESSION_COOKIE_AGE)
async def test_custom_expiry_reset_async(self):
await self.session.aset_expiry(None)
await self.session.aset_expiry(10)
await self.session.aset_expiry(None)
self.assertEqual(
await self.session.aget_expiry_age(), settings.SESSION_COOKIE_AGE
)
def test_get_expire_at_browser_close(self):
# Tests get_expire_at_browser_close with different settings and different
# set_expiry calls
@@ -309,6 +512,29 @@ class SessionTestsMixin:
self.session.set_expiry(None)
self.assertIs(self.session.get_expire_at_browser_close(), True)
async def test_get_expire_at_browser_close_async(self):
# Tests get_expire_at_browser_close with different settings and different
# set_expiry calls
with override_settings(SESSION_EXPIRE_AT_BROWSER_CLOSE=False):
await self.session.aset_expiry(10)
self.assertIs(await self.session.aget_expire_at_browser_close(), False)
await self.session.aset_expiry(0)
self.assertIs(await self.session.aget_expire_at_browser_close(), True)
await self.session.aset_expiry(None)
self.assertIs(await self.session.aget_expire_at_browser_close(), False)
with override_settings(SESSION_EXPIRE_AT_BROWSER_CLOSE=True):
await self.session.aset_expiry(10)
self.assertIs(await self.session.aget_expire_at_browser_close(), False)
await self.session.aset_expiry(0)
self.assertIs(await self.session.aget_expire_at_browser_close(), True)
await self.session.aset_expiry(None)
self.assertIs(await self.session.aget_expire_at_browser_close(), True)
def test_decode(self):
# Ensure we can decode what we encode
data = {"a test key": "a test value"}
@@ -350,6 +576,22 @@ class SessionTestsMixin:
self.session.delete(old_session_key)
self.session.delete(new_session_key)
async def test_actual_expiry_async(self):
old_session_key = None
new_session_key = None
try:
await self.session.aset("foo", "bar")
await self.session.aset_expiry(-timedelta(seconds=10))
await self.session.asave()
old_session_key = self.session.session_key
# With an expiry date in the past, the session expires instantly.
new_session = self.backend(self.session.session_key)
new_session_key = new_session.session_key
self.assertIs(await new_session.ahas_key("foo"), False)
finally:
await self.session.adelete(old_session_key)
await self.session.adelete(new_session_key)
def test_session_load_does_not_create_record(self):
"""
Loading an unknown session key does not create a session record.
@@ -364,6 +606,15 @@ class SessionTestsMixin:
# provided unknown key was cycled, not reused
self.assertNotEqual(session.session_key, "someunknownkey")
async def test_session_load_does_not_create_record_async(self):
session = self.backend("someunknownkey")
await session.aload()
self.assertIsNone(session.session_key)
self.assertIs(await session.aexists(session.session_key), False)
# Provided unknown key was cycled, not reused.
self.assertNotEqual(session.session_key, "someunknownkey")
def test_session_save_does_not_resurrect_session_logged_out_in_other_context(self):
"""
Sessions shouldn't be resurrected by a concurrent request.
@@ -386,6 +637,28 @@ class SessionTestsMixin:
self.assertEqual(s1.load(), {})
async def test_session_asave_does_not_resurrect_session_logged_out_in_other_context(
self,
):
"""Sessions shouldn't be resurrected by a concurrent request."""
# Create new session.
s1 = self.backend()
await s1.aset("test_data", "value1")
await s1.asave(must_create=True)
# Logout in another context.
s2 = self.backend(s1.session_key)
await s2.adelete()
# Modify session in first context.
await s1.aset("test_data", "value2")
with self.assertRaises(UpdateError):
# This should throw an exception as the session is deleted, not
# resurrect the session.
await s1.asave()
self.assertEqual(await s1.aload(), {})
class DatabaseSessionTests(SessionTestsMixin, TestCase):
backend = DatabaseSession
@@ -456,6 +729,25 @@ class DatabaseSessionTests(SessionTestsMixin, TestCase):
# ... and one is deleted.
self.assertEqual(1, self.model.objects.count())
async def test_aclear_expired(self):
self.assertEqual(await self.model.objects.acount(), 0)
# Object in the future.
await self.session.aset("key", "value")
await self.session.aset_expiry(3600)
await self.session.asave()
# Object in the past.
other_session = self.backend()
await other_session.aset("key", "value")
await other_session.aset_expiry(-3600)
await other_session.asave()
# Two sessions are in the database before clearing expired.
self.assertEqual(await self.model.objects.acount(), 2)
await self.session.aclear_expired()
await other_session.aclear_expired()
self.assertEqual(await self.model.objects.acount(), 1)
@override_settings(USE_TZ=True)
class DatabaseSessionWithTimeZoneTests(DatabaseSessionTests):
@@ -491,11 +783,28 @@ class CustomDatabaseSessionTests(DatabaseSessionTests):
self.session.set_expiry(None)
self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age)
async def test_custom_expiry_reset_async(self):
await self.session.aset_expiry(None)
await self.session.aset_expiry(10)
await self.session.aset_expiry(None)
self.assertEqual(
await self.session.aget_expiry_age(), self.custom_session_cookie_age
)
def test_default_expiry(self):
self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age)
self.session.set_expiry(0)
self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age)
async def test_default_expiry_async(self):
self.assertEqual(
await self.session.aget_expiry_age(), self.custom_session_cookie_age
)
await self.session.aset_expiry(0)
self.assertEqual(
await self.session.aget_expiry_age(), self.custom_session_cookie_age
)
class CacheDBSessionTests(SessionTestsMixin, TestCase):
backend = CacheDBSession
@@ -533,6 +842,22 @@ class CacheDBSessionTests(SessionTestsMixin, TestCase):
self.assertEqual(log.message, f"Error saving to cache ({session._cache})")
self.assertEqual(str(log.exc_info[1]), "Faked exception saving to cache")
@override_settings(
CACHES={"default": {"BACKEND": "cache.failing_cache.CacheClass"}}
)
async def test_cache_async_set_failure_non_fatal(self):
"""Failing to write to the cache does not raise errors."""
session = self.backend()
await session.aset("key", "val")
with self.assertLogs("django.contrib.sessions", "ERROR") as cm:
await session.asave()
# A proper ERROR log message was recorded.
log = cm.records[-1]
self.assertEqual(log.message, f"Error saving to cache ({session._cache})")
self.assertEqual(str(log.exc_info[1]), "Faked exception saving to cache")
@override_settings(USE_TZ=True)
class CacheDBSessionWithTimeZoneTests(CacheDBSessionTests):
@@ -673,6 +998,12 @@ class CacheSessionTests(SessionTestsMixin, SimpleTestCase):
self.session.save()
self.assertIsNotNone(caches["default"].get(self.session.cache_key))
async def test_create_and_save_async(self):
self.session = self.backend()
await self.session.acreate()
await self.session.asave()
self.assertIsNotNone(caches["default"].get(await self.session.acache_key()))
class SessionMiddlewareTests(TestCase):
request_factory = RequestFactory()
@@ -899,6 +1230,9 @@ class CookieSessionTests(SessionTestsMixin, SimpleTestCase):
"""
pass
async def test_save_async(self):
pass
def test_cycle(self):
"""
This test tested cycle_key() which would create a new session
@@ -908,11 +1242,17 @@ class CookieSessionTests(SessionTestsMixin, SimpleTestCase):
"""
pass
async def test_cycle_async(self):
pass
@unittest.expectedFailure
def test_actual_expiry(self):
# The cookie backend doesn't handle non-default expiry dates, see #19201
super().test_actual_expiry()
async def test_actual_expiry_async(self):
pass
def test_unpickling_exception(self):
# signed_cookies backend should handle unpickle exceptions gracefully
# by creating a new session
@@ -927,12 +1267,26 @@ class CookieSessionTests(SessionTestsMixin, SimpleTestCase):
def test_session_load_does_not_create_record(self):
pass
@unittest.skip(
"Cookie backend doesn't have an external store to create records in."
)
async def test_session_load_does_not_create_record_async(self):
pass
@unittest.skip(
"CookieSession is stored in the client and there is no way to query it."
)
def test_session_save_does_not_resurrect_session_logged_out_in_other_context(self):
pass
@unittest.skip(
"CookieSession is stored in the client and there is no way to query it."
)
async def test_session_asave_does_not_resurrect_session_logged_out_in_other_context(
self,
):
pass
class ClearSessionsCommandTests(SimpleTestCase):
def test_clearsessions_unsupported(self):
@@ -956,26 +1310,51 @@ class SessionBaseTests(SimpleTestCase):
with self.assertRaisesMessage(NotImplementedError, msg):
self.session.create()
async def test_acreate(self):
msg = self.not_implemented_msg % "a create"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.acreate()
def test_delete(self):
msg = self.not_implemented_msg % "a delete"
with self.assertRaisesMessage(NotImplementedError, msg):
self.session.delete()
async def test_adelete(self):
msg = self.not_implemented_msg % "a delete"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.adelete()
def test_exists(self):
msg = self.not_implemented_msg % "an exists"
with self.assertRaisesMessage(NotImplementedError, msg):
self.session.exists(None)
async def test_aexists(self):
msg = self.not_implemented_msg % "an exists"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.aexists(None)
def test_load(self):
msg = self.not_implemented_msg % "a load"
with self.assertRaisesMessage(NotImplementedError, msg):
self.session.load()
async def test_aload(self):
msg = self.not_implemented_msg % "a load"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.aload()
def test_save(self):
msg = self.not_implemented_msg % "a save"
with self.assertRaisesMessage(NotImplementedError, msg):
self.session.save()
async def test_asave(self):
msg = self.not_implemented_msg % "a save"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.asave()
def test_test_cookie(self):
self.assertIs(self.session.has_key(self.session.TEST_COOKIE_NAME), False)
self.session.set_test_cookie()
@@ -983,5 +1362,12 @@ class SessionBaseTests(SimpleTestCase):
self.session.delete_test_cookie()
self.assertIs(self.session.has_key(self.session.TEST_COOKIE_NAME), False)
async def test_atest_cookie(self):
self.assertIs(await self.session.ahas_key(self.session.TEST_COOKIE_NAME), False)
await self.session.aset_test_cookie()
self.assertIs(await self.session.atest_cookie_worked(), True)
await self.session.adelete_test_cookie()
self.assertIs(await self.session.ahas_key(self.session.TEST_COOKIE_NAME), False)
def test_is_empty(self):
self.assertIs(self.session.is_empty(), True)