diff --git a/django/contrib/messages/storage/cookie.py b/django/contrib/messages/storage/cookie.py index b493b207c8..2008b31843 100644 --- a/django/contrib/messages/storage/cookie.py +++ b/django/contrib/messages/storage/cookie.py @@ -144,20 +144,31 @@ class CookieStorage(BaseStorage): # adds its own overhead, which we must account for. cookie = SimpleCookie() # create outside the loop - def stored_length(val): - return len(cookie.value_encode(val)[1]) + def is_too_large_for_cookie(data): + return data and len(cookie.value_encode(data)[1]) > self.max_cookie_size - while encoded_data and stored_length(encoded_data) > self.max_cookie_size: - if remove_oldest: - unstored_messages.append(messages.pop(0)) - serialized_messages.pop(0) - else: - unstored_messages.insert(0, messages.pop()) - serialized_messages.pop() - encoded_data = self._encode_parts( - serialized_messages + [self.not_finished_json], - encode_empty=bool(unstored_messages), + def compute_msg(some_serialized_msg): + return self._encode_parts( + some_serialized_msg + [self.not_finished_json], + encode_empty=True, ) + + if is_too_large_for_cookie(encoded_data): + if remove_oldest: + idx = bisect_keep_right( + serialized_messages, + fn=lambda m: is_too_large_for_cookie(compute_msg(m)), + ) + unstored_messages = messages[:idx] + encoded_data = compute_msg(serialized_messages[idx:]) + else: + idx = bisect_keep_left( + serialized_messages, + fn=lambda m: is_too_large_for_cookie(compute_msg(m)), + ) + unstored_messages = messages[idx:] + encoded_data = compute_msg(serialized_messages[:idx]) + self._update_cookie(encoded_data, response) return unstored_messages @@ -201,3 +212,37 @@ class CookieStorage(BaseStorage): # with the data. self.used = True return None + + +def bisect_keep_left(a, fn): + """ + Find the index of the first element from the start of the array that + verifies the given condition. + The function is applied from the start of the array to the pivot. + """ + lo = 0 + hi = len(a) + while lo < hi: + mid = (lo + hi) // 2 + if fn(a[: mid + 1]): + hi = mid + else: + lo = mid + 1 + return lo + + +def bisect_keep_right(a, fn): + """ + Find the index of the first element from the end of the array that verifies + the given condition. + The function is applied from the pivot to the end of array. + """ + lo = 0 + hi = len(a) + while lo < hi: + mid = (lo + hi) // 2 + if fn(a[mid:]): + lo = mid + 1 + else: + hi = mid + return lo diff --git a/tests/messages_tests/test_cookie.py b/tests/messages_tests/test_cookie.py index 0fd2ed34d8..344df96886 100644 --- a/tests/messages_tests/test_cookie.py +++ b/tests/messages_tests/test_cookie.py @@ -1,5 +1,6 @@ import json import random +from unittest import TestCase from django.conf import settings from django.contrib.messages import constants @@ -8,6 +9,8 @@ from django.contrib.messages.storage.cookie import ( CookieStorage, MessageDecoder, MessageEncoder, + bisect_keep_left, + bisect_keep_right, ) from django.test import SimpleTestCase, override_settings from django.utils.crypto import get_random_string @@ -204,3 +207,20 @@ class CookieTests(BaseTests, SimpleTestCase): self.encode_decode("message", extra_tags=extra_tags).extra_tags, extra_tags, ) + + +class BisectTests(TestCase): + def test_bisect_keep_left(self): + self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 2), 2) + self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 0), 0) + self.assertEqual(bisect_keep_left([], fn=lambda arr: sum(arr) != 0), 0) + + def test_bisect_keep_right(self): + self.assertEqual(bisect_keep_right([1, 1, 1], fn=lambda arr: sum(arr) != 2), 1) + self.assertEqual( + bisect_keep_right([1, 1, 1, 1], fn=lambda arr: sum(arr) != 2), 2 + ) + self.assertEqual( + bisect_keep_right([1, 1, 1, 1, 1], fn=lambda arr: sum(arr) != 1), 4 + ) + self.assertEqual(bisect_keep_right([], fn=lambda arr: sum(arr) != 0), 0)