diff --git a/django/core/validators.py b/django/core/validators.py index 3c731f1459..805dd8860f 100644 --- a/django/core/validators.py +++ b/django/core/validators.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import os import re from django.core.exceptions import ValidationError @@ -452,3 +453,53 @@ class DecimalValidator(object): self.max_digits == other.max_digits and self.decimal_places == other.decimal_places ) + + +@deconstructible +class FileExtensionValidator(object): + message = _( + "File extension '%(extension)s' is not allowed. " + "Allowed extensions are: '%(allowed_extensions)s'." + ) + code = 'invalid_extension' + + def __init__(self, allowed_extensions=None, message=None, code=None): + self.allowed_extensions = allowed_extensions + if message is not None: + self.message = message + if code is not None: + self.code = code + + def __call__(self, value): + extension = os.path.splitext(value.name)[1][1:].lower() + if self.allowed_extensions is not None and extension not in self.allowed_extensions: + raise ValidationError( + self.message, + code=self.code, + params={ + 'extension': extension, + 'allowed_extensions': ', '.join(self.allowed_extensions) + } + ) + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) and + self.allowed_extensions == other.allowed_extensions and + self.message == other.message and + self.code == other.code + ) + + +def get_available_image_extensions(): + try: + from PIL import Image + except ImportError: + return [] + else: + Image.init() + return [ext.lower()[1:] for ext in Image.EXTENSION.keys()] + +validate_image_file_extension = FileExtensionValidator( + allowed_extensions=get_available_image_extensions(), +) diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index e3a3fee8f6..8dfd9bf08d 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -8,6 +8,7 @@ from django.core import checks from django.core.files.base import File from django.core.files.images import ImageFile from django.core.files.storage import default_storage +from django.core.validators import validate_image_file_extension from django.db.models import signals from django.db.models.fields import Field from django.utils import six @@ -378,6 +379,7 @@ class ImageFieldFile(ImageFile, FieldFile): class ImageField(FileField): + default_validators = [validate_image_file_extension] attr_class = ImageFieldFile descriptor_class = ImageFileDescriptor description = _("Image") diff --git a/docs/ref/validators.txt b/docs/ref/validators.txt index 7c82f21605..417df517a4 100644 --- a/docs/ref/validators.txt +++ b/docs/ref/validators.txt @@ -279,3 +279,30 @@ to, or in lieu of custom ``field.clean()`` methods. ``decimal_places``. - ``'max_whole_digits'`` if the number of whole digits is larger than the difference between ``max_digits`` and ``decimal_places``. + +``FileExtensionValidator`` +-------------------------- + +.. class:: FileExtensionValidator(allowed_extensions, message, code) + + .. versionadded:: 1.11 + + Raises a :exc:`~django.core.exceptions.ValidationError` with a code of + ``'invalid_extension'`` if the ``value`` cannot be found in + ``allowed_extensions``. + + .. warning:: + + Don't rely on validation of the file extension to determine a file's + type. Files can be renamed to have any extension no matter what data + they contain. + +``validate_image_file_extension`` +--------------------------------- + +.. data:: validate_image_file_extension + + .. versionadded:: 1.11 + + Uses Pillow to ensure that the ``value`` is `a valid image extension + `_. diff --git a/docs/releases/1.11.txt b/docs/releases/1.11.txt index 5bac7e0b0a..1d79b7ffca 100644 --- a/docs/releases/1.11.txt +++ b/docs/releases/1.11.txt @@ -192,6 +192,9 @@ Models ` and :meth:`~django.db.models.query.QuerySet.get_or_create`. +* :class:`~django.db.models.ImageField` now has a default + :data:`~django.core.validators.validate_image_file_extension` validator. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ @@ -237,7 +240,10 @@ URLs Validators ~~~~~~~~~~ -* ... +* Added :class:`~django.core.validators.FileExtensionValidator` to validate + file extensions and + :data:`~django.core.validators.validate_image_file_extension` to validate + image files. .. _backwards-incompatible-1.11: diff --git a/tests/model_fields/test_imagefield.py b/tests/model_fields/test_imagefield.py index f9ae83b8b2..51f91d90c0 100644 --- a/tests/model_fields/test_imagefield.py +++ b/tests/model_fields/test_imagefield.py @@ -4,7 +4,7 @@ import os import shutil from unittest import skipIf -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.core.files import File from django.core.files.images import ImageFile from django.test import TestCase @@ -133,6 +133,12 @@ class ImageFieldTests(ImageFieldTestMixin, TestCase): self.assertEqual(hash(p1_db.mugshot), hash(p1.mugshot)) self.assertIs(p1_db.mugshot != p1.mugshot, False) + def test_validation(self): + p = self.PersonModel(name="Joan") + p.mugshot.save("shot.txt", self.file1) + with self.assertRaisesMessage(ValidationError, "File extension 'txt' is not allowed."): + p.full_clean() + def test_instantiate_missing(self): """ If the underlying file is unavailable, still create instantiate the diff --git a/tests/validators/tests.py b/tests/validators/tests.py index 56ebbe4cce..56dea8804f 100644 --- a/tests/validators/tests.py +++ b/tests/validators/tests.py @@ -9,11 +9,13 @@ from datetime import datetime, timedelta from unittest import TestCase from django.core.exceptions import ValidationError +from django.core.files.base import ContentFile from django.core.validators import ( - BaseValidator, DecimalValidator, EmailValidator, MaxLengthValidator, - MaxValueValidator, MinLengthValidator, MinValueValidator, RegexValidator, - URLValidator, int_list_validator, validate_comma_separated_integer_list, - validate_email, validate_integer, validate_ipv4_address, + BaseValidator, DecimalValidator, EmailValidator, FileExtensionValidator, + MaxLengthValidator, MaxValueValidator, MinLengthValidator, + MinValueValidator, RegexValidator, URLValidator, int_list_validator, + validate_comma_separated_integer_list, validate_email, + validate_image_file_extension, validate_integer, validate_ipv4_address, validate_ipv6_address, validate_ipv46_address, validate_slug, validate_unicode_slug, ) @@ -242,6 +244,17 @@ TEST_DATA = [ (RegexValidator('x', flags=re.IGNORECASE), 'y', ValidationError), (RegexValidator('a'), 'A', ValidationError), (RegexValidator('a', flags=re.IGNORECASE), 'A', None), + + (FileExtensionValidator(['txt']), ContentFile('contents', name='fileWithUnsupportedExt.jpg'), ValidationError), + (FileExtensionValidator(['txt']), ContentFile('contents', name='fileWithNoExtenstion'), ValidationError), + (FileExtensionValidator([]), ContentFile('contents', name='file.txt'), ValidationError), + (FileExtensionValidator(['txt']), ContentFile('contents', name='file.txt'), None), + (FileExtensionValidator(), ContentFile('contents', name='file.jpg'), None), + + (validate_image_file_extension, ContentFile('contents', name='file.jpg'), None), + (validate_image_file_extension, ContentFile('contents', name='file.png'), None), + (validate_image_file_extension, ContentFile('contents', name='file.txt'), ValidationError), + (validate_image_file_extension, ContentFile('contents', name='file'), ValidationError), ] @@ -422,3 +435,33 @@ class TestValidatorEquality(TestCase): DecimalValidator(1, 2), MinValueValidator(11), ) + + def test_file_extension_equality(self): + self.assertEqual( + FileExtensionValidator(), + FileExtensionValidator() + ) + self.assertEqual( + FileExtensionValidator(['txt']), + FileExtensionValidator(['txt']) + ) + self.assertEqual( + FileExtensionValidator(['txt']), + FileExtensionValidator(['txt'], code='invalid_extension') + ) + self.assertNotEqual( + FileExtensionValidator(['txt']), + FileExtensionValidator(['png']) + ) + self.assertNotEqual( + FileExtensionValidator(['txt']), + FileExtensionValidator(['png', 'jpg']) + ) + self.assertNotEqual( + FileExtensionValidator(['txt']), + FileExtensionValidator(['txt'], code='custom_code') + ) + self.assertNotEqual( + FileExtensionValidator(['txt']), + FileExtensionValidator(['txt'], message='custom error message') + )