mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fixed #24636 -- Added model field validation for decimal places and max digits.
This commit is contained in:
		
				
					committed by
					
						 Simon Charette
						Simon Charette
					
				
			
			
				
	
			
			
			
						parent
						
							6f1b09bb5c
						
					
				
				
					commit
					75ed590032
				
			| @@ -346,3 +346,72 @@ class MaxLengthValidator(BaseValidator): | ||||
|         'Ensure this value has at most %(limit_value)d characters (it has %(show_value)d).', | ||||
|         'limit_value') | ||||
|     code = 'max_length' | ||||
|  | ||||
|  | ||||
| @deconstructible | ||||
| class DecimalValidator(object): | ||||
|     """ | ||||
|     Validate that the input does not exceed the maximum number of digits | ||||
|     expected, otherwise raise ValidationError. | ||||
|     """ | ||||
|     messages = { | ||||
|         'max_digits': ungettext_lazy( | ||||
|             'Ensure that there are no more than %(max)s digit in total.', | ||||
|             'Ensure that there are no more than %(max)s digits in total.', | ||||
|             'max' | ||||
|         ), | ||||
|         'max_decimal_places': ungettext_lazy( | ||||
|             'Ensure that there are no more than %(max)s decimal place.', | ||||
|             'Ensure that there are no more than %(max)s decimal places.', | ||||
|             'max' | ||||
|         ), | ||||
|         'max_whole_digits': ungettext_lazy( | ||||
|             'Ensure that there are no more than %(max)s digit before the decimal point.', | ||||
|             'Ensure that there are no more than %(max)s digits before the decimal point.', | ||||
|             'max' | ||||
|         ), | ||||
|     } | ||||
|  | ||||
|     def __init__(self, max_digits, decimal_places): | ||||
|         self.max_digits = max_digits | ||||
|         self.decimal_places = decimal_places | ||||
|  | ||||
|     def __call__(self, value): | ||||
|         digit_tuple, exponent = value.as_tuple()[1:] | ||||
|         decimals = abs(exponent) | ||||
|         # digit_tuple doesn't include any leading zeros. | ||||
|         digits = len(digit_tuple) | ||||
|         if decimals > digits: | ||||
|             # We have leading zeros up to or past the decimal point. Count | ||||
|             # everything past the decimal point as a digit. We do not count | ||||
|             # 0 before the decimal point as a digit since that would mean | ||||
|             # we would not allow max_digits = decimal_places. | ||||
|             digits = decimals | ||||
|         whole_digits = digits - decimals | ||||
|  | ||||
|         if self.max_digits is not None and digits > self.max_digits: | ||||
|             raise ValidationError( | ||||
|                 self.messages['max_digits'], | ||||
|                 code='max_digits', | ||||
|                 params={'max': self.max_digits}, | ||||
|             ) | ||||
|         if self.decimal_places is not None and decimals > self.decimal_places: | ||||
|             raise ValidationError( | ||||
|                 self.messages['max_decimal_places'], | ||||
|                 code='max_decimal_places', | ||||
|                 params={'max': self.decimal_places}, | ||||
|             ) | ||||
|         if (self.max_digits is not None and self.decimal_places is not None | ||||
|                 and whole_digits > (self.max_digits - self.decimal_places)): | ||||
|             raise ValidationError( | ||||
|                 self.messages['max_whole_digits'], | ||||
|                 code='max_whole_digits', | ||||
|                 params={'max': (self.max_digits - self.decimal_places)}, | ||||
|             ) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return ( | ||||
|             isinstance(other, self.__class__) and | ||||
|             self.max_digits == other.max_digits and | ||||
|             self.decimal_places == other.decimal_places | ||||
|         ) | ||||
|   | ||||
| @@ -1578,6 +1578,12 @@ class DecimalField(Field): | ||||
|             ] | ||||
|         return [] | ||||
|  | ||||
|     @cached_property | ||||
|     def validators(self): | ||||
|         return super(DecimalField, self).validators + [ | ||||
|             validators.DecimalValidator(self.max_digits, self.decimal_places) | ||||
|         ] | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         name, path, args, kwargs = super(DecimalField, self).deconstruct() | ||||
|         if self.max_digits is not None: | ||||
|   | ||||
| @@ -334,23 +334,12 @@ class FloatField(IntegerField): | ||||
| class DecimalField(IntegerField): | ||||
|     default_error_messages = { | ||||
|         'invalid': _('Enter a number.'), | ||||
|         'max_digits': ungettext_lazy( | ||||
|             'Ensure that there are no more than %(max)s digit in total.', | ||||
|             'Ensure that there are no more than %(max)s digits in total.', | ||||
|             'max'), | ||||
|         'max_decimal_places': ungettext_lazy( | ||||
|             'Ensure that there are no more than %(max)s decimal place.', | ||||
|             'Ensure that there are no more than %(max)s decimal places.', | ||||
|             'max'), | ||||
|         'max_whole_digits': ungettext_lazy( | ||||
|             'Ensure that there are no more than %(max)s digit before the decimal point.', | ||||
|             'Ensure that there are no more than %(max)s digits before the decimal point.', | ||||
|             'max'), | ||||
|     } | ||||
|  | ||||
|     def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): | ||||
|         self.max_digits, self.decimal_places = max_digits, decimal_places | ||||
|         super(DecimalField, self).__init__(max_value, min_value, *args, **kwargs) | ||||
|         self.validators.append(validators.DecimalValidator(max_digits, decimal_places)) | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         """ | ||||
| @@ -379,38 +368,6 @@ class DecimalField(IntegerField): | ||||
|         # isn't equal to itself, so we can use this to identify NaN | ||||
|         if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): | ||||
|             raise ValidationError(self.error_messages['invalid'], code='invalid') | ||||
|         sign, digittuple, exponent = value.as_tuple() | ||||
|         decimals = abs(exponent) | ||||
|         # digittuple doesn't include any leading zeros. | ||||
|         digits = len(digittuple) | ||||
|         if decimals > digits: | ||||
|             # We have leading zeros up to or past the decimal point.  Count | ||||
|             # everything past the decimal point as a digit.  We do not count | ||||
|             # 0 before the decimal point as a digit since that would mean | ||||
|             # we would not allow max_digits = decimal_places. | ||||
|             digits = decimals | ||||
|         whole_digits = digits - decimals | ||||
|  | ||||
|         if self.max_digits is not None and digits > self.max_digits: | ||||
|             raise ValidationError( | ||||
|                 self.error_messages['max_digits'], | ||||
|                 code='max_digits', | ||||
|                 params={'max': self.max_digits}, | ||||
|             ) | ||||
|         if self.decimal_places is not None and decimals > self.decimal_places: | ||||
|             raise ValidationError( | ||||
|                 self.error_messages['max_decimal_places'], | ||||
|                 code='max_decimal_places', | ||||
|                 params={'max': self.decimal_places}, | ||||
|             ) | ||||
|         if (self.max_digits is not None and self.decimal_places is not None | ||||
|                 and whole_digits > (self.max_digits - self.decimal_places)): | ||||
|             raise ValidationError( | ||||
|                 self.error_messages['max_whole_digits'], | ||||
|                 code='max_whole_digits', | ||||
|                 params={'max': (self.max_digits - self.decimal_places)}, | ||||
|             ) | ||||
|         return value | ||||
|  | ||||
|     def widget_attrs(self, widget): | ||||
|         attrs = super(DecimalField, self).widget_attrs(widget) | ||||
|   | ||||
| @@ -281,3 +281,19 @@ to, or in lieu of custom ``field.clean()`` methods. | ||||
|     .. versionchanged:: 1.8 | ||||
|  | ||||
|        The ``message`` parameter was added. | ||||
|  | ||||
| ``DecimalValidator`` | ||||
| -------------------- | ||||
|  | ||||
| .. class:: DecimalValidator(max_digits, decimal_places) | ||||
|  | ||||
|     .. versionadded:: 1.9 | ||||
|  | ||||
|     Raises :exc:`~django.core.exceptions.ValidationError` with the following | ||||
|     codes: | ||||
|  | ||||
|     - ``'max_digits'`` if the number of digits is larger than ``max_digits``. | ||||
|     - ``'max_decimal_places'`` if the number of decimals is larger than | ||||
|       ``decimal_places``. | ||||
|     - ``'max_whole_digits'`` if the number of whole digits is larger than | ||||
|       the difference between ``max_digits`` and ``decimal_places``. | ||||
|   | ||||
| @@ -165,6 +165,24 @@ class DecimalFieldTests(test.TestCase): | ||||
|         # This should not crash. That counts as a win for our purposes. | ||||
|         Foo.objects.filter(d__gte=100000000000) | ||||
|  | ||||
|     def test_max_digits_validation(self): | ||||
|         field = models.DecimalField(max_digits=2) | ||||
|         expected_message = validators.DecimalValidator.messages['max_digits'] % {'max': 2} | ||||
|         with self.assertRaisesMessage(ValidationError, expected_message): | ||||
|             field.clean(100, None) | ||||
|  | ||||
|     def test_max_decimal_places_validation(self): | ||||
|         field = models.DecimalField(decimal_places=1) | ||||
|         expected_message = validators.DecimalValidator.messages['max_decimal_places'] % {'max': 1} | ||||
|         with self.assertRaisesMessage(ValidationError, expected_message): | ||||
|             field.clean(Decimal('0.99'), None) | ||||
|  | ||||
|     def test_max_whole_digits_validation(self): | ||||
|         field = models.DecimalField(max_digits=3, decimal_places=1) | ||||
|         expected_message = validators.DecimalValidator.messages['max_whole_digits'] % {'max': 2} | ||||
|         with self.assertRaisesMessage(ValidationError, expected_message): | ||||
|             field.clean(Decimal('999'), None) | ||||
|  | ||||
|  | ||||
| class ForeignKeyTests(test.TestCase): | ||||
|     def test_callable_default(self): | ||||
|   | ||||
| @@ -10,11 +10,12 @@ from unittest import TestCase | ||||
|  | ||||
| from django.core.exceptions import ValidationError | ||||
| from django.core.validators import ( | ||||
|     BaseValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, | ||||
|     MinLengthValidator, MinValueValidator, RegexValidator, URLValidator, | ||||
|     int_list_validator, validate_comma_separated_integer_list, validate_email, | ||||
|     validate_integer, validate_ipv4_address, validate_ipv6_address, | ||||
|     validate_ipv46_address, validate_slug, validate_unicode_slug, | ||||
|     BaseValidator, DecimalValidator, EmailValidator, MaxLengthValidator, | ||||
|     MaxValueValidator, MinLengthValidator, MinValueValidator, RegexValidator, | ||||
|     URLValidator, int_list_validator, validate_comma_separated_integer_list, | ||||
|     validate_email, validate_integer, validate_ipv4_address, | ||||
|     validate_ipv6_address, validate_ipv46_address, validate_slug, | ||||
|     validate_unicode_slug, | ||||
| ) | ||||
| from django.test import SimpleTestCase | ||||
| from django.test.utils import str_prefix | ||||
| @@ -401,3 +402,21 @@ class TestValidatorEquality(TestCase): | ||||
|             MinValueValidator(45), | ||||
|             MinValueValidator(11), | ||||
|         ) | ||||
|  | ||||
|     def test_decimal_equality(self): | ||||
|         self.assertEqual( | ||||
|             DecimalValidator(1, 2), | ||||
|             DecimalValidator(1, 2), | ||||
|         ) | ||||
|         self.assertNotEqual( | ||||
|             DecimalValidator(1, 2), | ||||
|             DecimalValidator(1, 1), | ||||
|         ) | ||||
|         self.assertNotEqual( | ||||
|             DecimalValidator(1, 2), | ||||
|             DecimalValidator(2, 2), | ||||
|         ) | ||||
|         self.assertNotEqual( | ||||
|             DecimalValidator(1, 2), | ||||
|             MinValueValidator(11), | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user