1
0
mirror of https://github.com/django/django.git synced 2025-03-31 11:37:06 +00:00

Initial application of last patch from #13376 and made tests pass

This commit is contained in:
Andrew Miller 2024-06-19 20:37:38 +01:00
parent 38ad710aba
commit 45b94ce077
12 changed files with 321 additions and 36 deletions

View File

@ -19,7 +19,7 @@ class MessageFailure(Exception):
pass
def add_message(request, level, message, extra_tags="", fail_silently=False):
def add_message(request, level, message, extra_tags="", fail_silently=False, restrictions=[]):
"""
Attempt to add a message to the request using the 'messages' app.
"""
@ -37,7 +37,7 @@ def add_message(request, level, message, extra_tags="", fail_silently=False):
"django.contrib.messages.middleware.MessageMiddleware"
)
else:
return messages.add(level, message, extra_tags)
return messages.add(level, message, extra_tags, restrictions = restrictions)
def get_messages(request):

View File

@ -0,0 +1,105 @@
import time as time_provider
class Restriction:
JSON_SEPARATOR = '@'
def __init__(self, name):
self.type = name
def is_expired(self):
"""
indicates whether given message should be removed
"""
raise NotImplementedError ()
def on_display(self):
"""
called when iterated - does nothing by default
"""
pass
def to_json(self):
"""
returns json representation of restriction
"""
raise NotImplementedError
@classmethod
def from_json_param(cls, *args):
"""
returns restriction on the basis of data encoded in json
"""
raise NotImplementedError
def __cmp__(self, other):
if self.__eq__(other): return 0
return -1
class TimeRestriction(Restriction):
JSON_TYPE_CODE = 't'
def __init__(self, seconds):
"""
seconds - expiration time since now
"""
Restriction.__init__(self, 'time')
created = time_provider.time()
self.expires = created + int(seconds)
def set_expirity_time(self, expiration_time):
"""
Sets expilcity expiration time
"""
self.expires = int(expiration_time)
def is_expired(self):
return self.expires < time_provider.time()
def __eq__(self, other):
return self.type == other.type and not bool(self.expires ^ other.expires)
def __hash__(self):
return self.expires
def to_json(self):
return '%s%s%s' % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.expires)
@classmethod
def from_json_param(cls, expirity_time):
ret = TimeRestriction(0)
ret.set_expirity_time(expirity_time)
return ret
class AmountRestriction(Restriction):
JSON_TYPE_CODE = 'a'
def __init__(self, amount):
assert int(amount) >= 0
Restriction.__init__(self, 'amount')
self.can_be_shown = int(amount)
def on_display(self):
self.can_be_shown -= 1
def is_expired(self):
return int(self.can_be_shown) <= 0
def __eq__(self, other):
return self.type == other.type and not bool(self.can_be_shown ^ other.can_be_shown)
def __hash__(self):
return self.can_be_shown
def __repr__(self):
return self.to_json()
def to_json(self):
return '%s%s%s' % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.can_be_shown)
@classmethod
def from_json_param(cls, amount):
return AmountRestriction(amount)

View File

@ -1,10 +1,29 @@
from django.conf import settings
from django.contrib.messages import constants, utils
from django.contrib.messages import constants, utils, restrictions as res
from django.utils.functional import SimpleLazyObject
LEVEL_TAGS = SimpleLazyObject(utils.get_level_tags)
class RestrictionsContainer(list):
res_map = {res.AmountRestriction.JSON_TYPE_CODE: res.AmountRestriction,
res.TimeRestriction.JSON_TYPE_CODE: res.TimeRestriction}
def to_json_obj(self):
return [r.to_json() for r in self]
@classmethod
def create_from_josn(cls, enc_restrictions):
#set_trace()
ret = []
for r in enc_restrictions:
restriction_type, values = r.split(res.Restriction.JSON_SEPARATOR)
ret.append(cls.res_map[restriction_type].from_json_param(values))
return RestrictionsContainer(ret)
def __eq__(self, other):
return set(self) == set(other)
class Message:
"""
Represent an actual message that can be stored in any of the supported
@ -12,10 +31,14 @@ class Message:
or template.
"""
def __init__(self, level, message, extra_tags=None):
def __init__(self, level, message, extra_tags=None, restrictions = []):
self.level = int(level)
self.message = message
self.extra_tags = extra_tags
self.restrictions = restrictions or list([res.AmountRestriction(1)])
self.restrictions = RestrictionsContainer(self.restrictions)
# if not given any restriction - one show by default
# todo: self.restrictions =
def _prepare(self):
"""
@ -28,11 +51,14 @@ class Message:
def __eq__(self, other):
if not isinstance(other, Message):
return NotImplemented
return self.level == other.level and self.message == other.message
return self.level == other.level and self.message == other.message and self.restrictions == other.restrictions
def __str__(self):
return str(self.message)
def __hash__(self):
return hash(self.message)
def __repr__(self):
extra_tags = f", extra_tags={self.extra_tags!r}" if self.extra_tags else ""
return f"Message(level={self.level}, message={self.message!r}{extra_tags})"
@ -45,6 +71,14 @@ class Message:
def level_tag(self):
return LEVEL_TAGS.get(self.level, "")
def active(self):
for r in self.restrictions:
if r.is_expired(): return False
return True
def on_display(self):
for r in self.restrictions:
r.on_display()
class BaseStorage:
"""
@ -62,14 +96,33 @@ class BaseStorage:
super().__init__(*args, **kwargs)
def __len__(self):
return len(self._loaded_messages) + len(self._queued_messages)
# in case that there was a call for render template which would cause iterating throught messages,
# and then (e.g. in some middleware, would be call for iterating through messages (e.g. by iterating of context['messages'])
# TODO: implement a way to access messages without affecting calling __iter__ method
all_msgs = set(self._loaded_messages + self._queued_messages)
return len(all_msgs)
#return len(self._loaded_messages) + len(self._queued_messages)
def __iter__(self):
self.used = True
if self._queued_messages:
self._loaded_messages.extend(self._queued_messages)
self._queued_messages = []
return iter(self._loaded_messages)
if not self.used:
self.used = True
if self._queued_messages:
self._loaded_messages.extend(self._queued_messages)
self._queued_messages = []
active_messages = []
for message in self._loaded_messages:
if isinstance(message, Message):
if message.active():
active_messages.append(message)
else:
active_messages.append(message)
for x in active_messages:
if isinstance(x, Message):
x.on_display()
# self._queued_messages.extend(m for m in active_messages if m not in self._queued_messages)
self._queued_messages = active_messages
return iter(self._queued_messages)
def __contains__(self, item):
return item in self._loaded_messages or item in self._queued_messages
@ -77,6 +130,11 @@ class BaseStorage:
def __repr__(self):
return f"<{self.__class__.__qualname__}: request={self.request!r}>"
def filter_store(self, messages, response, *args, **kwargs):
''' stores only active messages from given messages in storage '''
filtered_messages = [x for x in messages if x.active()]
return self._store(filtered_messages, response, *args, **kwargs)
@property
def _loaded_messages(self):
"""
@ -132,14 +190,16 @@ class BaseStorage:
If the backend has yet to be iterated, store previously stored messages
again. Otherwise, only store messages added after the last iteration.
"""
# if used or used and added, then _queued_messages contains all messages that should be saved
# if added, then save: all messages currently stored and added ones
self._prepare_messages(self._queued_messages)
if self.used:
return self._store(self._queued_messages, response)
return self.filter_store(self._queued_messages, response)
elif self.added_new:
messages = self._loaded_messages + self._queued_messages
return self._store(messages, response)
return self.filter_store(messages, response)
def add(self, level, message, extra_tags=""):
def add(self, level, message, extra_tags="", restrictions=[]):
"""
Queue a message to be stored.
@ -154,7 +214,7 @@ class BaseStorage:
return
# Add the message.
self.added_new = True
message = Message(level, message, extra_tags=extra_tags)
message = Message(level, message, extra_tags=extra_tags, restrictions=restrictions)
self._queued_messages.append(message)
def _get_level(self):

View File

@ -2,11 +2,11 @@ import binascii
import json
from django.conf import settings
from django.contrib.messages.storage.base import BaseStorage, Message
from django.contrib.messages.storage.base import BaseStorage, Message, RestrictionsContainer
from django.core import signing
from django.http import SimpleCookie
from django.utils.safestring import SafeData, mark_safe
from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction
class MessageEncoder(json.JSONEncoder):
"""
@ -22,6 +22,7 @@ class MessageEncoder(json.JSONEncoder):
message = [self.message_key, is_safedata, obj.level, obj.message]
if obj.extra_tags is not None:
message.append(obj.extra_tags)
message.append(obj.restrictions.to_json_obj())
return message
return super().default(obj)
@ -30,13 +31,19 @@ class MessageDecoder(json.JSONDecoder):
"""
Decode JSON that includes serialized ``Message`` instances.
"""
def create_message(self, *vars):
''' creates message on the basis of encoded data '''
args = vars[:-1]
restrictions = vars[-1]
restrictions = RestrictionsContainer.create_from_josn(restrictions)
return Message(*args, **{'restrictions': restrictions})
def process_messages(self, obj):
if isinstance(obj, list) and obj:
if obj[0] == MessageEncoder.message_key:
if obj[1]:
obj[3] = mark_safe(obj[3])
return Message(*obj[2:])
return self.create_message(*obj[2:])
return [self.process_messages(item) for item in obj]
if isinstance(obj, dict):
return {key: self.process_messages(value) for key, value in obj.items()}

View File

@ -4,5 +4,6 @@ from .api import get_messages
class MessagesTestMixin:
def assertMessages(self, response, expected_messages, *, ordered=True):
request_messages = list(get_messages(response.wsgi_request))
[i.on_display() for i in expected_messages]
assertion = self.assertEqual if ordered else self.assertCountEqual
assertion(request_messages, expected_messages)

View File

@ -1,6 +1,7 @@
from django.contrib.messages import Message, constants, get_level, set_level
from django.contrib.messages.api import MessageFailure
from django.contrib.messages.constants import DEFAULT_LEVELS
from django.contrib.messages.restrictions import AmountRestriction
from django.contrib.messages.storage import default_storage
from django.http import HttpRequest, HttpResponse
from django.test import modify_settings, override_settings
@ -158,7 +159,7 @@ class BaseTests:
response = self.client.post(add_url, data, follow=True)
self.assertRedirects(response, show_url)
self.assertIn("messages", response.context)
messages = [Message(self.levels[level], msg) for msg in data["messages"]]
messages = [Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) for msg in data["messages"]]
self.assertEqual(list(response.context["messages"]), messages)
for msg in data["messages"]:
self.assertContains(response, msg)
@ -201,7 +202,7 @@ class BaseTests:
messages = []
for level in ("debug", "info", "success", "warning", "error"):
messages.extend(
Message(self.levels[level], msg) for msg in data["messages"]
Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) for msg in data["messages"]
)
add_url = reverse("add_message", args=(level,))
self.client.post(add_url, data)

View File

@ -1,6 +1,7 @@
import random
from django.contrib.messages import constants
from django.contrib.messages.storage.base import Message
from django.contrib.messages.storage.fallback import CookieStorage, FallbackStorage
from django.test import SimpleTestCase
from django.utils.crypto import get_random_string
@ -90,8 +91,8 @@ class FallbackTests(BaseTests, SimpleTestCase):
cookie_storage = self.get_cookie_storage(storage)
session_storage = self.get_session_storage(storage)
# Set initial cookie and session data.
set_cookie_data(cookie_storage, ["cookie", CookieStorage.not_finished])
set_session_data(session_storage, ["session"])
set_cookie_data(cookie_storage, [Message(constants.INFO, "cookie"), CookieStorage.not_finished])
set_session_data(session_storage, [Message(constants.INFO, "session")])
# When updating, previously used but no longer needed backends are
# flushed.
response = self.get_response()

View File

@ -0,0 +1,68 @@
from django.test import TestCase
from django.contrib.messages.storage.base import Message
from django.contrib.messages import restrictions
from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction, time_provider
from django.contrib.messages import constants
from .time_provider import TestTimeProvider
class MessageTest(TestCase):
def setUp(self):
self.tp = restrictions.time_provider = TestTimeProvider()
def __check_active(self, msg, iterations):
"""
Reads msg given amount of iterations, and after each read
checks whether before each read message is active
"""
for i in range(iterations):
self.assertTrue(msg.active())
msg.on_display()
self.assertFalse(msg.active())
msg.on_display()
self.assertFalse(msg.active())
def test_active_default(self):
msg = Message(constants.INFO, "Test message")
self.__check_active(msg, 1)
def test_active_custom_one_amount_restriction(self):
msg = Message(constants.INFO, "Test message", restrictions = [AmountRestriction(3),])
self.__check_active(msg, 3)
def test_active_custom_few_amount_restriction(self):
msg = Message(constants.INFO, "Test message", restrictions = [AmountRestriction(x) for x in (2, 3, 5)])
self.__check_active(msg, 2)
def test_active_custom_one_time_restriction(self):
msg = Message(constants.INFO, "Test message", restrictions = [TimeRestriction(3),])
def check_iter():
for i in range(10): # iteration doesn't have direct impact for TimeRestriction
self.assertTrue(msg.active())
msg.on_display()
check_iter()
self.tp.set_act_time(3)
check_iter()
self.tp.set_act_time(4)
self.assertFalse(msg.active())
def test_mixed_restrictions(self):
get_restrictions = lambda:[TimeRestriction(3), TimeRestriction(5), AmountRestriction(2), AmountRestriction(3)]
get_msg = lambda:Message(constants.INFO, "Test message", restrictions = get_restrictions())
msg = get_msg()
for i in range(2):
self.assertTrue(msg.active())
msg.on_display()
self.assertFalse(msg.active())
msg = get_msg()
self.assertTrue(msg.active())
msg.on_display()
self.assertTrue(msg.active())
self.tp.set_act_time(4)
self.assertFalse(msg.active())
for i in range(10):
self.assertFalse(msg.active())
msg.on_display()

View File

@ -0,0 +1,31 @@
from django.test import TestCase
from django.contrib.messages import restrictions
from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction
from .time_provider import TestTimeProvider
restrictions.time_provider = TestTimeProvider ()
class RestrictionsTest(TestCase):
def __check_expired(self, amount_restriction, iterations_amount):
"""
Checks whether after iterations_amount of on_displayate given restriction will become expired
But before iterations_amount given amount_restriction must not indicate is_expired
"""
for i in range(iterations_amount):
self.assertFalse(amount_restriction.is_expired())
amount_restriction.on_display()
self.assertTrue(amount_restriction.is_expired())
def test_amount_restrictions(self):
res = AmountRestriction(4)
self.__check_expired(res, 4)
def test_amount_restrictions_invalid_argument(self):
self.assertRaises(AssertionError, AmountRestriction, -1)
def test_equal(self):
self.assertEqual(AmountRestriction(5), AmountRestriction(5))
self.assertFalse(AmountRestriction(1) == AmountRestriction(3))
self.assertEqual(TimeRestriction(2), TimeRestriction(2))
self.assertFalse(TimeRestriction(3) == TimeRestriction(4))

View File

@ -4,6 +4,8 @@ from unittest import mock
from django.conf import settings
from django.contrib.messages import Message, add_message, constants
from django.contrib.messages import restrictions
from django.contrib.messages.restrictions import AmountRestriction
from django.contrib.messages.storage import base
from django.contrib.messages.test import MessagesTestMixin
from django.test import RequestFactory, SimpleTestCase, override_settings
@ -110,11 +112,11 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase):
self.assertMessages(
response,
[
Message(constants.DEBUG, "DEBUG message."),
Message(constants.INFO, "INFO message."),
Message(constants.SUCCESS, "SUCCESS message."),
Message(constants.WARNING, "WARNING message."),
Message(constants.ERROR, "ERROR message."),
Message(constants.DEBUG, "DEBUG message.", restrictions=[AmountRestriction(2)]),
Message(constants.INFO, "INFO message.",restrictions=[AmountRestriction(2)]),
Message(constants.SUCCESS, "SUCCESS message.",restrictions=[AmountRestriction(2)]),
Message(constants.WARNING, "WARNING message.",restrictions=[AmountRestriction(2)]),
Message(constants.ERROR, "ERROR message.",restrictions=[AmountRestriction(2)]),
],
)
@ -147,10 +149,10 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase):
self.assertMessages(
response,
[
Message(constants.INFO, "INFO message.", "extra-info"),
Message(constants.SUCCESS, "SUCCESS message.", "extra-success"),
Message(constants.WARNING, "WARNING message.", "extra-warning"),
Message(constants.ERROR, "ERROR message.", "extra-error"),
Message(constants.INFO, "INFO message.", "extra-info", restrictions=[AmountRestriction(2)]),
Message(constants.SUCCESS, "SUCCESS message.", "extra-success", restrictions=[AmountRestriction(2)]),
Message(constants.WARNING, "WARNING message.", "extra-warning", restrictions=[AmountRestriction(2)]),
Message(constants.ERROR, "ERROR message.", "extra-error", restrictions=[AmountRestriction(2)]),
],
)
@ -158,15 +160,15 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase):
def test_custom_levelname(self):
response = FakeResponse()
add_message(response.wsgi_request, 42, "CUSTOM message.")
self.assertMessages(response, [Message(42, "CUSTOM message.")])
self.assertMessages(response, [Message(42, "CUSTOM message.", restrictions=[AmountRestriction(2)])])
def test_ordered(self):
response = FakeResponse()
add_message(response.wsgi_request, constants.INFO, "First message.")
add_message(response.wsgi_request, constants.WARNING, "Second message.")
expected_messages = [
Message(constants.WARNING, "Second message."),
Message(constants.INFO, "First message."),
Message(constants.WARNING, "Second message.", restrictions=[AmountRestriction(2)]),
Message(constants.INFO, "First message.", restrictions=[AmountRestriction(2)]),
]
self.assertMessages(response, expected_messages, ordered=False)
with self.assertRaisesMessage(AssertionError, "Lists differ: "):

View File

@ -0,0 +1,9 @@
class TestTimeProvider:
def __init__(self, act_time = 0):
self.act_time = act_time
def set_act_time(self, act_time):
self.act_time = act_time
def time(self):
return self.act_time
def inc_act_time(self):
self.act_time += 1

View File

@ -7,8 +7,8 @@ class DummyStorage:
def __init__(self):
self.store = []
def add(self, level, message, extra_tags=""):
self.store.append(Message(level, message, extra_tags))
def add(self, level, message, extra_tags="", restrictions=[]):
self.store.append(Message(level, message, extra_tags, restrictions))
def __iter__(self):
return iter(self.store)