import decimal import enum import json import unittest import uuid from django import forms from django.core import checks, exceptions, serializers, validators from django.core.exceptions import FieldError from django.core.management import call_command from django.db import IntegrityError, connection, models from django.db.models.expressions import RawSQL from django.db.models.functions import Cast from django.test import TransactionTestCase, modify_settings, override_settings from django.test.utils import isolate_apps from django.utils import timezone from . import ( PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase, ) from .models import ( ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, PostgreSQLModel, Tag, ) try: from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields.array import IndexTransform, SliceTransform from django.contrib.postgres.forms import ( SimpleArrayField, SplitArrayField, SplitArrayWidget, ) from psycopg2.extras import NumericRange except ImportError: pass class TestSaveLoad(PostgreSQLTestCase): def test_integer(self): instance = IntegerArrayModel(field=[1, 2, 3]) instance.save() loaded = IntegerArrayModel.objects.get() self.assertEqual(instance.field, loaded.field) def test_char(self): instance = CharArrayModel(field=['hello', 'goodbye']) instance.save() loaded = CharArrayModel.objects.get() self.assertEqual(instance.field, loaded.field) def test_dates(self): instance = DateTimeArrayModel( datetimes=[timezone.now()], dates=[timezone.now().date()], times=[timezone.now().time()], ) instance.save() loaded = DateTimeArrayModel.objects.get() self.assertEqual(instance.datetimes, loaded.datetimes) self.assertEqual(instance.dates, loaded.dates) self.assertEqual(instance.times, loaded.times) def test_tuples(self): instance = IntegerArrayModel(field=(1,)) instance.save() loaded = IntegerArrayModel.objects.get() self.assertSequenceEqual(instance.field, loaded.field) def test_integers_passed_as_strings(self): # This checks that get_prep_value is deferred properly instance = IntegerArrayModel(field=['1']) instance.save() loaded = IntegerArrayModel.objects.get() self.assertEqual(loaded.field, [1]) def test_default_null(self): instance = NullableIntegerArrayModel() instance.save() loaded = NullableIntegerArrayModel.objects.get(pk=instance.pk) self.assertIsNone(loaded.field) self.assertEqual(instance.field, loaded.field) def test_null_handling(self): instance = NullableIntegerArrayModel(field=None) instance.save() loaded = NullableIntegerArrayModel.objects.get() self.assertEqual(instance.field, loaded.field) instance = IntegerArrayModel(field=None) with self.assertRaises(IntegrityError): instance.save() def test_nested(self): instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]]) instance.save() loaded = NestedIntegerArrayModel.objects.get() self.assertEqual(instance.field, loaded.field) def test_other_array_types(self): instance = OtherTypesArrayModel( ips=['192.168.0.1', '::1'], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=[Tag(1), Tag(2), Tag(3)], json=[{'a': 1}, {'b': 2}], int_ranges=[NumericRange(10, 20), NumericRange(30, 40)], bigint_ranges=[ NumericRange(7000000000, 10000000000), NumericRange(50000000000, 70000000000), ] ) instance.save() loaded = OtherTypesArrayModel.objects.get() self.assertEqual(instance.ips, loaded.ips) self.assertEqual(instance.uuids, loaded.uuids) self.assertEqual(instance.decimals, loaded.decimals) self.assertEqual(instance.tags, loaded.tags) self.assertEqual(instance.json, loaded.json) self.assertEqual(instance.int_ranges, loaded.int_ranges) self.assertEqual(instance.bigint_ranges, loaded.bigint_ranges) def test_null_from_db_value_handling(self): instance = OtherTypesArrayModel.objects.create( ips=['192.168.0.1', '::1'], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=None, ) instance.refresh_from_db() self.assertIsNone(instance.tags) self.assertEqual(instance.json, []) self.assertIsNone(instance.int_ranges) self.assertIsNone(instance.bigint_ranges) def test_model_set_on_base_field(self): instance = IntegerArrayModel() field = instance._meta.get_field('field') self.assertEqual(field.model, IntegerArrayModel) self.assertEqual(field.base_field.model, IntegerArrayModel) class TestQuerying(PostgreSQLTestCase): @classmethod def setUpTestData(cls): cls.objs = NullableIntegerArrayModel.objects.bulk_create([ NullableIntegerArrayModel(field=[1]), NullableIntegerArrayModel(field=[2]), NullableIntegerArrayModel(field=[2, 3]), NullableIntegerArrayModel(field=[20, 30, 40]), NullableIntegerArrayModel(field=None), ]) def test_empty_list(self): NullableIntegerArrayModel.objects.create(field=[]) obj = NullableIntegerArrayModel.objects.annotate( empty_array=models.Value([], output_field=ArrayField(models.IntegerField())), ).filter(field=models.F('empty_array')).get() self.assertEqual(obj.field, []) self.assertEqual(obj.empty_array, []) def test_exact(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1] ) def test_exact_charfield(self): instance = CharArrayModel.objects.create(field=['text']) self.assertSequenceEqual( CharArrayModel.objects.filter(field=['text']), [instance] ) def test_exact_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance] ) def test_isnull(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:] ) def test_gt(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4] ) def test_lt(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1] ) def test_in(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), self.objs[:2] ) def test_in_subquery(self): IntegerArrayModel.objects.create(field=[2, 3]) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter( field__in=IntegerArrayModel.objects.all().values_list('field', flat=True) ), self.objs[2:3] ) @unittest.expectedFailure def test_in_including_F_object(self): # This test asserts that Array objects passed to filters can be # constructed to contain F objects. This currently doesn't work as the # psycopg2 mogrify method that generates the ARRAY() syntax is # expecting literals, not column references (#27095). self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]), self.objs[:2] ) def test_in_as_F_object(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]), self.objs[:4] ) def test_contained_by(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), self.objs[:2] ) @unittest.expectedFailure def test_contained_by_including_F_object(self): # This test asserts that Array objects passed to filters can be # constructed to contain F objects. This currently doesn't work as the # psycopg2 mogrify method that generates the ARRAY() syntax is # expecting literals, not column references (#27095). self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]), self.objs[:2] ) def test_contains(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contains=[2]), self.objs[1:3] ) def test_icontains(self): # Using the __icontains lookup with ArrayField is inefficient. instance = CharArrayModel.objects.create(field=['FoO']) self.assertSequenceEqual( CharArrayModel.objects.filter(field__icontains='foo'), [instance] ) def test_contains_charfield(self): # Regression for #22907 self.assertSequenceEqual( CharArrayModel.objects.filter(field__contains=['text']), [] ) def test_contained_by_charfield(self): self.assertSequenceEqual( CharArrayModel.objects.filter(field__contained_by=['text']), [] ) def test_overlap_charfield(self): self.assertSequenceEqual( CharArrayModel.objects.filter(field__overlap=['text']), [] ) def test_index(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3] ) def test_index_chained(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3] ) def test_index_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance] ) @unittest.expectedFailure def test_index_used_on_nested_data(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance] ) def test_index_transform_expression(self): expr = RawSQL("string_to_array(%s, ';')", ['1;2']) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter( field__0=Cast( IndexTransform(1, models.IntegerField, expr), output_field=models.IntegerField(), ), ), self.objs[:1], ) def test_overlap(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), self.objs[0:3] ) def test_len(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3] ) def test_len_empty_array(self): obj = NullableIntegerArrayModel.objects.create(field=[]) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__len=0), [obj] ) def test_slice(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3] ) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3] ) def test_order_by_slice(self): more_objs = ( NullableIntegerArrayModel.objects.create(field=[1, 637]), NullableIntegerArrayModel.objects.create(field=[2, 1]), NullableIntegerArrayModel.objects.create(field=[3, -98123]), NullableIntegerArrayModel.objects.create(field=[4, 2]), ) self.assertSequenceEqual( NullableIntegerArrayModel.objects.order_by('field__1'), [ more_objs[2], more_objs[1], more_objs[3], self.objs[2], self.objs[3], more_objs[0], self.objs[4], self.objs[1], self.objs[0], ] ) @unittest.expectedFailure def test_slice_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance] ) def test_slice_transform_expression(self): expr = RawSQL("string_to_array(%s, ';')", ['9;2;3']) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0_2=SliceTransform(2, 3, expr)), self.objs[2:3], ) def test_usage_in_subquery(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter( id__in=NullableIntegerArrayModel.objects.filter(field__len=3) ), [self.objs[3]] ) def test_enum_lookup(self): class TestEnum(enum.Enum): VALUE_1 = 'value_1' instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1]) self.assertSequenceEqual( ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]), [instance] ) def test_unsupported_lookup(self): msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted." with self.assertRaisesMessage(FieldError, msg): list(NullableIntegerArrayModel.objects.filter(field__0_bar=[2])) msg = "Unsupported lookup '0bar' for ArrayField or join on the field not permitted." with self.assertRaisesMessage(FieldError, msg): list(NullableIntegerArrayModel.objects.filter(field__0bar=[2])) def test_grouping_by_annotations_with_array_field_param(self): value = models.Value([1], output_field=ArrayField(models.IntegerField())) self.assertEqual( NullableIntegerArrayModel.objects.annotate( array_length=models.Func(value, 1, function='ARRAY_LENGTH'), ).values('array_length').annotate( count=models.Count('pk'), ).get()['array_length'], 1, ) class TestDateTimeExactQuerying(PostgreSQLTestCase): @classmethod def setUpTestData(cls): now = timezone.now() cls.datetimes = [now] cls.dates = [now.date()] cls.times = [now.time()] cls.objs = [ DateTimeArrayModel.objects.create(datetimes=cls.datetimes, dates=cls.dates, times=cls.times), ] def test_exact_datetimes(self): self.assertSequenceEqual( DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs ) def test_exact_dates(self): self.assertSequenceEqual( DateTimeArrayModel.objects.filter(dates=self.dates), self.objs ) def test_exact_times(self): self.assertSequenceEqual( DateTimeArrayModel.objects.filter(times=self.times), self.objs ) class TestOtherTypesExactQuerying(PostgreSQLTestCase): @classmethod def setUpTestData(cls): cls.ips = ['192.168.0.1', '::1'] cls.uuids = [uuid.uuid4()] cls.decimals = [decimal.Decimal(1.25), 1.75] cls.tags = [Tag(1), Tag(2), Tag(3)] cls.objs = [ OtherTypesArrayModel.objects.create( ips=cls.ips, uuids=cls.uuids, decimals=cls.decimals, tags=cls.tags, ) ] def test_exact_ip_addresses(self): self.assertSequenceEqual( OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs ) def test_exact_uuids(self): self.assertSequenceEqual( OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs ) def test_exact_decimals(self): self.assertSequenceEqual( OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs ) def test_exact_tags(self): self.assertSequenceEqual( OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs ) @isolate_apps('postgres_tests') class TestChecks(PostgreSQLSimpleTestCase): def test_field_checks(self): class MyModel(PostgreSQLModel): field = ArrayField(models.CharField()) model = MyModel() errors = model.check() self.assertEqual(len(errors), 1) # The inner CharField is missing a max_length. self.assertEqual(errors[0].id, 'postgres.E001') self.assertIn('max_length', errors[0].msg) def test_invalid_base_fields(self): class MyModel(PostgreSQLModel): field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel')) model = MyModel() errors = model.check() self.assertEqual(len(errors), 1) self.assertEqual(errors[0].id, 'postgres.E002') def test_invalid_default(self): class MyModel(PostgreSQLModel): field = ArrayField(models.IntegerField(), default=[]) model = MyModel() self.assertEqual(model.check(), [ checks.Warning( msg=( "ArrayField default should be a callable instead of an " "instance so that it's not shared between all field " "instances." ), hint='Use a callable instead, e.g., use `list` instead of `[]`.', obj=MyModel._meta.get_field('field'), id='postgres.E003', ) ]) def test_valid_default(self): class MyModel(PostgreSQLModel): field = ArrayField(models.IntegerField(), default=list) model = MyModel() self.assertEqual(model.check(), []) def test_valid_default_none(self): class MyModel(PostgreSQLModel): field = ArrayField(models.IntegerField(), default=None) model = MyModel() self.assertEqual(model.check(), []) def test_nested_field_checks(self): """ Nested ArrayFields are permitted. """ class MyModel(PostgreSQLModel): field = ArrayField(ArrayField(models.CharField())) model = MyModel() errors = model.check() self.assertEqual(len(errors), 1) # The inner CharField is missing a max_length. self.assertEqual(errors[0].id, 'postgres.E001') self.assertIn('max_length', errors[0].msg) @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") class TestMigrations(TransactionTestCase): available_apps = ['postgres_tests'] def test_deconstruct(self): field = ArrayField(models.IntegerField()) name, path, args, kwargs = field.deconstruct() new = ArrayField(*args, **kwargs) self.assertEqual(type(new.base_field), type(field.base_field)) self.assertIsNot(new.base_field, field.base_field) def test_deconstruct_with_size(self): field = ArrayField(models.IntegerField(), size=3) name, path, args, kwargs = field.deconstruct() new = ArrayField(*args, **kwargs) self.assertEqual(new.size, field.size) def test_deconstruct_args(self): field = ArrayField(models.CharField(max_length=20)) name, path, args, kwargs = field.deconstruct() new = ArrayField(*args, **kwargs) self.assertEqual(new.base_field.max_length, field.base_field.max_length) def test_subclass_deconstruct(self): field = ArrayField(models.IntegerField()) name, path, args, kwargs = field.deconstruct() self.assertEqual(path, 'django.contrib.postgres.fields.ArrayField') field = ArrayFieldSubclass() name, path, args, kwargs = field.deconstruct() self.assertEqual(path, 'postgres_tests.models.ArrayFieldSubclass') @override_settings(MIGRATION_MODULES={ "postgres_tests": "postgres_tests.array_default_migrations", }) def test_adding_field_with_default(self): # See #22962 table_name = 'postgres_tests_integerarraydefaultmodel' with connection.cursor() as cursor: self.assertNotIn(table_name, connection.introspection.table_names(cursor)) call_command('migrate', 'postgres_tests', verbosity=0) with connection.cursor() as cursor: self.assertIn(table_name, connection.introspection.table_names(cursor)) call_command('migrate', 'postgres_tests', 'zero', verbosity=0) with connection.cursor() as cursor: self.assertNotIn(table_name, connection.introspection.table_names(cursor)) @override_settings(MIGRATION_MODULES={ "postgres_tests": "postgres_tests.array_index_migrations", }) def test_adding_arrayfield_with_index(self): """ ArrayField shouldn't have varchar_patterns_ops or text_patterns_ops indexes. """ table_name = 'postgres_tests_chartextarrayindexmodel' call_command('migrate', 'postgres_tests', verbosity=0) with connection.cursor() as cursor: like_constraint_columns_list = [ v['columns'] for k, v in list(connection.introspection.get_constraints(cursor, table_name).items()) if k.endswith('_like') ] # Only the CharField should have a LIKE index. self.assertEqual(like_constraint_columns_list, [['char2']]) # All fields should have regular indexes. with connection.cursor() as cursor: indexes = [ c['columns'][0] for c in connection.introspection.get_constraints(cursor, table_name).values() if c['index'] and len(c['columns']) == 1 ] self.assertIn('char', indexes) self.assertIn('char2', indexes) self.assertIn('text', indexes) call_command('migrate', 'postgres_tests', 'zero', verbosity=0) with connection.cursor() as cursor: self.assertNotIn(table_name, connection.introspection.table_names(cursor)) class TestSerialization(PostgreSQLSimpleTestCase): test_data = ( '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]' ) def test_dumping(self): instance = IntegerArrayModel(field=[1, 2, None]) data = serializers.serialize('json', [instance]) self.assertEqual(json.loads(data), json.loads(self.test_data)) def test_loading(self): instance = list(serializers.deserialize('json', self.test_data))[0].object self.assertEqual(instance.field, [1, 2, None]) class TestValidation(PostgreSQLSimpleTestCase): def test_unbounded(self): field = ArrayField(models.IntegerField()) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([1, None], None) self.assertEqual(cm.exception.code, 'item_invalid') self.assertEqual( cm.exception.message % cm.exception.params, 'Item 2 in the array did not validate: This field cannot be null.' ) def test_blank_true(self): field = ArrayField(models.IntegerField(blank=True, null=True)) # This should not raise a validation error field.clean([1, None], None) def test_with_size(self): field = ArrayField(models.IntegerField(), size=3) field.clean([1, 2, 3], None) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([1, 2, 3, 4], None) self.assertEqual(cm.exception.messages[0], 'List contains 4 items, it should contain no more than 3.') def test_nested_array_mismatch(self): field = ArrayField(ArrayField(models.IntegerField())) field.clean([[1, 2], [3, 4]], None) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([[1, 2], [3, 4, 5]], None) self.assertEqual(cm.exception.code, 'nested_array_mismatch') self.assertEqual(cm.exception.messages[0], 'Nested arrays must have the same length.') def test_with_base_field_error_params(self): field = ArrayField(models.CharField(max_length=2)) with self.assertRaises(exceptions.ValidationError) as cm: field.clean(['abc'], None) self.assertEqual(len(cm.exception.error_list), 1) exception = cm.exception.error_list[0] self.assertEqual( exception.message, 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).' ) self.assertEqual(exception.code, 'item_invalid') self.assertEqual(exception.params, {'nth': 1, 'value': 'abc', 'limit_value': 2, 'show_value': 3}) def test_with_validators(self): field = ArrayField(models.IntegerField(validators=[validators.MinValueValidator(1)])) field.clean([1, 2], None) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([0], None) self.assertEqual(len(cm.exception.error_list), 1) exception = cm.exception.error_list[0] self.assertEqual( exception.message, 'Item 1 in the array did not validate: Ensure this value is greater than or equal to 1.' ) self.assertEqual(exception.code, 'item_invalid') self.assertEqual(exception.params, {'nth': 1, 'value': 0, 'limit_value': 1, 'show_value': 0}) class TestSimpleFormField(PostgreSQLSimpleTestCase): def test_valid(self): field = SimpleArrayField(forms.CharField()) value = field.clean('a,b,c') self.assertEqual(value, ['a', 'b', 'c']) def test_to_python_fail(self): field = SimpleArrayField(forms.IntegerField()) with self.assertRaises(exceptions.ValidationError) as cm: field.clean('a,b,9') self.assertEqual(cm.exception.messages[0], 'Item 1 in the array did not validate: Enter a whole number.') def test_validate_fail(self): field = SimpleArrayField(forms.CharField(required=True)) with self.assertRaises(exceptions.ValidationError) as cm: field.clean('a,b,') self.assertEqual(cm.exception.messages[0], 'Item 3 in the array did not validate: This field is required.') def test_validate_fail_base_field_error_params(self): field = SimpleArrayField(forms.CharField(max_length=2)) with self.assertRaises(exceptions.ValidationError) as cm: field.clean('abc,c,defg') errors = cm.exception.error_list self.assertEqual(len(errors), 2) first_error = errors[0] self.assertEqual( first_error.message, 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).' ) self.assertEqual(first_error.code, 'item_invalid') self.assertEqual(first_error.params, {'nth': 1, 'value': 'abc', 'limit_value': 2, 'show_value': 3}) second_error = errors[1] self.assertEqual( second_error.message, 'Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).' ) self.assertEqual(second_error.code, 'item_invalid') self.assertEqual(second_error.params, {'nth': 3, 'value': 'defg', 'limit_value': 2, 'show_value': 4}) def test_validators_fail(self): field = SimpleArrayField(forms.RegexField('[a-e]{2}')) with self.assertRaises(exceptions.ValidationError) as cm: field.clean('a,bc,de') self.assertEqual(cm.exception.messages[0], 'Item 1 in the array did not validate: Enter a valid value.') def test_delimiter(self): field = SimpleArrayField(forms.CharField(), delimiter='|') value = field.clean('a|b|c') self.assertEqual(value, ['a', 'b', 'c']) def test_delimiter_with_nesting(self): field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter='|') value = field.clean('a,b|c,d') self.assertEqual(value, [['a', 'b'], ['c', 'd']]) def test_prepare_value(self): field = SimpleArrayField(forms.CharField()) value = field.prepare_value(['a', 'b', 'c']) self.assertEqual(value, 'a,b,c') def test_max_length(self): field = SimpleArrayField(forms.CharField(), max_length=2) with self.assertRaises(exceptions.ValidationError) as cm: field.clean('a,b,c') self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no more than 2.') def test_min_length(self): field = SimpleArrayField(forms.CharField(), min_length=4) with self.assertRaises(exceptions.ValidationError) as cm: field.clean('a,b,c') self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no fewer than 4.') def test_required(self): field = SimpleArrayField(forms.CharField(), required=True) with self.assertRaises(exceptions.ValidationError) as cm: field.clean('') self.assertEqual(cm.exception.messages[0], 'This field is required.') def test_model_field_formfield(self): model_field = ArrayField(models.CharField(max_length=27)) form_field = model_field.formfield() self.assertIsInstance(form_field, SimpleArrayField) self.assertIsInstance(form_field.base_field, forms.CharField) self.assertEqual(form_field.base_field.max_length, 27) def test_model_field_formfield_size(self): model_field = ArrayField(models.CharField(max_length=27), size=4) form_field = model_field.formfield() self.assertIsInstance(form_field, SimpleArrayField) self.assertEqual(form_field.max_length, 4) def test_model_field_choices(self): model_field = ArrayField(models.IntegerField(choices=((1, 'A'), (2, 'B')))) form_field = model_field.formfield() self.assertEqual(form_field.clean('1,2'), [1, 2]) def test_already_converted_value(self): field = SimpleArrayField(forms.CharField()) vals = ['a', 'b', 'c'] self.assertEqual(field.clean(vals), vals) def test_has_changed(self): field = SimpleArrayField(forms.IntegerField()) self.assertIs(field.has_changed([1, 2], [1, 2]), False) self.assertIs(field.has_changed([1, 2], '1,2'), False) self.assertIs(field.has_changed([1, 2], '1,2,3'), True) self.assertIs(field.has_changed([1, 2], 'a,b'), True) def test_has_changed_empty(self): field = SimpleArrayField(forms.CharField()) self.assertIs(field.has_changed(None, None), False) self.assertIs(field.has_changed(None, ''), False) self.assertIs(field.has_changed(None, []), False) self.assertIs(field.has_changed([], None), False) self.assertIs(field.has_changed([], ''), False) class TestSplitFormField(PostgreSQLSimpleTestCase): def test_valid(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), size=3) data = {'array_0': 'a', 'array_1': 'b', 'array_2': 'c'} form = SplitForm(data) self.assertTrue(form.is_valid()) self.assertEqual(form.cleaned_data, {'array': ['a', 'b', 'c']}) def test_required(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), required=True, size=3) data = {'array_0': '', 'array_1': '', 'array_2': ''} form = SplitForm(data) self.assertFalse(form.is_valid()) self.assertEqual(form.errors, {'array': ['This field is required.']}) def test_remove_trailing_nulls(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(required=False), size=5, remove_trailing_nulls=True) data = {'array_0': 'a', 'array_1': '', 'array_2': 'b', 'array_3': '', 'array_4': ''} form = SplitForm(data) self.assertTrue(form.is_valid(), form.errors) self.assertEqual(form.cleaned_data, {'array': ['a', '', 'b']}) def test_remove_trailing_nulls_not_required(self): class SplitForm(forms.Form): array = SplitArrayField( forms.CharField(required=False), size=2, remove_trailing_nulls=True, required=False, ) data = {'array_0': '', 'array_1': ''} form = SplitForm(data) self.assertTrue(form.is_valid()) self.assertEqual(form.cleaned_data, {'array': []}) def test_required_field(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), size=3) data = {'array_0': 'a', 'array_1': 'b', 'array_2': ''} form = SplitForm(data) self.assertFalse(form.is_valid()) self.assertEqual(form.errors, {'array': ['Item 3 in the array did not validate: This field is required.']}) def test_invalid_integer(self): msg = 'Item 2 in the array did not validate: Ensure this value is less than or equal to 100.' with self.assertRaisesMessage(exceptions.ValidationError, msg): SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101]) # To locate the widget's template. @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) def test_rendering(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), size=3) self.assertHTMLEqual(str(SplitForm()), ''' ''') def test_invalid_char_length(self): field = SplitArrayField(forms.CharField(max_length=2), size=3) with self.assertRaises(exceptions.ValidationError) as cm: field.clean(['abc', 'c', 'defg']) self.assertEqual(cm.exception.messages, [ 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).', 'Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).', ]) def test_splitarraywidget_value_omitted_from_data(self): class Form(forms.ModelForm): field = SplitArrayField(forms.IntegerField(), required=False, size=2) class Meta: model = IntegerArrayModel fields = ('field',) form = Form({'field_0': '1', 'field_1': '2'}) self.assertEqual(form.errors, {}) obj = form.save(commit=False) self.assertEqual(obj.field, [1, 2]) def test_splitarrayfield_has_changed(self): class Form(forms.ModelForm): field = SplitArrayField(forms.IntegerField(), required=False, size=2) class Meta: model = IntegerArrayModel fields = ('field',) obj = IntegerArrayModel(field=[1, 2]) form = Form({'field_0': '1', 'field_1': '2'}, instance=obj) self.assertFalse(form.has_changed()) class TestSplitFormWidget(PostgreSQLWidgetTestCase): def test_get_context(self): self.assertEqual( SplitArrayWidget(forms.TextInput(), size=2).get_context('name', ['val1', 'val2']), { 'widget': { 'name': 'name', 'is_hidden': False, 'required': False, 'value': "['val1', 'val2']", 'attrs': {}, 'template_name': 'postgres/widgets/split_array.html', 'subwidgets': [ { 'name': 'name_0', 'is_hidden': False, 'required': False, 'value': 'val1', 'attrs': {}, 'template_name': 'django/forms/widgets/text.html', 'type': 'text', }, { 'name': 'name_1', 'is_hidden': False, 'required': False, 'value': 'val2', 'attrs': {}, 'template_name': 'django/forms/widgets/text.html', 'type': 'text', }, ] } } ) def test_render(self): self.check_html( SplitArrayWidget(forms.TextInput(), size=2), 'array', None, """ """ ) def test_render_attrs(self): self.check_html( SplitArrayWidget(forms.TextInput(), size=2), 'array', ['val1', 'val2'], attrs={'id': 'foo'}, html=( """ """ ) ) def test_value_omitted_from_data(self): widget = SplitArrayWidget(forms.TextInput(), size=2) self.assertIs(widget.value_omitted_from_data({}, {}, 'field'), True) self.assertIs(widget.value_omitted_from_data({'field_0': 'value'}, {}, 'field'), False) self.assertIs(widget.value_omitted_from_data({'field_1': 'value'}, {}, 'field'), False) self.assertIs(widget.value_omitted_from_data({'field_0': 'value', 'field_1': 'value'}, {}, 'field'), False)