1
0
mirror of https://github.com/django/django.git synced 2025-03-31 11:37:06 +00:00

make MultiValueDict mutable or immutable

This commit is contained in:
Ben Cail 2024-03-18 14:34:59 -04:00
parent 1baf829382
commit 3e84dfbd24
3 changed files with 33 additions and 114 deletions

View File

@ -18,7 +18,7 @@ from django.core.exceptions import (
TooManyFilesSent,
)
from django.core.files.uploadhandler import SkipFile, StopFutureHandlers, StopUpload
from django.utils.datastructures import ImmutableMultiValueDict, MultiValueDict
from django.utils.datastructures import MultiValueDict
from django.utils.encoding import force_str
from django.utils.http import parse_header_parameters
from django.utils.regex_helper import _lazy_re_compile
@ -50,7 +50,7 @@ class MultiPartParser:
An RFC 7578 multipart/form-data parser.
``MultiPartParser.parse()`` reads the input stream in ``chunk_size`` chunks
and returns a tuple of ``(QueryDict(POST), ImmutableMultiValueDict(FILES))``.
and returns a tuple of ``(QueryDict(POST), MultiValueDict(FILES))``.
"""
boundary_re = _lazy_re_compile(r"[ -~]{0,200}[!-~]")
@ -132,7 +132,7 @@ class MultiPartParser:
def _parse(self):
"""
Parse the POST data and break it into a FILES ImmutableMultiValueDict
Parse the POST data and break it into a FILES MultiValueDict
and a POST QueryDict.
Return a tuple containing the POST and FILES dictionary, respectively.
@ -145,7 +145,7 @@ class MultiPartParser:
# HTTP spec says that Content-Length >= 0 is valid
# handling content-length == 0 before continuing
if self._content_length == 0:
return QueryDict(encoding=self._encoding), ImmutableMultiValueDict()
return QueryDict(encoding=self._encoding), MultiValueDict(mutable=False)
# See if any of the handlers take care of the parsing.
# This allows overriding everything if need be.
@ -163,7 +163,7 @@ class MultiPartParser:
# Create the data structures to be used later.
self._post = QueryDict(mutable=True)
self._files = ImmutableMultiValueDict(mutable=True)
self._files = MultiValueDict(mutable=True)
# Instantiate the parser and stream:
stream = LazyStream(ChunkIter(self._input_data, self._chunk_size))

View File

@ -69,7 +69,8 @@ class MultiValueDict(dict):
single name-value pairs.
"""
def __init__(self, key_to_list_mapping=()):
def __init__(self, key_to_list_mapping=(), mutable=True):
self._mutable = mutable
super().__init__(key_to_list_mapping)
def __repr__(self):
@ -90,6 +91,7 @@ class MultiValueDict(dict):
return []
def __setitem__(self, key, value):
self._assert_mutable()
super().__setitem__(key, [value])
def __copy__(self):
@ -108,11 +110,16 @@ class MultiValueDict(dict):
return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}}
def __setstate__(self, obj_dict):
self._assert_mutable()
data = obj_dict.pop("_data", {})
for k, v in data.items():
self.setlist(k, v)
self.__dict__.update(obj_dict)
def _assert_mutable(self):
if hasattr(self, '_mutable') and not self._mutable:
raise AttributeError("This MultiValueDict instance is immutable")
def get(self, key, default=None):
"""
Return the last data value for the passed key. If key doesn't exist
@ -152,9 +159,11 @@ class MultiValueDict(dict):
return self._getlist(key, default, force_list=True)
def setlist(self, key, list_):
self._assert_mutable()
super().__setitem__(key, list_)
def setdefault(self, key, default=None):
self._assert_mutable()
if key not in self:
self[key] = default
# Do not return default here because __setitem__() may store
@ -162,6 +171,7 @@ class MultiValueDict(dict):
return self[key]
def setlistdefault(self, key, default_list=None):
self._assert_mutable()
if key not in self:
if default_list is None:
default_list = []
@ -170,8 +180,21 @@ class MultiValueDict(dict):
# another value -- QueryDict.setlist() does. Look it up.
return self._getlist(key)
def pop(self, key, *args):
self._assert_mutable()
return super().pop(key, *args)
def popitem(self):
self._assert_mutable()
return super().popitem()
def clear(self):
self._assert_mutable()
super().clear()
def appendlist(self, key, value):
"""Append an item to the internal list associated with key."""
self._assert_mutable()
self.setlistdefault(key).append(value)
def items(self):
@ -197,6 +220,7 @@ class MultiValueDict(dict):
def update(self, *args, **kwargs):
"""Extend rather than replace existing key lists."""
self._assert_mutable()
if len(args) > 1:
raise TypeError("update expected at most 1 argument, got %d" % len(args))
if args:
@ -217,73 +241,6 @@ class MultiValueDict(dict):
return {key: self[key] for key in self}
class ImmutableMultiValueDict(MultiValueDict):
_mutable = False
def __init__(self, key_to_list_mapping=(), mutable=False):
super().__init__(key_to_list_mapping)
self._mutable = mutable
def _assert_mutable(self):
if not self._mutable:
raise AttributeError(
"This ImmutableMultiValueDict instance is immutable"
)
def __setitem__(self, key, value):
self._assert_mutable()
super().__setitem__(key, value)
def __delitem__(self, key):
self._assert_mutable()
super().__delitem__(key)
def __copy__(self):
result = self.__class__(mutable=True)
for key, value in self.lists():
result.setlist(key, value)
return result
def __deepcopy__(self, memo):
result = self.__class__(mutable=True)
memo[id(self)] = result
for key, value in self.lists():
result.setlist(copy.deepcopy(key, memo), copy.deepcopy(value, memo))
return result
def setlist(self, key, list_):
self._assert_mutable()
super().setlist(key, list_)
def setlistdefault(self, key, default_list=None):
self._assert_mutable()
return super().setlistdefault(key, default_list)
def appendlist(self, key, value):
self._assert_mutable()
super().appendlist(key, value)
def pop(self, key, *args):
self._assert_mutable()
return super().pop(key, *args)
def popitem(self):
self._assert_mutable()
return super().popitem()
def clear(self):
self._assert_mutable()
super().clear()
def setdefault(self, key, default=None):
self._assert_mutable()
return super().setdefault(key, default)
def copy(self):
"""Return a mutable copy of this object."""
return self.__deepcopy__({})
class ImmutableList(tuple):
"""
A tuple-like object that raises useful errors when it is asked to mutate.

View File

@ -11,7 +11,6 @@ from django.utils.datastructures import (
CaseInsensitiveMapping,
DictWrapper,
ImmutableList,
ImmutableMultiValueDict,
MultiValueDict,
MultiValueDictKeyError,
OrderedSet,
@ -252,7 +251,7 @@ class MultiValueDictTests(SimpleTestCase):
class ImmutableMultiValueDictTests(SimpleTestCase):
def test_immutability(self):
q = ImmutableMultiValueDict()
q = MultiValueDict(mutable=False)
with self.assertRaises(AttributeError):
q.__setitem__('something', 'bar')
with self.assertRaises(AttributeError):
@ -269,50 +268,13 @@ class ImmutableMultiValueDictTests(SimpleTestCase):
q.clear()
def test_mutable_copy(self):
"""A copy of a QueryDict is mutable."""
q = ImmutableMultiValueDict().copy()
"""A copy of an immutable MultiValueDict is mutable."""
q = MultiValueDict(mutable=False).copy()
with self.assertRaises(KeyError):
q.__getitem__("foo")
q['name'] = 'john'
self.assertEqual(q['name'], 'john')
def test_basic_mutable_operations(self):
q = ImmutableMultiValueDict(mutable=True)
q['name'] = 'john'
self.assertEqual(q.get('foo', 'default'), 'default')
self.assertEqual(q.get('name', 'default'), 'john')
self.assertEqual(q.getlist('name'), ['john'])
self.assertEqual(q.getlist('foo'), [])
q.setlist('foo', ['bar', 'baz'])
self.assertEqual(q.get('foo', 'default'), 'baz')
self.assertEqual(q.getlist('foo'), ['bar', 'baz'])
q.appendlist('foo', 'another')
self.assertEqual(q.getlist('foo'), ['bar', 'baz', 'another'])
self.assertEqual(q['foo'], 'another')
self.assertIn('foo', q)
self.assertCountEqual(q, ['foo', 'name'])
self.assertCountEqual(q.items(), [('foo', 'another'), ('name', 'john')])
self.assertCountEqual(q.lists(), [('foo', ['bar', 'baz', 'another']), ('name', ['john'])])
self.assertCountEqual(q.keys(), ['foo', 'name'])
self.assertCountEqual(q.values(), ['another', 'john'])
q.update({'foo': 'hello'})
self.assertEqual(q['foo'], 'hello')
self.assertEqual(q.get('foo', 'not available'), 'hello')
self.assertEqual(q.getlist('foo'), ['bar', 'baz', 'another', 'hello'])
self.assertEqual(q.pop('foo'), ['bar', 'baz', 'another', 'hello'])
self.assertEqual(q.pop('foo', 'not there'), 'not there')
self.assertEqual(q.get('foo', 'not there'), 'not there')
self.assertEqual(q.setdefault('foo', 'bar'), 'bar')
self.assertEqual(q['foo'], 'bar')
self.assertEqual(q.getlist('foo'), ['bar'])
q.clear()
self.assertEqual(len(q), 0)
class ImmutableListTests(SimpleTestCase):
def test_sort(self):