From 3e84dfbd2426a5080d6837474a097265a31d9ef6 Mon Sep 17 00:00:00 2001 From: Ben Cail Date: Mon, 18 Mar 2024 14:34:59 -0400 Subject: [PATCH] make MultiValueDict mutable or immutable --- django/http/multipartparser.py | 10 +-- django/utils/datastructures.py | 93 +++++++----------------- tests/utils_tests/test_datastructures.py | 44 +---------- 3 files changed, 33 insertions(+), 114 deletions(-) diff --git a/django/http/multipartparser.py b/django/http/multipartparser.py index 7b8b62e9e8..c6a008ef5c 100644 --- a/django/http/multipartparser.py +++ b/django/http/multipartparser.py @@ -18,7 +18,7 @@ from django.core.exceptions import ( TooManyFilesSent, ) from django.core.files.uploadhandler import SkipFile, StopFutureHandlers, StopUpload -from django.utils.datastructures import ImmutableMultiValueDict, MultiValueDict +from django.utils.datastructures import MultiValueDict from django.utils.encoding import force_str from django.utils.http import parse_header_parameters from django.utils.regex_helper import _lazy_re_compile @@ -50,7 +50,7 @@ class MultiPartParser: An RFC 7578 multipart/form-data parser. ``MultiPartParser.parse()`` reads the input stream in ``chunk_size`` chunks - and returns a tuple of ``(QueryDict(POST), ImmutableMultiValueDict(FILES))``. + and returns a tuple of ``(QueryDict(POST), MultiValueDict(FILES))``. """ boundary_re = _lazy_re_compile(r"[ -~]{0,200}[!-~]") @@ -132,7 +132,7 @@ class MultiPartParser: def _parse(self): """ - Parse the POST data and break it into a FILES ImmutableMultiValueDict + Parse the POST data and break it into a FILES MultiValueDict and a POST QueryDict. Return a tuple containing the POST and FILES dictionary, respectively. @@ -145,7 +145,7 @@ class MultiPartParser: # HTTP spec says that Content-Length >= 0 is valid # handling content-length == 0 before continuing if self._content_length == 0: - return QueryDict(encoding=self._encoding), ImmutableMultiValueDict() + return QueryDict(encoding=self._encoding), MultiValueDict(mutable=False) # See if any of the handlers take care of the parsing. # This allows overriding everything if need be. @@ -163,7 +163,7 @@ class MultiPartParser: # Create the data structures to be used later. self._post = QueryDict(mutable=True) - self._files = ImmutableMultiValueDict(mutable=True) + self._files = MultiValueDict(mutable=True) # Instantiate the parser and stream: stream = LazyStream(ChunkIter(self._input_data, self._chunk_size)) diff --git a/django/utils/datastructures.py b/django/utils/datastructures.py index 9381a63eb1..5de3f8b8c3 100644 --- a/django/utils/datastructures.py +++ b/django/utils/datastructures.py @@ -69,7 +69,8 @@ class MultiValueDict(dict): single name-value pairs. """ - def __init__(self, key_to_list_mapping=()): + def __init__(self, key_to_list_mapping=(), mutable=True): + self._mutable = mutable super().__init__(key_to_list_mapping) def __repr__(self): @@ -90,6 +91,7 @@ class MultiValueDict(dict): return [] def __setitem__(self, key, value): + self._assert_mutable() super().__setitem__(key, [value]) def __copy__(self): @@ -108,11 +110,16 @@ class MultiValueDict(dict): return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}} def __setstate__(self, obj_dict): + self._assert_mutable() data = obj_dict.pop("_data", {}) for k, v in data.items(): self.setlist(k, v) self.__dict__.update(obj_dict) + def _assert_mutable(self): + if hasattr(self, '_mutable') and not self._mutable: + raise AttributeError("This MultiValueDict instance is immutable") + def get(self, key, default=None): """ Return the last data value for the passed key. If key doesn't exist @@ -152,9 +159,11 @@ class MultiValueDict(dict): return self._getlist(key, default, force_list=True) def setlist(self, key, list_): + self._assert_mutable() super().__setitem__(key, list_) def setdefault(self, key, default=None): + self._assert_mutable() if key not in self: self[key] = default # Do not return default here because __setitem__() may store @@ -162,6 +171,7 @@ class MultiValueDict(dict): return self[key] def setlistdefault(self, key, default_list=None): + self._assert_mutable() if key not in self: if default_list is None: default_list = [] @@ -170,8 +180,21 @@ class MultiValueDict(dict): # another value -- QueryDict.setlist() does. Look it up. return self._getlist(key) + def pop(self, key, *args): + self._assert_mutable() + return super().pop(key, *args) + + def popitem(self): + self._assert_mutable() + return super().popitem() + + def clear(self): + self._assert_mutable() + super().clear() + def appendlist(self, key, value): """Append an item to the internal list associated with key.""" + self._assert_mutable() self.setlistdefault(key).append(value) def items(self): @@ -197,6 +220,7 @@ class MultiValueDict(dict): def update(self, *args, **kwargs): """Extend rather than replace existing key lists.""" + self._assert_mutable() if len(args) > 1: raise TypeError("update expected at most 1 argument, got %d" % len(args)) if args: @@ -217,73 +241,6 @@ class MultiValueDict(dict): return {key: self[key] for key in self} -class ImmutableMultiValueDict(MultiValueDict): - _mutable = False - - def __init__(self, key_to_list_mapping=(), mutable=False): - super().__init__(key_to_list_mapping) - self._mutable = mutable - - def _assert_mutable(self): - if not self._mutable: - raise AttributeError( - "This ImmutableMultiValueDict instance is immutable" - ) - - def __setitem__(self, key, value): - self._assert_mutable() - super().__setitem__(key, value) - - def __delitem__(self, key): - self._assert_mutable() - super().__delitem__(key) - - def __copy__(self): - result = self.__class__(mutable=True) - for key, value in self.lists(): - result.setlist(key, value) - return result - - def __deepcopy__(self, memo): - result = self.__class__(mutable=True) - memo[id(self)] = result - for key, value in self.lists(): - result.setlist(copy.deepcopy(key, memo), copy.deepcopy(value, memo)) - return result - - def setlist(self, key, list_): - self._assert_mutable() - super().setlist(key, list_) - - def setlistdefault(self, key, default_list=None): - self._assert_mutable() - return super().setlistdefault(key, default_list) - - def appendlist(self, key, value): - self._assert_mutable() - super().appendlist(key, value) - - def pop(self, key, *args): - self._assert_mutable() - return super().pop(key, *args) - - def popitem(self): - self._assert_mutable() - return super().popitem() - - def clear(self): - self._assert_mutable() - super().clear() - - def setdefault(self, key, default=None): - self._assert_mutable() - return super().setdefault(key, default) - - def copy(self): - """Return a mutable copy of this object.""" - return self.__deepcopy__({}) - - class ImmutableList(tuple): """ A tuple-like object that raises useful errors when it is asked to mutate. diff --git a/tests/utils_tests/test_datastructures.py b/tests/utils_tests/test_datastructures.py index f0a44f22ed..957e1a8f4d 100644 --- a/tests/utils_tests/test_datastructures.py +++ b/tests/utils_tests/test_datastructures.py @@ -11,7 +11,6 @@ from django.utils.datastructures import ( CaseInsensitiveMapping, DictWrapper, ImmutableList, - ImmutableMultiValueDict, MultiValueDict, MultiValueDictKeyError, OrderedSet, @@ -252,7 +251,7 @@ class MultiValueDictTests(SimpleTestCase): class ImmutableMultiValueDictTests(SimpleTestCase): def test_immutability(self): - q = ImmutableMultiValueDict() + q = MultiValueDict(mutable=False) with self.assertRaises(AttributeError): q.__setitem__('something', 'bar') with self.assertRaises(AttributeError): @@ -269,50 +268,13 @@ class ImmutableMultiValueDictTests(SimpleTestCase): q.clear() def test_mutable_copy(self): - """A copy of a QueryDict is mutable.""" - q = ImmutableMultiValueDict().copy() + """A copy of an immutable MultiValueDict is mutable.""" + q = MultiValueDict(mutable=False).copy() with self.assertRaises(KeyError): q.__getitem__("foo") q['name'] = 'john' self.assertEqual(q['name'], 'john') - def test_basic_mutable_operations(self): - q = ImmutableMultiValueDict(mutable=True) - q['name'] = 'john' - self.assertEqual(q.get('foo', 'default'), 'default') - self.assertEqual(q.get('name', 'default'), 'john') - self.assertEqual(q.getlist('name'), ['john']) - self.assertEqual(q.getlist('foo'), []) - - q.setlist('foo', ['bar', 'baz']) - self.assertEqual(q.get('foo', 'default'), 'baz') - self.assertEqual(q.getlist('foo'), ['bar', 'baz']) - - q.appendlist('foo', 'another') - self.assertEqual(q.getlist('foo'), ['bar', 'baz', 'another']) - self.assertEqual(q['foo'], 'another') - self.assertIn('foo', q) - - self.assertCountEqual(q, ['foo', 'name']) - self.assertCountEqual(q.items(), [('foo', 'another'), ('name', 'john')]) - self.assertCountEqual(q.lists(), [('foo', ['bar', 'baz', 'another']), ('name', ['john'])]) - self.assertCountEqual(q.keys(), ['foo', 'name']) - self.assertCountEqual(q.values(), ['another', 'john']) - - q.update({'foo': 'hello'}) - self.assertEqual(q['foo'], 'hello') - self.assertEqual(q.get('foo', 'not available'), 'hello') - self.assertEqual(q.getlist('foo'), ['bar', 'baz', 'another', 'hello']) - self.assertEqual(q.pop('foo'), ['bar', 'baz', 'another', 'hello']) - self.assertEqual(q.pop('foo', 'not there'), 'not there') - self.assertEqual(q.get('foo', 'not there'), 'not there') - self.assertEqual(q.setdefault('foo', 'bar'), 'bar') - self.assertEqual(q['foo'], 'bar') - self.assertEqual(q.getlist('foo'), ['bar']) - - q.clear() - self.assertEqual(len(q), 0) - class ImmutableListTests(SimpleTestCase): def test_sort(self):