mirror of
				https://github.com/django/django.git
				synced 2025-10-24 14:16:09 +00:00 
			
		
		
		
	Fixed #24092 -- Widened base field support for ArrayField.
Several issues resolved here, following from a report that a base_field of GenericIpAddressField was failing. We were using get_prep_value instead of get_db_prep_value in ArrayField which was bypassing any extra modifications to the value being made in the base field's get_db_prep_value. Changing this broke datetime support, so the postgres backend has gained the relevant operation methods to send dates/times/datetimes directly to the db backend instead of casting them to strings. Similarly, a new database feature has been added allowing the uuid to be passed directly to the backend, as we do with timedeltas. On the other side, psycopg2 expects an Inet() instance for IP address fields, so we add a value_to_db_ipaddress method to wrap the strings on postgres. We also have to manually add a database adapter to psycopg2, as we do not wish to use the built in adapter which would turn everything into Inet() instances. Thanks to smclenithan for the report.
This commit is contained in:
		| @@ -70,9 +70,9 @@ class ArrayField(Field): | |||||||
|         size = self.size or '' |         size = self.size or '' | ||||||
|         return '%s[%s]' % (self.base_field.db_type(connection), size) |         return '%s[%s]' % (self.base_field.db_type(connection), size) | ||||||
|  |  | ||||||
|     def get_prep_value(self, value): |     def get_db_prep_value(self, value, connection, prepared=False): | ||||||
|         if isinstance(value, list) or isinstance(value, tuple): |         if isinstance(value, list) or isinstance(value, tuple): | ||||||
|             return [self.base_field.get_prep_value(i) for i in value] |             return [self.base_field.get_db_prep_value(i, connection, prepared) for i in value] | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def deconstruct(self): |     def deconstruct(self): | ||||||
|   | |||||||
| @@ -59,6 +59,9 @@ class BaseDatabaseFeatures(object): | |||||||
|     supports_subqueries_in_group_by = True |     supports_subqueries_in_group_by = True | ||||||
|     supports_bitwise_or = True |     supports_bitwise_or = True | ||||||
|  |  | ||||||
|  |     # Is there a true datatype for uuid? | ||||||
|  |     has_native_uuid_field = False | ||||||
|  |  | ||||||
|     # Is there a true datatype for timedeltas? |     # Is there a true datatype for timedeltas? | ||||||
|     has_native_duration_field = False |     has_native_duration_field = False | ||||||
|  |  | ||||||
|   | |||||||
| @@ -219,7 +219,7 @@ class BaseDatabaseOperations(object): | |||||||
|         """ |         """ | ||||||
|         return cursor.lastrowid |         return cursor.lastrowid | ||||||
|  |  | ||||||
|     def lookup_cast(self, lookup_type): |     def lookup_cast(self, lookup_type, internal_type=None): | ||||||
|         """ |         """ | ||||||
|         Returns the string to use in a query when performing lookups |         Returns the string to use in a query when performing lookups | ||||||
|         ("contains", "like", etc). The resulting string should contain a '%s' |         ("contains", "like", etc). The resulting string should contain a '%s' | ||||||
| @@ -442,7 +442,7 @@ class BaseDatabaseOperations(object): | |||||||
|  |  | ||||||
|     def value_to_db_date(self, value): |     def value_to_db_date(self, value): | ||||||
|         """ |         """ | ||||||
|         Transform a date value to an object compatible with what is expected |         Transforms a date value to an object compatible with what is expected | ||||||
|         by the backend driver for date columns. |         by the backend driver for date columns. | ||||||
|         """ |         """ | ||||||
|         if value is None: |         if value is None: | ||||||
| @@ -451,7 +451,7 @@ class BaseDatabaseOperations(object): | |||||||
|  |  | ||||||
|     def value_to_db_datetime(self, value): |     def value_to_db_datetime(self, value): | ||||||
|         """ |         """ | ||||||
|         Transform a datetime value to an object compatible with what is expected |         Transforms a datetime value to an object compatible with what is expected | ||||||
|         by the backend driver for datetime columns. |         by the backend driver for datetime columns. | ||||||
|         """ |         """ | ||||||
|         if value is None: |         if value is None: | ||||||
| @@ -460,7 +460,7 @@ class BaseDatabaseOperations(object): | |||||||
|  |  | ||||||
|     def value_to_db_time(self, value): |     def value_to_db_time(self, value): | ||||||
|         """ |         """ | ||||||
|         Transform a time value to an object compatible with what is expected |         Transforms a time value to an object compatible with what is expected | ||||||
|         by the backend driver for time columns. |         by the backend driver for time columns. | ||||||
|         """ |         """ | ||||||
|         if value is None: |         if value is None: | ||||||
| @@ -471,11 +471,18 @@ class BaseDatabaseOperations(object): | |||||||
|  |  | ||||||
|     def value_to_db_decimal(self, value, max_digits, decimal_places): |     def value_to_db_decimal(self, value, max_digits, decimal_places): | ||||||
|         """ |         """ | ||||||
|         Transform a decimal.Decimal value to an object compatible with what is |         Transforms a decimal.Decimal value to an object compatible with what is | ||||||
|         expected by the backend driver for decimal (numeric) columns. |         expected by the backend driver for decimal (numeric) columns. | ||||||
|         """ |         """ | ||||||
|         return utils.format_number(value, max_digits, decimal_places) |         return utils.format_number(value, max_digits, decimal_places) | ||||||
|  |  | ||||||
|  |     def value_to_db_ipaddress(self, value): | ||||||
|  |         """ | ||||||
|  |         Transforms a string representation of an IP address into the expected | ||||||
|  |         type for the backend driver. | ||||||
|  |         """ | ||||||
|  |         return value | ||||||
|  |  | ||||||
|     def year_lookup_bounds_for_date_field(self, value): |     def year_lookup_bounds_for_date_field(self, value): | ||||||
|         """ |         """ | ||||||
|         Returns a two-elements list with the lower and upper bound to be used |         Returns a two-elements list with the lower and upper bound to be used | ||||||
|   | |||||||
| @@ -246,7 +246,7 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|         cursor.execute('SELECT "%s".currval FROM dual' % sq_name) |         cursor.execute('SELECT "%s".currval FROM dual' % sq_name) | ||||||
|         return cursor.fetchone()[0] |         return cursor.fetchone()[0] | ||||||
|  |  | ||||||
|     def lookup_cast(self, lookup_type): |     def lookup_cast(self, lookup_type, internal_type=None): | ||||||
|         if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): |         if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): | ||||||
|             return "UPPER(%s)" |             return "UPPER(%s)" | ||||||
|         return "%s" |         return "%s" | ||||||
|   | |||||||
| @@ -38,6 +38,16 @@ psycopg2.extensions.register_adapter(SafeBytes, psycopg2.extensions.QuotedString | |||||||
| psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString) | psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString) | ||||||
| psycopg2.extras.register_uuid() | psycopg2.extras.register_uuid() | ||||||
|  |  | ||||||
|  | # Register support for inet[] manually so we don't have to handle the Inet() | ||||||
|  | # object on load all the time. | ||||||
|  | INETARRAY_OID = 1041 | ||||||
|  | INETARRAY = psycopg2.extensions.new_array_type( | ||||||
|  |     (INETARRAY_OID,), | ||||||
|  |     'INETARRAY', | ||||||
|  |     psycopg2.extensions.UNICODE, | ||||||
|  | ) | ||||||
|  | psycopg2.extensions.register_type(INETARRAY) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseWrapper(BaseDatabaseWrapper): | class DatabaseWrapper(BaseDatabaseWrapper): | ||||||
|     vendor = 'postgresql' |     vendor = 'postgresql' | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): | |||||||
|     needs_datetime_string_cast = False |     needs_datetime_string_cast = False | ||||||
|     can_return_id_from_insert = True |     can_return_id_from_insert = True | ||||||
|     has_real_datatype = True |     has_real_datatype = True | ||||||
|  |     has_native_uuid_field = True | ||||||
|     has_native_duration_field = True |     has_native_duration_field = True | ||||||
|     driver_supports_timedelta_args = True |     driver_supports_timedelta_args = True | ||||||
|     can_defer_constraint_checks = True |     can_defer_constraint_checks = True | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ from __future__ import unicode_literals | |||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db.backends.base.operations import BaseDatabaseOperations | from django.db.backends.base.operations import BaseDatabaseOperations | ||||||
|  |  | ||||||
|  | from psycopg2.extras import Inet | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseOperations(BaseDatabaseOperations): | class DatabaseOperations(BaseDatabaseOperations): | ||||||
|     def unification_cast_sql(self, output_field): |     def unification_cast_sql(self, output_field): | ||||||
| @@ -57,13 +59,16 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|     def deferrable_sql(self): |     def deferrable_sql(self): | ||||||
|         return " DEFERRABLE INITIALLY DEFERRED" |         return " DEFERRABLE INITIALLY DEFERRED" | ||||||
|  |  | ||||||
|     def lookup_cast(self, lookup_type): |     def lookup_cast(self, lookup_type, internal_type=None): | ||||||
|         lookup = '%s' |         lookup = '%s' | ||||||
|  |  | ||||||
|         # Cast text lookups to text to allow things like filter(x__contains=4) |         # Cast text lookups to text to allow things like filter(x__contains=4) | ||||||
|         if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', |         if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', | ||||||
|                            'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'): |                            'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'): | ||||||
|             lookup = "%s::text" |             if internal_type in ('IPAddressField', 'GenericIPAddressField'): | ||||||
|  |                 lookup = "HOST(%s)" | ||||||
|  |             else: | ||||||
|  |                 lookup = "%s::text" | ||||||
|  |  | ||||||
|         # Use UPPER(x) for case-insensitive lookups; it's faster. |         # Use UPPER(x) for case-insensitive lookups; it's faster. | ||||||
|         if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): |         if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): | ||||||
| @@ -71,11 +76,6 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|  |  | ||||||
|         return lookup |         return lookup | ||||||
|  |  | ||||||
|     def field_cast_sql(self, db_type, internal_type): |  | ||||||
|         if internal_type == "GenericIPAddressField" or internal_type == "IPAddressField": |  | ||||||
|             return 'HOST(%s)' |  | ||||||
|         return '%s' |  | ||||||
|  |  | ||||||
|     def last_insert_id(self, cursor, table_name, pk_name): |     def last_insert_id(self, cursor, table_name, pk_name): | ||||||
|         # Use pg_get_serial_sequence to get the underlying sequence name |         # Use pg_get_serial_sequence to get the underlying sequence name | ||||||
|         # from the table name and column name (available since PostgreSQL 8) |         # from the table name and column name (available since PostgreSQL 8) | ||||||
| @@ -224,3 +224,17 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|     def bulk_insert_sql(self, fields, num_values): |     def bulk_insert_sql(self, fields, num_values): | ||||||
|         items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) |         items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) | ||||||
|         return "VALUES " + ", ".join([items_sql] * num_values) |         return "VALUES " + ", ".join([items_sql] * num_values) | ||||||
|  |  | ||||||
|  |     def value_to_db_date(self, value): | ||||||
|  |         return value | ||||||
|  |  | ||||||
|  |     def value_to_db_datetime(self, value): | ||||||
|  |         return value | ||||||
|  |  | ||||||
|  |     def value_to_db_time(self, value): | ||||||
|  |         return value | ||||||
|  |  | ||||||
|  |     def value_to_db_ipaddress(self, value): | ||||||
|  |         if value: | ||||||
|  |             return Inet(value) | ||||||
|  |         return None | ||||||
|   | |||||||
| @@ -1983,7 +1983,7 @@ class GenericIPAddressField(Field): | |||||||
|     def get_db_prep_value(self, value, connection, prepared=False): |     def get_db_prep_value(self, value, connection, prepared=False): | ||||||
|         if not prepared: |         if not prepared: | ||||||
|             value = self.get_prep_value(value) |             value = self.get_prep_value(value) | ||||||
|         return value or None |         return connection.ops.value_to_db_ipaddress(value) | ||||||
|  |  | ||||||
|     def get_prep_value(self, value): |     def get_prep_value(self, value): | ||||||
|         value = super(GenericIPAddressField, self).get_prep_value(value) |         value = super(GenericIPAddressField, self).get_prep_value(value) | ||||||
| @@ -2366,8 +2366,10 @@ class UUIDField(Field): | |||||||
|     def get_internal_type(self): |     def get_internal_type(self): | ||||||
|         return "UUIDField" |         return "UUIDField" | ||||||
|  |  | ||||||
|     def get_prep_value(self, value): |     def get_db_prep_value(self, value, connection, prepared=False): | ||||||
|         if isinstance(value, uuid.UUID): |         if isinstance(value, uuid.UUID): | ||||||
|  |             if connection.features.has_native_uuid_field: | ||||||
|  |                 return value | ||||||
|             return value.hex |             return value.hex | ||||||
|         if isinstance(value, six.string_types): |         if isinstance(value, six.string_types): | ||||||
|             return value.replace('-', '') |             return value.replace('-', '') | ||||||
|   | |||||||
| @@ -198,7 +198,7 @@ class BuiltinLookup(Lookup): | |||||||
|         db_type = self.lhs.output_field.db_type(connection=connection) |         db_type = self.lhs.output_field.db_type(connection=connection) | ||||||
|         lhs_sql = connection.ops.field_cast_sql( |         lhs_sql = connection.ops.field_cast_sql( | ||||||
|             db_type, field_internal_type) % lhs_sql |             db_type, field_internal_type) % lhs_sql | ||||||
|         lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql |         lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql | ||||||
|         return lhs_sql, params |         return lhs_sql, params | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|   | |||||||
| @@ -695,6 +695,11 @@ class GenericIPAddressFieldTests(test.TestCase): | |||||||
|         o = GenericIPAddress.objects.get() |         o = GenericIPAddress.objects.get() | ||||||
|         self.assertIsNone(o.ip) |         self.assertIsNone(o.ip) | ||||||
|  |  | ||||||
|  |     def test_save_load(self): | ||||||
|  |         instance = GenericIPAddress.objects.create(ip='::1') | ||||||
|  |         loaded = GenericIPAddress.objects.get() | ||||||
|  |         self.assertEqual(loaded.ip, instance.ip) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PromiseTest(test.TestCase): | class PromiseTest(test.TestCase): | ||||||
|     def test_AutoField(self): |     def test_AutoField(self): | ||||||
|   | |||||||
| @@ -27,7 +27,9 @@ class Migration(migrations.Migration): | |||||||
|             name='DateTimeArrayModel', |             name='DateTimeArrayModel', | ||||||
|             fields=[ |             fields=[ | ||||||
|                 ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), |                 ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), | ||||||
|                 ('field', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), |                 ('datetimes', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), | ||||||
|  |                 ('dates', django.contrib.postgres.fields.ArrayField(models.DateField(), size=None)), | ||||||
|  |                 ('times', django.contrib.postgres.fields.ArrayField(models.TimeField(), size=None)), | ||||||
|             ], |             ], | ||||||
|             options={ |             options={ | ||||||
|             }, |             }, | ||||||
| @@ -43,6 +45,18 @@ class Migration(migrations.Migration): | |||||||
|             }, |             }, | ||||||
|             bases=(models.Model,), |             bases=(models.Model,), | ||||||
|         ), |         ), | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             name='OtherTypesArrayModel', | ||||||
|  |             fields=[ | ||||||
|  |                 ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), | ||||||
|  |                 ('ips', django.contrib.postgres.fields.ArrayField(models.GenericIPAddressField(), size=None)), | ||||||
|  |                 ('uuids', django.contrib.postgres.fields.ArrayField(models.UUIDField(), size=None)), | ||||||
|  |                 ('decimals', django.contrib.postgres.fields.ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), | ||||||
|  |             ], | ||||||
|  |             options={ | ||||||
|  |             }, | ||||||
|  |             bases=(models.Model,), | ||||||
|  |         ), | ||||||
|         migrations.CreateModel( |         migrations.CreateModel( | ||||||
|             name='IntegerArrayModel', |             name='IntegerArrayModel', | ||||||
|             fields=[ |             fields=[ | ||||||
|   | |||||||
| @@ -18,13 +18,21 @@ class CharArrayModel(models.Model): | |||||||
|  |  | ||||||
|  |  | ||||||
| class DateTimeArrayModel(models.Model): | class DateTimeArrayModel(models.Model): | ||||||
|     field = ArrayField(models.DateTimeField()) |     datetimes = ArrayField(models.DateTimeField()) | ||||||
|  |     dates = ArrayField(models.DateField()) | ||||||
|  |     times = ArrayField(models.TimeField()) | ||||||
|  |  | ||||||
|  |  | ||||||
| class NestedIntegerArrayModel(models.Model): | class NestedIntegerArrayModel(models.Model): | ||||||
|     field = ArrayField(ArrayField(models.IntegerField())) |     field = ArrayField(ArrayField(models.IntegerField())) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OtherTypesArrayModel(models.Model): | ||||||
|  |     ips = ArrayField(models.GenericIPAddressField()) | ||||||
|  |     uuids = ArrayField(models.UUIDField()) | ||||||
|  |     decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class HStoreModel(models.Model): | class HStoreModel(models.Model): | ||||||
|     field = HStoreField(blank=True, null=True) |     field = HStoreField(blank=True, null=True) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
|  | import decimal | ||||||
| import json | import json | ||||||
| import unittest | import unittest | ||||||
|  | import uuid | ||||||
|  |  | ||||||
| from django.contrib.postgres.fields import ArrayField | from django.contrib.postgres.fields import ArrayField | ||||||
| from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField | from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField | ||||||
| @@ -10,7 +12,11 @@ from django import forms | |||||||
| from django.test import TestCase, override_settings | from django.test import TestCase, override_settings | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
|  |  | ||||||
| from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel, ArrayFieldSubclass | from .models import ( | ||||||
|  |     IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, | ||||||
|  |     DateTimeArrayModel, NestedIntegerArrayModel, OtherTypesArrayModel, | ||||||
|  |     ArrayFieldSubclass, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') | @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') | ||||||
| @@ -29,10 +35,16 @@ class TestSaveLoad(TestCase): | |||||||
|         self.assertEqual(instance.field, loaded.field) |         self.assertEqual(instance.field, loaded.field) | ||||||
|  |  | ||||||
|     def test_dates(self): |     def test_dates(self): | ||||||
|         instance = DateTimeArrayModel(field=[timezone.now()]) |         instance = DateTimeArrayModel( | ||||||
|  |             datetimes=[timezone.now()], | ||||||
|  |             dates=[timezone.now().date()], | ||||||
|  |             times=[timezone.now().time()], | ||||||
|  |         ) | ||||||
|         instance.save() |         instance.save() | ||||||
|         loaded = DateTimeArrayModel.objects.get() |         loaded = DateTimeArrayModel.objects.get() | ||||||
|         self.assertEqual(instance.field, loaded.field) |         self.assertEqual(instance.datetimes, loaded.datetimes) | ||||||
|  |         self.assertEqual(instance.dates, loaded.dates) | ||||||
|  |         self.assertEqual(instance.times, loaded.times) | ||||||
|  |  | ||||||
|     def test_tuples(self): |     def test_tuples(self): | ||||||
|         instance = IntegerArrayModel(field=(1,)) |         instance = IntegerArrayModel(field=(1,)) | ||||||
| @@ -70,6 +82,18 @@ class TestSaveLoad(TestCase): | |||||||
|         loaded = NestedIntegerArrayModel.objects.get() |         loaded = NestedIntegerArrayModel.objects.get() | ||||||
|         self.assertEqual(instance.field, loaded.field) |         self.assertEqual(instance.field, loaded.field) | ||||||
|  |  | ||||||
|  |     def test_other_array_types(self): | ||||||
|  |         instance = OtherTypesArrayModel( | ||||||
|  |             ips=['192.168.0.1', '::1'], | ||||||
|  |             uuids=[uuid.uuid4()], | ||||||
|  |             decimals=[decimal.Decimal(1.25), 1.75], | ||||||
|  |         ) | ||||||
|  |         instance.save() | ||||||
|  |         loaded = OtherTypesArrayModel.objects.get() | ||||||
|  |         self.assertEqual(instance.ips, loaded.ips) | ||||||
|  |         self.assertEqual(instance.uuids, loaded.uuids) | ||||||
|  |         self.assertEqual(instance.decimals, loaded.decimals) | ||||||
|  |  | ||||||
|  |  | ||||||
| @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') | @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') | ||||||
| class TestQuerying(TestCase): | class TestQuerying(TestCase): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user