diff --git a/django/contrib/messages/api.py b/django/contrib/messages/api.py index b46cab6aa7..c3bacd6eeb 100644 --- a/django/contrib/messages/api.py +++ b/django/contrib/messages/api.py @@ -19,7 +19,9 @@ class MessageFailure(Exception): pass -def add_message(request, level, message, extra_tags="", fail_silently=False, restrictions=[]): +def add_message( + request, level, message, extra_tags="", fail_silently=False, restrictions=[] +): """ Attempt to add a message to the request using the 'messages' app. """ @@ -37,7 +39,7 @@ def add_message(request, level, message, extra_tags="", fail_silently=False, res "django.contrib.messages.middleware.MessageMiddleware" ) else: - return messages.add(level, message, extra_tags, restrictions = restrictions) + return messages.add(level, message, extra_tags, restrictions=restrictions) def get_messages(request): diff --git a/django/contrib/messages/restrictions.py b/django/contrib/messages/restrictions.py index 6fc9a07996..484ec89060 100644 --- a/django/contrib/messages/restrictions.py +++ b/django/contrib/messages/restrictions.py @@ -2,7 +2,7 @@ import time as time_provider class Restriction: - JSON_SEPARATOR = '@' + JSON_SEPARATOR = "@" def __init__(self, name): self.type = name @@ -11,7 +11,7 @@ class Restriction: """ indicates whether given message should be removed """ - raise NotImplementedError () + raise NotImplementedError() def on_display(self): """ @@ -33,18 +33,19 @@ class Restriction: raise NotImplementedError def __cmp__(self, other): - if self.__eq__(other): return 0 + if self.__eq__(other): + return 0 return -1 class TimeRestriction(Restriction): - JSON_TYPE_CODE = 't' + JSON_TYPE_CODE = "t" def __init__(self, seconds): """ seconds - expiration time since now """ - Restriction.__init__(self, 'time') + Restriction.__init__(self, "time") created = time_provider.time() self.expires = created + int(seconds) @@ -64,7 +65,7 @@ class TimeRestriction(Restriction): return self.expires def to_json(self): - return '%s%s%s' % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.expires) + return "%s%s%s" % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.expires) @classmethod def from_json_param(cls, expirity_time): @@ -73,13 +74,12 @@ class TimeRestriction(Restriction): return ret - class AmountRestriction(Restriction): - JSON_TYPE_CODE = 'a' + JSON_TYPE_CODE = "a" def __init__(self, amount): assert int(amount) >= 0 - Restriction.__init__(self, 'amount') + Restriction.__init__(self, "amount") self.can_be_shown = int(amount) def on_display(self): @@ -89,7 +89,9 @@ class AmountRestriction(Restriction): return int(self.can_be_shown) <= 0 def __eq__(self, other): - return self.type == other.type and not bool(self.can_be_shown ^ other.can_be_shown) + return self.type == other.type and not bool( + self.can_be_shown ^ other.can_be_shown + ) def __hash__(self): return self.can_be_shown @@ -98,7 +100,7 @@ class AmountRestriction(Restriction): return self.to_json() def to_json(self): - return '%s%s%s' % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.can_be_shown) + return "%s%s%s" % (self.JSON_TYPE_CODE, self.JSON_SEPARATOR, self.can_be_shown) @classmethod def from_json_param(cls, amount): diff --git a/django/contrib/messages/storage/base.py b/django/contrib/messages/storage/base.py index 0c9f8a030f..f55c67da67 100644 --- a/django/contrib/messages/storage/base.py +++ b/django/contrib/messages/storage/base.py @@ -1,20 +1,24 @@ from django.conf import settings -from django.contrib.messages import constants, utils, restrictions as res +from django.contrib.messages import constants +from django.contrib.messages import restrictions as res +from django.contrib.messages import utils from django.utils.functional import SimpleLazyObject LEVEL_TAGS = SimpleLazyObject(utils.get_level_tags) class RestrictionsContainer(list): - res_map = {res.AmountRestriction.JSON_TYPE_CODE: res.AmountRestriction, - res.TimeRestriction.JSON_TYPE_CODE: res.TimeRestriction} + res_map = { + res.AmountRestriction.JSON_TYPE_CODE: res.AmountRestriction, + res.TimeRestriction.JSON_TYPE_CODE: res.TimeRestriction, + } def to_json_obj(self): return [r.to_json() for r in self] @classmethod def create_from_josn(cls, enc_restrictions): - #set_trace() + # set_trace() ret = [] for r in enc_restrictions: restriction_type, values = r.split(res.Restriction.JSON_SEPARATOR) @@ -24,6 +28,7 @@ class RestrictionsContainer(list): def __eq__(self, other): return set(self) == set(other) + class Message: """ Represent an actual message that can be stored in any of the supported @@ -31,7 +36,7 @@ class Message: or template. """ - def __init__(self, level, message, extra_tags=None, restrictions = []): + def __init__(self, level, message, extra_tags=None, restrictions=[]): self.level = int(level) self.message = message self.extra_tags = extra_tags @@ -51,7 +56,11 @@ class Message: def __eq__(self, other): if not isinstance(other, Message): return NotImplemented - return self.level == other.level and self.message == other.message and self.restrictions == other.restrictions + return ( + self.level == other.level + and self.message == other.message + and self.restrictions == other.restrictions + ) def __str__(self): return str(self.message) @@ -73,13 +82,15 @@ class Message: def active(self): for r in self.restrictions: - if r.is_expired(): return False + if r.is_expired(): + return False return True def on_display(self): for r in self.restrictions: r.on_display() + class BaseStorage: """ This is the base backend for temporary message storage. @@ -96,12 +107,15 @@ class BaseStorage: super().__init__(*args, **kwargs) def __len__(self): - # in case that there was a call for render template which would cause iterating throught messages, - # and then (e.g. in some middleware, would be call for iterating through messages (e.g. by iterating of context['messages']) - # TODO: implement a way to access messages without affecting calling __iter__ method + # in case that there was a call for render template which would + # cause iterating throught messages, + # and then (e.g. in some middleware, would be call for iterating + # through messages (e.g. by iterating of context['messages']) + # TODO: implement a way to access messages without affecting + # calling __iter__ method all_msgs = set(self._loaded_messages + self._queued_messages) return len(all_msgs) - #return len(self._loaded_messages) + len(self._queued_messages) + # return len(self._loaded_messages) + len(self._queued_messages) def __iter__(self): if not self.used: @@ -120,7 +134,8 @@ class BaseStorage: for x in active_messages: if isinstance(x, Message): x.on_display() - # self._queued_messages.extend(m for m in active_messages if m not in self._queued_messages) + # self._queued_messages.extend(m for m in active_messages + # if m not in self._queued_messages) self._queued_messages = active_messages return iter(self._queued_messages) @@ -131,7 +146,7 @@ class BaseStorage: return f"<{self.__class__.__qualname__}: request={self.request!r}>" def filter_store(self, messages, response, *args, **kwargs): - ''' stores only active messages from given messages in storage ''' + """stores only active messages from given messages in storage""" filtered_messages = [x for x in messages if x.active()] return self._store(filtered_messages, response, *args, **kwargs) @@ -190,7 +205,8 @@ class BaseStorage: If the backend has yet to be iterated, store previously stored messages again. Otherwise, only store messages added after the last iteration. """ - # if used or used and added, then _queued_messages contains all messages that should be saved + # if used or used and added, + # then _queued_messages contains all messages that should be saved # if added, then save: all messages currently stored and added ones self._prepare_messages(self._queued_messages) if self.used: @@ -214,7 +230,9 @@ class BaseStorage: return # Add the message. self.added_new = True - message = Message(level, message, extra_tags=extra_tags, restrictions=restrictions) + message = Message( + level, message, extra_tags=extra_tags, restrictions=restrictions + ) self._queued_messages.append(message) def _get_level(self): diff --git a/django/contrib/messages/storage/cookie.py b/django/contrib/messages/storage/cookie.py index d381e629be..cbcbe9b22e 100644 --- a/django/contrib/messages/storage/cookie.py +++ b/django/contrib/messages/storage/cookie.py @@ -2,11 +2,15 @@ import binascii import json from django.conf import settings -from django.contrib.messages.storage.base import BaseStorage, Message, RestrictionsContainer +from django.contrib.messages.storage.base import ( + BaseStorage, + Message, + RestrictionsContainer, +) from django.core import signing from django.http import SimpleCookie from django.utils.safestring import SafeData, mark_safe -from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction + class MessageEncoder(json.JSONEncoder): """ @@ -31,12 +35,13 @@ class MessageDecoder(json.JSONDecoder): """ Decode JSON that includes serialized ``Message`` instances. """ + def create_message(self, *vars): - ''' creates message on the basis of encoded data ''' + """creates message on the basis of encoded data""" args = vars[:-1] restrictions = vars[-1] restrictions = RestrictionsContainer.create_from_josn(restrictions) - return Message(*args, **{'restrictions': restrictions}) + return Message(*args, **{"restrictions": restrictions}) def process_messages(self, obj): if isinstance(obj, list) and obj: diff --git a/tests/messages_tests/base.py b/tests/messages_tests/base.py index 1ca6324226..784d17a047 100644 --- a/tests/messages_tests/base.py +++ b/tests/messages_tests/base.py @@ -159,7 +159,10 @@ class BaseTests: response = self.client.post(add_url, data, follow=True) self.assertRedirects(response, show_url) self.assertIn("messages", response.context) - messages = [Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) for msg in data["messages"]] + messages = [ + Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) + for msg in data["messages"] + ] self.assertEqual(list(response.context["messages"]), messages) for msg in data["messages"]: self.assertContains(response, msg) @@ -202,7 +205,8 @@ class BaseTests: messages = [] for level in ("debug", "info", "success", "warning", "error"): messages.extend( - Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) for msg in data["messages"] + Message(self.levels[level], msg, restrictions=[AmountRestriction(0)]) + for msg in data["messages"] ) add_url = reverse("add_message", args=(level,)) self.client.post(add_url, data) diff --git a/tests/messages_tests/test_fallback.py b/tests/messages_tests/test_fallback.py index 5458f5c6f2..baa1545530 100644 --- a/tests/messages_tests/test_fallback.py +++ b/tests/messages_tests/test_fallback.py @@ -91,7 +91,10 @@ class FallbackTests(BaseTests, SimpleTestCase): cookie_storage = self.get_cookie_storage(storage) session_storage = self.get_session_storage(storage) # Set initial cookie and session data. - set_cookie_data(cookie_storage, [Message(constants.INFO, "cookie"), CookieStorage.not_finished]) + set_cookie_data( + cookie_storage, + [Message(constants.INFO, "cookie"), CookieStorage.not_finished], + ) set_session_data(session_storage, [Message(constants.INFO, "session")]) # When updating, previously used but no longer needed backends are # flushed. diff --git a/tests/messages_tests/test_message.py b/tests/messages_tests/test_message.py index f9acd8f52f..beb548db7e 100644 --- a/tests/messages_tests/test_message.py +++ b/tests/messages_tests/test_message.py @@ -1,8 +1,10 @@ -from django.test import TestCase +from django.contrib.messages import constants, restrictions +from django.contrib.messages.restrictions import ( + AmountRestriction, + TimeRestriction, +) from django.contrib.messages.storage.base import Message -from django.contrib.messages import restrictions -from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction, time_provider -from django.contrib.messages import constants +from django.test import TestCase from .time_provider import TestTimeProvider @@ -10,6 +12,7 @@ from .time_provider import TestTimeProvider class MessageTest(TestCase): def setUp(self): self.tp = restrictions.time_provider = TestTimeProvider() + def __check_active(self, msg, iterations): """ Reads msg given amount of iterations, and after each read @@ -27,29 +30,57 @@ class MessageTest(TestCase): self.__check_active(msg, 1) def test_active_custom_one_amount_restriction(self): - msg = Message(constants.INFO, "Test message", restrictions = [AmountRestriction(3),]) + msg = Message( + constants.INFO, + "Test message", + restrictions=[ + AmountRestriction(3), + ], + ) self.__check_active(msg, 3) def test_active_custom_few_amount_restriction(self): - msg = Message(constants.INFO, "Test message", restrictions = [AmountRestriction(x) for x in (2, 3, 5)]) + msg = Message( + constants.INFO, + "Test message", + restrictions=[AmountRestriction(x) for x in (2, 3, 5)], + ) self.__check_active(msg, 2) def test_active_custom_one_time_restriction(self): - msg = Message(constants.INFO, "Test message", restrictions = [TimeRestriction(3),]) + msg = Message( + constants.INFO, + "Test message", + restrictions=[ + TimeRestriction(3), + ], + ) + def check_iter(): - for i in range(10): # iteration doesn't have direct impact for TimeRestriction + for i in range( + 10 + ): # iteration doesn't have direct impact for TimeRestriction self.assertTrue(msg.active()) msg.on_display() + check_iter() self.tp.set_act_time(3) check_iter() self.tp.set_act_time(4) self.assertFalse(msg.active()) - def test_mixed_restrictions(self): - get_restrictions = lambda:[TimeRestriction(3), TimeRestriction(5), AmountRestriction(2), AmountRestriction(3)] - get_msg = lambda:Message(constants.INFO, "Test message", restrictions = get_restrictions()) + def get_restrictions(): + return [ + TimeRestriction(3), + TimeRestriction(5), + AmountRestriction(2), + AmountRestriction(3), + ] + def get_msg(): + return Message( + constants.INFO, "Test message", restrictions=get_restrictions() + ) msg = get_msg() for i in range(2): diff --git a/tests/messages_tests/test_restrictions.py b/tests/messages_tests/test_restrictions.py index 9c52e28af1..9da8e88908 100644 --- a/tests/messages_tests/test_restrictions.py +++ b/tests/messages_tests/test_restrictions.py @@ -1,16 +1,19 @@ -from django.test import TestCase from django.contrib.messages import restrictions from django.contrib.messages.restrictions import AmountRestriction, TimeRestriction +from django.test import TestCase from .time_provider import TestTimeProvider -restrictions.time_provider = TestTimeProvider () +restrictions.time_provider = TestTimeProvider() + class RestrictionsTest(TestCase): def __check_expired(self, amount_restriction, iterations_amount): """ - Checks whether after iterations_amount of on_displayate given restriction will become expired - But before iterations_amount given amount_restriction must not indicate is_expired + Checks whether after iterations_amount of on_display given + restriction will become expired + But before iterations_amount given amount_restriction must + not indicate is_expired """ for i in range(iterations_amount): self.assertFalse(amount_restriction.is_expired()) diff --git a/tests/messages_tests/tests.py b/tests/messages_tests/tests.py index 0279927adf..679679e7d2 100644 --- a/tests/messages_tests/tests.py +++ b/tests/messages_tests/tests.py @@ -4,7 +4,6 @@ from unittest import mock from django.conf import settings from django.contrib.messages import Message, add_message, constants -from django.contrib.messages import restrictions from django.contrib.messages.restrictions import AmountRestriction from django.contrib.messages.storage import base from django.contrib.messages.test import MessagesTestMixin @@ -112,11 +111,29 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase): self.assertMessages( response, [ - Message(constants.DEBUG, "DEBUG message.", restrictions=[AmountRestriction(2)]), - Message(constants.INFO, "INFO message.",restrictions=[AmountRestriction(2)]), - Message(constants.SUCCESS, "SUCCESS message.",restrictions=[AmountRestriction(2)]), - Message(constants.WARNING, "WARNING message.",restrictions=[AmountRestriction(2)]), - Message(constants.ERROR, "ERROR message.",restrictions=[AmountRestriction(2)]), + Message( + constants.DEBUG, + "DEBUG message.", + restrictions=[AmountRestriction(2)], + ), + Message( + constants.INFO, "INFO message.", restrictions=[AmountRestriction(2)] + ), + Message( + constants.SUCCESS, + "SUCCESS message.", + restrictions=[AmountRestriction(2)], + ), + Message( + constants.WARNING, + "WARNING message.", + restrictions=[AmountRestriction(2)], + ), + Message( + constants.ERROR, + "ERROR message.", + restrictions=[AmountRestriction(2)], + ), ], ) @@ -149,10 +166,30 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase): self.assertMessages( response, [ - Message(constants.INFO, "INFO message.", "extra-info", restrictions=[AmountRestriction(2)]), - Message(constants.SUCCESS, "SUCCESS message.", "extra-success", restrictions=[AmountRestriction(2)]), - Message(constants.WARNING, "WARNING message.", "extra-warning", restrictions=[AmountRestriction(2)]), - Message(constants.ERROR, "ERROR message.", "extra-error", restrictions=[AmountRestriction(2)]), + Message( + constants.INFO, + "INFO message.", + "extra-info", + restrictions=[AmountRestriction(2)], + ), + Message( + constants.SUCCESS, + "SUCCESS message.", + "extra-success", + restrictions=[AmountRestriction(2)], + ), + Message( + constants.WARNING, + "WARNING message.", + "extra-warning", + restrictions=[AmountRestriction(2)], + ), + Message( + constants.ERROR, + "ERROR message.", + "extra-error", + restrictions=[AmountRestriction(2)], + ), ], ) @@ -160,15 +197,24 @@ class AssertMessagesTest(MessagesTestMixin, SimpleTestCase): def test_custom_levelname(self): response = FakeResponse() add_message(response.wsgi_request, 42, "CUSTOM message.") - self.assertMessages(response, [Message(42, "CUSTOM message.", restrictions=[AmountRestriction(2)])]) + self.assertMessages( + response, + [Message(42, "CUSTOM message.", restrictions=[AmountRestriction(2)])], + ) def test_ordered(self): response = FakeResponse() add_message(response.wsgi_request, constants.INFO, "First message.") add_message(response.wsgi_request, constants.WARNING, "Second message.") expected_messages = [ - Message(constants.WARNING, "Second message.", restrictions=[AmountRestriction(2)]), - Message(constants.INFO, "First message.", restrictions=[AmountRestriction(2)]), + Message( + constants.WARNING, + "Second message.", + restrictions=[AmountRestriction(2)], + ), + Message( + constants.INFO, "First message.", restrictions=[AmountRestriction(2)] + ), ] self.assertMessages(response, expected_messages, ordered=False) with self.assertRaisesMessage(AssertionError, "Lists differ: "): diff --git a/tests/messages_tests/time_provider.py b/tests/messages_tests/time_provider.py index 007b9e81fa..496249a640 100644 --- a/tests/messages_tests/time_provider.py +++ b/tests/messages_tests/time_provider.py @@ -1,9 +1,12 @@ class TestTimeProvider: - def __init__(self, act_time = 0): + def __init__(self, act_time=0): self.act_time = act_time + def set_act_time(self, act_time): self.act_time = act_time + def time(self): return self.act_time + def inc_act_time(self): self.act_time += 1