From 45b94ce0777b13bccc640decbb635044c6e5d2bf Mon Sep 17 00:00:00 2001 From: Andrew Miller Date: Wed, 19 Jun 2024 20:37:38 +0100 Subject: [PATCH] Initial application of last patch from #13376 and made tests pass --- django/contrib/messages/api.py | 4 +- django/contrib/messages/restrictions.py | 105 ++++++++++++++++++++++ django/contrib/messages/storage/base.py | 86 +++++++++++++++--- django/contrib/messages/storage/cookie.py | 13 ++- django/contrib/messages/test.py | 1 + tests/messages_tests/base.py | 5 +- tests/messages_tests/test_fallback.py | 5 +- tests/messages_tests/test_message.py | 68 ++++++++++++++ tests/messages_tests/test_restrictions.py | 31 +++++++ tests/messages_tests/tests.py | 26 +++--- tests/messages_tests/time_provider.py | 9 ++ tests/messages_tests/utils.py | 4 +- 12 files changed, 321 insertions(+), 36 deletions(-) create mode 100644 django/contrib/messages/restrictions.py create mode 100644 tests/messages_tests/test_message.py create mode 100644 tests/messages_tests/test_restrictions.py create mode 100644 tests/messages_tests/time_provider.py diff --git a/django/contrib/messages/api.py b/django/contrib/messages/api.py index 7a67e8b4b0..b46cab6aa7 100644 --- a/django/contrib/messages/api.py +++ b/django/contrib/messages/api.py @@ -19,7 +19,7 @@ class MessageFailure(Exception): pass -def add_message(request, level, message, extra_tags="", fail_silently=False): +def add_message(request, level, message, extra_tags="", fail_silently=False, restrictions=[]): """ Attempt to add a message to the request using the 'messages' app. """ @@ -37,7 +37,7 @@ def add_message(request, level, message, extra_tags="", fail_silently=False): "django.contrib.messages.middleware.MessageMiddleware" ) else: - return messages.add(level, message, extra_tags) + return messages.add(level, message, extra_tags, restrictions = restrictions) def get_messages(request): diff --git a/django/contrib/messages/restrictions.py b/django/contrib/messages/restrictions.py new file mode 100644 index 0000000000..6fc9a07996 --- /dev/null +++ b/django/contrib/messages/restrictions.py @@ -0,0 +1,105 @@ +import time as time_provider + + +class Restriction: + JSON_SEPARATOR = '@' + + def __init__(self, name): + self.type = name + + def is_expired(self): + """ + indicates whether given message should be removed + """ + raise NotImplementedError () + + def on_display(self): + """ + called when iterated - does nothing by default + """ + pass + + def to_json(self): + """ + returns json representation of restriction + """ + raise NotImplementedError + + @classmethod + def from_json_param(cls, *args): + """ + returns restriction on the basis of data encoded in json + """ + raise NotImplementedError + + def __cmp__(self, other): + if self.__eq__(other): return 0 + return -1 + + +class TimeRestriction(Restriction): + JSON_TYPE_CODE = 't' + + def __init__(self, seconds): + """ + seconds - expiration time since now + """ + Restriction.__init__(self, 'time') + created = time_provider.time() + self.expires = created + int(seconds) + + def set_expirity_time(self, expiration_time): + """ + Sets expilcity expiration time + """ + self.expires = int(expiration_time) + + def is_expired(self): + return self.expires < time_provider.time() + + def __eq__(self, other): + return self.type == other.type and not bool(self.expires ^ other.expires) + + def __hash__(self): + return self.expires + + def to_json(self): + return '%s%s%s' % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.expires) + + @classmethod + def from_json_param(cls, expirity_time): + ret = TimeRestriction(0) + ret.set_expirity_time(expirity_time) + return ret + + + +class AmountRestriction(Restriction): + JSON_TYPE_CODE = 'a' + + def __init__(self, amount): + assert int(amount) >= 0 + Restriction.__init__(self, 'amount') + self.can_be_shown = int(amount) + + def on_display(self): + self.can_be_shown -= 1 + + def is_expired(self): + return int(self.can_be_shown) <= 0 + + def __eq__(self, other): + return self.type == other.type and not bool(self.can_be_shown ^ other.can_be_shown) + + def __hash__(self): + return self.can_be_shown + + def __repr__(self): + return self.to_json() + + def to_json(self): + return '%s%s%s' % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.can_be_shown) + + @classmethod + def from_json_param(cls, amount): + return AmountRestriction(amount) diff --git a/django/contrib/messages/storage/base.py b/django/contrib/messages/storage/base.py index 5d89acfe69..0c9f8a030f 100644 --- a/django/contrib/messages/storage/base.py +++ b/django/contrib/messages/storage/base.py @@ -1,10 +1,29 @@ from django.conf import settings -from django.contrib.messages import constants, utils +from django.contrib.messages import constants, utils, restrictions as res from django.utils.functional import SimpleLazyObject LEVEL_TAGS = SimpleLazyObject(utils.get_level_tags) +class RestrictionsContainer(list): + res_map = {res.AmountRestriction.JSON_TYPE_CODE: res.AmountRestriction, + res.TimeRestriction.JSON_TYPE_CODE: res.TimeRestriction} + + def to_json_obj(self): + return [r.to_json() for r in self] + + @classmethod + def create_from_josn(cls, enc_restrictions): + #set_trace() + ret = [] + for r in enc_restrictions: + restriction_type, values = r.split(res.Restriction.JSON_SEPARATOR) + ret.append(cls.res_map[restriction_type].from_json_param(values)) + return RestrictionsContainer(ret) + + def __eq__(self, other): + return set(self) == set(other) + class Message: """ Represent an actual message that can be stored in any of the supported @@ -12,10 +31,14 @@ class Message: or template. """ - def __init__(self, level, message, extra_tags=None): + def __init__(self, level, message, extra_tags=None, restrictions = []): self.level = int(level) self.message = message self.extra_tags = extra_tags + self.restrictions = restrictions or list([res.AmountRestriction(1)]) + self.restrictions = RestrictionsContainer(self.restrictions) + # if not given any restriction - one show by default + # todo: self.restrictions = def _prepare(self): """ @@ -28,11 +51,14 @@ class Message: def __eq__(self, other): if not isinstance(other, Message): return NotImplemented - return self.level == other.level and self.message == other.message + return self.level == other.level and self.message == other.message and self.restrictions == other.restrictions def __str__(self): return str(self.message) + def __hash__(self): + return hash(self.message) + def __repr__(self): extra_tags = f", extra_tags={self.extra_tags!r}" if self.extra_tags else "" return f"Message(level={self.level}, message={self.message!r}{extra_tags})" @@ -45,6 +71,14 @@ class Message: def level_tag(self): return LEVEL_TAGS.get(self.level, "") + def active(self): + for r in self.restrictions: + if r.is_expired(): return False + return True + + def on_display(self): + for r in self.restrictions: + r.on_display() class BaseStorage: """ @@ -62,14 +96,33 @@ class BaseStorage: super().__init__(*args, **kwargs) def __len__(self): - return len(self._loaded_messages) + len(self._queued_messages) + # in case that there was a call for render template which would cause iterating throught messages, + # and then (e.g. in some middleware, would be call for iterating through messages (e.g. by iterating of context['messages']) + # TODO: implement a way to access messages without affecting calling __iter__ method + all_msgs = set(self._loaded_messages + self._queued_messages) + return len(all_msgs) + #return len(self._loaded_messages) + len(self._queued_messages) def __iter__(self): - self.used = True - if self._queued_messages: - self._loaded_messages.extend(self._queued_messages) - self._queued_messages = [] - return iter(self._loaded_messages) + if not self.used: + self.used = True + if self._queued_messages: + self._loaded_messages.extend(self._queued_messages) + self._queued_messages = [] + + active_messages = [] + for message in self._loaded_messages: + if isinstance(message, Message): + if message.active(): + active_messages.append(message) + else: + active_messages.append(message) + for x in active_messages: + if isinstance(x, Message): + x.on_display() + # self._queued_messages.extend(m for m in active_messages if m not in self._queued_messages) + self._queued_messages = active_messages + return iter(self._queued_messages) def __contains__(self, item): return item in self._loaded_messages or item in self._queued_messages @@ -77,6 +130,11 @@ class BaseStorage: def __repr__(self): return f"<{self.__class__.__qualname__}: request={self.request!r}>" + def filter_store(self, messages, response, *args, **kwargs): + ''' stores only active messages from given messages in storage ''' + filtered_messages = [x for x in messages if x.active()] + return self._store(filtered_messages, response, *args, **kwargs) + @property def _loaded_messages(self): """ @@ -132,14 +190,16 @@ class BaseStorage: If the backend has yet to be iterated, store previously stored messages again. Otherwise, only store messages added after the last iteration. """ + # if used or used and added, then _queued_messages contains all messages that should be saved + # if added, then save: all messages currently stored and added ones self._prepare_messages(self._queued_messages) if self.used: - return self._store(self._queued_messages, response) + return self.filter_store(self._queued_messages, response) elif self.added_new: messages = self._loaded_messages + self._queued_messages - return self._store(messages, response) + return self.filter_store(messages, response) - def add(self, level, message, extra_tags=""): + def add(self, level, message, extra_tags="", restrictions=[]): """ Queue a message to be stored. @@ -154,7 +214,7 @@ class BaseStorage: return # Add the message. self.added_new = True - message = Message(level, message, extra_tags=extra_tags) + message = Message(level, message, extra_tags=extra_tags, restrictions=restrictions) self._queued_messages.append(message) def _get_level(self): diff --git a/django/contrib/messages/storage/cookie.py b/django/contrib/messages/storage/cookie.py index 2008b31843..d381e629be 100644 --- a/django/contrib/messages/storage/cookie.py +++ b/django/contrib/messages/storage/cookie.py @@ -2,11 +2,11 @@ import binascii import json from django.conf import settings -from django.contrib.messages.storage.base import BaseStorage, Message +from django.contrib.messages.storage.base import BaseStorage, Message, RestrictionsContainer from django.core import signing from django.http import SimpleCookie from django.utils.safestring import SafeData, mark_safe - +from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction class MessageEncoder(json.JSONEncoder): """ @@ -22,6 +22,7 @@ class MessageEncoder(json.JSONEncoder): message = [self.message_key, is_safedata, obj.level, obj.message] if obj.extra_tags is not None: message.append(obj.extra_tags) + message.append(obj.restrictions.to_json_obj()) return message return super().default(obj) @@ -30,13 +31,19 @@ class MessageDecoder(json.JSONDecoder): """ Decode JSON that includes serialized ``Message`` instances. """ + def create_message(self, *vars): + ''' creates message on the basis of encoded data ''' + args = vars[:-1] + restrictions = vars[-1] + restrictions = RestrictionsContainer.create_from_josn(restrictions) + return Message(*args, **{'restrictions': restrictions}) def process_messages(self, obj): if isinstance(obj, list) and obj: if obj[0] == MessageEncoder.message_key: if obj[1]: obj[3] = mark_safe(obj[3]) - return Message(*obj[2:]) + return self.create_message(*obj[2:]) return [self.process_messages(item) for item in obj] if isinstance(obj, dict): return {key: self.process_messages(value) for key, value in obj.items()} diff --git a/django/contrib/messages/test.py b/django/contrib/messages/test.py index 3a69f54585..5349655f6d 100644 --- a/django/contrib/messages/test.py +++ b/django/contrib/messages/test.py @@ -4,5 +4,6 @@ from .api import get_messages class MessagesTestMixin: def assertMessages(self, response, expected_messages, *, ordered=True): request_messages = list(get_messages(response.wsgi_request)) + [i.on_display() for i in expected_messages] assertion = self.assertEqual if ordered else self.assertCountEqual assertion(request_messages, expected_messages) diff --git a/tests/messages_tests/base.py b/tests/messages_tests/base.py index ce4b2acac8..1ca6324226 100644 --- a/tests/messages_tests/base.py +++ b/tests/messages_tests/base.py @@ -1,6 +1,7 @@ from django.contrib.messages import Message, constants, get_level, set_level from django.contrib.messages.api import MessageFailure from django.contrib.messages.constants import DEFAULT_LEVELS +from django.contrib.messages.restrictions import AmountRestriction from django.contrib.messages.storage import default_storage from django.http import HttpRequest, HttpResponse from django.test import modify_settings, override_settings @@ -158,7 +159,7 @@ class BaseTests: response = self.client.post(add_url, data, follow=True) self.assertRedirects(response, show_url) self.assertIn("messages", response.context) - messages = [Message(self.levels[level], msg) for msg in data["messages"]] + messages = [Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) for msg in data["messages"]] self.assertEqual(list(response.context["messages"]), messages) for msg in data["messages"]: self.assertContains(response, msg) @@ -201,7 +202,7 @@ class BaseTests: messages = [] for level in ("debug", "info", "success", "warning", "error"): messages.extend( - Message(self.levels[level], msg) for msg in data["messages"] + Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) for msg in data["messages"] ) add_url = reverse("add_message", args=(level,)) self.client.post(add_url, data) diff --git a/tests/messages_tests/test_fallback.py b/tests/messages_tests/test_fallback.py index 7a335114c5..5458f5c6f2 100644 --- a/tests/messages_tests/test_fallback.py +++ b/tests/messages_tests/test_fallback.py @@ -1,6 +1,7 @@ import random from django.contrib.messages import constants +from django.contrib.messages.storage.base import Message from django.contrib.messages.storage.fallback import CookieStorage, FallbackStorage from django.test import SimpleTestCase from django.utils.crypto import get_random_string @@ -90,8 +91,8 @@ class FallbackTests(BaseTests, SimpleTestCase): cookie_storage = self.get_cookie_storage(storage) session_storage = self.get_session_storage(storage) # Set initial cookie and session data. - set_cookie_data(cookie_storage, ["cookie", CookieStorage.not_finished]) - set_session_data(session_storage, ["session"]) + set_cookie_data(cookie_storage, [Message(constants.INFO, "cookie"), CookieStorage.not_finished]) + set_session_data(session_storage, [Message(constants.INFO, "session")]) # When updating, previously used but no longer needed backends are # flushed. response = self.get_response() diff --git a/tests/messages_tests/test_message.py b/tests/messages_tests/test_message.py new file mode 100644 index 0000000000..f9acd8f52f --- /dev/null +++ b/tests/messages_tests/test_message.py @@ -0,0 +1,68 @@ +from django.test import TestCase +from django.contrib.messages.storage.base import Message +from django.contrib.messages import restrictions +from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction, time_provider +from django.contrib.messages import constants + +from .time_provider import TestTimeProvider + + +class MessageTest(TestCase): + def setUp(self): + self.tp = restrictions.time_provider = TestTimeProvider() + def __check_active(self, msg, iterations): + """ + Reads msg given amount of iterations, and after each read + checks whether before each read message is active + """ + for i in range(iterations): + self.assertTrue(msg.active()) + msg.on_display() + self.assertFalse(msg.active()) + msg.on_display() + self.assertFalse(msg.active()) + + def test_active_default(self): + msg = Message(constants.INFO, "Test message") + self.__check_active(msg, 1) + + def test_active_custom_one_amount_restriction(self): + msg = Message(constants.INFO, "Test message", restrictions = [AmountRestriction(3),]) + self.__check_active(msg, 3) + + def test_active_custom_few_amount_restriction(self): + msg = Message(constants.INFO, "Test message", restrictions = [AmountRestriction(x) for x in (2, 3, 5)]) + self.__check_active(msg, 2) + + def test_active_custom_one_time_restriction(self): + msg = Message(constants.INFO, "Test message", restrictions = [TimeRestriction(3),]) + def check_iter(): + for i in range(10): # iteration doesn't have direct impact for TimeRestriction + self.assertTrue(msg.active()) + msg.on_display() + check_iter() + self.tp.set_act_time(3) + check_iter() + self.tp.set_act_time(4) + self.assertFalse(msg.active()) + + + def test_mixed_restrictions(self): + get_restrictions = lambda:[TimeRestriction(3), TimeRestriction(5), AmountRestriction(2), AmountRestriction(3)] + get_msg = lambda:Message(constants.INFO, "Test message", restrictions = get_restrictions()) + + msg = get_msg() + for i in range(2): + self.assertTrue(msg.active()) + msg.on_display() + self.assertFalse(msg.active()) + + msg = get_msg() + self.assertTrue(msg.active()) + msg.on_display() + self.assertTrue(msg.active()) + self.tp.set_act_time(4) + self.assertFalse(msg.active()) + for i in range(10): + self.assertFalse(msg.active()) + msg.on_display() diff --git a/tests/messages_tests/test_restrictions.py b/tests/messages_tests/test_restrictions.py new file mode 100644 index 0000000000..9c52e28af1 --- /dev/null +++ b/tests/messages_tests/test_restrictions.py @@ -0,0 +1,31 @@ +from django.test import TestCase +from django.contrib.messages import restrictions +from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction + +from .time_provider import TestTimeProvider + +restrictions.time_provider = TestTimeProvider () + +class RestrictionsTest(TestCase): + def __check_expired(self, amount_restriction, iterations_amount): + """ + Checks whether after iterations_amount of on_displayate given restriction will become expired + But before iterations_amount given amount_restriction must not indicate is_expired + """ + for i in range(iterations_amount): + self.assertFalse(amount_restriction.is_expired()) + amount_restriction.on_display() + self.assertTrue(amount_restriction.is_expired()) + + def test_amount_restrictions(self): + res = AmountRestriction(4) + self.__check_expired(res, 4) + + def test_amount_restrictions_invalid_argument(self): + self.assertRaises(AssertionError, AmountRestriction, -1) + + def test_equal(self): + self.assertEqual(AmountRestriction(5), AmountRestriction(5)) + self.assertFalse(AmountRestriction(1) == AmountRestriction(3)) + self.assertEqual(TimeRestriction(2), TimeRestriction(2)) + self.assertFalse(TimeRestriction(3) == TimeRestriction(4)) diff --git a/tests/messages_tests/tests.py b/tests/messages_tests/tests.py index 19aeee9a08..0279927adf 100644 --- a/tests/messages_tests/tests.py +++ b/tests/messages_tests/tests.py @@ -4,6 +4,8 @@ from unittest import mock from django.conf import settings from django.contrib.messages import Message, add_message, constants +from django.contrib.messages import restrictions +from django.contrib.messages.restrictions import AmountRestriction from django.contrib.messages.storage import base from django.contrib.messages.test import MessagesTestMixin from django.test import RequestFactory, SimpleTestCase, override_settings @@ -110,11 +112,11 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase): self.assertMessages( response, [ - Message(constants.DEBUG, "DEBUG message."), - Message(constants.INFO, "INFO message."), - Message(constants.SUCCESS, "SUCCESS message."), - Message(constants.WARNING, "WARNING message."), - Message(constants.ERROR, "ERROR message."), + Message(constants.DEBUG, "DEBUG message.", restrictions=[AmountRestriction(2)]), + Message(constants.INFO, "INFO message.",restrictions=[AmountRestriction(2)]), + Message(constants.SUCCESS, "SUCCESS message.",restrictions=[AmountRestriction(2)]), + Message(constants.WARNING, "WARNING message.",restrictions=[AmountRestriction(2)]), + Message(constants.ERROR, "ERROR message.",restrictions=[AmountRestriction(2)]), ], ) @@ -147,10 +149,10 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase): self.assertMessages( response, [ - Message(constants.INFO, "INFO message.", "extra-info"), - Message(constants.SUCCESS, "SUCCESS message.", "extra-success"), - Message(constants.WARNING, "WARNING message.", "extra-warning"), - Message(constants.ERROR, "ERROR message.", "extra-error"), + Message(constants.INFO, "INFO message.", "extra-info", restrictions=[AmountRestriction(2)]), + Message(constants.SUCCESS, "SUCCESS message.", "extra-success", restrictions=[AmountRestriction(2)]), + Message(constants.WARNING, "WARNING message.", "extra-warning", restrictions=[AmountRestriction(2)]), + Message(constants.ERROR, "ERROR message.", "extra-error", restrictions=[AmountRestriction(2)]), ], ) @@ -158,15 +160,15 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase): def test_custom_levelname(self): response = FakeResponse() add_message(response.wsgi_request, 42, "CUSTOM message.") - self.assertMessages(response, [Message(42, "CUSTOM message.")]) + self.assertMessages(response, [Message(42, "CUSTOM message.", restrictions=[AmountRestriction(2)])]) def test_ordered(self): response = FakeResponse() add_message(response.wsgi_request, constants.INFO, "First message.") add_message(response.wsgi_request, constants.WARNING, "Second message.") expected_messages = [ - Message(constants.WARNING, "Second message."), - Message(constants.INFO, "First message."), + Message(constants.WARNING, "Second message.", restrictions=[AmountRestriction(2)]), + Message(constants.INFO, "First message.", restrictions=[AmountRestriction(2)]), ] self.assertMessages(response, expected_messages, ordered=False) with self.assertRaisesMessage(AssertionError, "Lists differ: "): diff --git a/tests/messages_tests/time_provider.py b/tests/messages_tests/time_provider.py new file mode 100644 index 0000000000..007b9e81fa --- /dev/null +++ b/tests/messages_tests/time_provider.py @@ -0,0 +1,9 @@ +class TestTimeProvider: + def __init__(self, act_time = 0): + self.act_time = act_time + def set_act_time(self, act_time): + self.act_time = act_time + def time(self): + return self.act_time + def inc_act_time(self): + self.act_time += 1 diff --git a/tests/messages_tests/utils.py b/tests/messages_tests/utils.py index f00ea4585b..3305445a33 100644 --- a/tests/messages_tests/utils.py +++ b/tests/messages_tests/utils.py @@ -7,8 +7,8 @@ class DummyStorage: def __init__(self): self.store = [] - def add(self, level, message, extra_tags=""): - self.store.append(Message(level, message, extra_tags)) + def add(self, level, message, extra_tags="", restrictions=[]): + self.store.append(Message(level, message, extra_tags, restrictions)) def __iter__(self): return iter(self.store)