1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Refs #28948 -- Removed superfluous messages from cookie through bisect.

This commit is contained in:
David Wobrock 2023-03-08 17:39:26 +01:00 committed by Mariusz Felisiak
parent 9d0c878abf
commit 21757bbdcd
2 changed files with 77 additions and 12 deletions

View File

@ -144,20 +144,31 @@ class CookieStorage(BaseStorage):
# adds its own overhead, which we must account for. # adds its own overhead, which we must account for.
cookie = SimpleCookie() # create outside the loop cookie = SimpleCookie() # create outside the loop
def stored_length(val): def is_too_large_for_cookie(data):
return len(cookie.value_encode(val)[1]) return data and len(cookie.value_encode(data)[1]) > self.max_cookie_size
while encoded_data and stored_length(encoded_data) > self.max_cookie_size: def compute_msg(some_serialized_msg):
if remove_oldest: return self._encode_parts(
unstored_messages.append(messages.pop(0)) some_serialized_msg + [self.not_finished_json],
serialized_messages.pop(0) encode_empty=True,
else:
unstored_messages.insert(0, messages.pop())
serialized_messages.pop()
encoded_data = self._encode_parts(
serialized_messages + [self.not_finished_json],
encode_empty=bool(unstored_messages),
) )
if is_too_large_for_cookie(encoded_data):
if remove_oldest:
idx = bisect_keep_right(
serialized_messages,
fn=lambda m: is_too_large_for_cookie(compute_msg(m)),
)
unstored_messages = messages[:idx]
encoded_data = compute_msg(serialized_messages[idx:])
else:
idx = bisect_keep_left(
serialized_messages,
fn=lambda m: is_too_large_for_cookie(compute_msg(m)),
)
unstored_messages = messages[idx:]
encoded_data = compute_msg(serialized_messages[:idx])
self._update_cookie(encoded_data, response) self._update_cookie(encoded_data, response)
return unstored_messages return unstored_messages
@ -201,3 +212,37 @@ class CookieStorage(BaseStorage):
# with the data. # with the data.
self.used = True self.used = True
return None return None
def bisect_keep_left(a, fn):
"""
Find the index of the first element from the start of the array that
verifies the given condition.
The function is applied from the start of the array to the pivot.
"""
lo = 0
hi = len(a)
while lo < hi:
mid = (lo + hi) // 2
if fn(a[: mid + 1]):
hi = mid
else:
lo = mid + 1
return lo
def bisect_keep_right(a, fn):
"""
Find the index of the first element from the end of the array that verifies
the given condition.
The function is applied from the pivot to the end of array.
"""
lo = 0
hi = len(a)
while lo < hi:
mid = (lo + hi) // 2
if fn(a[mid:]):
lo = mid + 1
else:
hi = mid
return lo

View File

@ -1,5 +1,6 @@
import json import json
import random import random
from unittest import TestCase
from django.conf import settings from django.conf import settings
from django.contrib.messages import constants from django.contrib.messages import constants
@ -8,6 +9,8 @@ from django.contrib.messages.storage.cookie import (
CookieStorage, CookieStorage,
MessageDecoder, MessageDecoder,
MessageEncoder, MessageEncoder,
bisect_keep_left,
bisect_keep_right,
) )
from django.test import SimpleTestCase, override_settings from django.test import SimpleTestCase, override_settings
from django.utils.crypto import get_random_string from django.utils.crypto import get_random_string
@ -204,3 +207,20 @@ class CookieTests(BaseTests, SimpleTestCase):
self.encode_decode("message", extra_tags=extra_tags).extra_tags, self.encode_decode("message", extra_tags=extra_tags).extra_tags,
extra_tags, extra_tags,
) )
class BisectTests(TestCase):
def test_bisect_keep_left(self):
self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 2), 2)
self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 0), 0)
self.assertEqual(bisect_keep_left([], fn=lambda arr: sum(arr) != 0), 0)
def test_bisect_keep_right(self):
self.assertEqual(bisect_keep_right([1, 1, 1], fn=lambda arr: sum(arr) != 2), 1)
self.assertEqual(
bisect_keep_right([1, 1, 1, 1], fn=lambda arr: sum(arr) != 2), 2
)
self.assertEqual(
bisect_keep_right([1, 1, 1, 1, 1], fn=lambda arr: sum(arr) != 1), 4
)
self.assertEqual(bisect_keep_right([], fn=lambda arr: sum(arr) != 0), 0)