1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Fixed #24092 -- Widened base field support for ArrayField.

Several issues resolved here, following from a report that a base_field
of GenericIpAddressField was failing.

We were using get_prep_value instead of get_db_prep_value in ArrayField
which was bypassing any extra modifications to the value being made in
the base field's get_db_prep_value. Changing this broke datetime
support, so the postgres backend has gained the relevant operation
methods to send dates/times/datetimes directly to the db backend instead
of casting them to strings. Similarly, a new database feature has been
added allowing the uuid to be passed directly to the backend, as we do
with timedeltas.

On the other side, psycopg2 expects an Inet() instance for IP address
fields, so we add a value_to_db_ipaddress method to wrap the strings on
postgres. We also have to manually add a database adapter to psycopg2,
as we do not wish to use the built in adapter which would turn
everything into Inet() instances.

Thanks to smclenithan for the report.
This commit is contained in:
Marc Tamlyn
2015-01-10 18:13:28 +00:00
committed by Tim Graham
parent a17724b791
commit 39d95fb6ad
13 changed files with 111 additions and 23 deletions

View File

@@ -70,9 +70,9 @@ class ArrayField(Field):
size = self.size or ''
return '%s[%s]' % (self.base_field.db_type(connection), size)
def get_prep_value(self, value):
def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, list) or isinstance(value, tuple):
return [self.base_field.get_prep_value(i) for i in value]
return [self.base_field.get_db_prep_value(i, connection, prepared) for i in value]
return value
def deconstruct(self):

View File

@@ -59,6 +59,9 @@ class BaseDatabaseFeatures(object):
supports_subqueries_in_group_by = True
supports_bitwise_or = True
# Is there a true datatype for uuid?
has_native_uuid_field = False
# Is there a true datatype for timedeltas?
has_native_duration_field = False

View File

@@ -219,7 +219,7 @@ class BaseDatabaseOperations(object):
"""
return cursor.lastrowid
def lookup_cast(self, lookup_type):
def lookup_cast(self, lookup_type, internal_type=None):
"""
Returns the string to use in a query when performing lookups
("contains", "like", etc). The resulting string should contain a '%s'
@@ -442,7 +442,7 @@ class BaseDatabaseOperations(object):
def value_to_db_date(self, value):
"""
Transform a date value to an object compatible with what is expected
Transforms a date value to an object compatible with what is expected
by the backend driver for date columns.
"""
if value is None:
@@ -451,7 +451,7 @@ class BaseDatabaseOperations(object):
def value_to_db_datetime(self, value):
"""
Transform a datetime value to an object compatible with what is expected
Transforms a datetime value to an object compatible with what is expected
by the backend driver for datetime columns.
"""
if value is None:
@@ -460,7 +460,7 @@ class BaseDatabaseOperations(object):
def value_to_db_time(self, value):
"""
Transform a time value to an object compatible with what is expected
Transforms a time value to an object compatible with what is expected
by the backend driver for time columns.
"""
if value is None:
@@ -471,11 +471,18 @@ class BaseDatabaseOperations(object):
def value_to_db_decimal(self, value, max_digits, decimal_places):
"""
Transform a decimal.Decimal value to an object compatible with what is
Transforms a decimal.Decimal value to an object compatible with what is
expected by the backend driver for decimal (numeric) columns.
"""
return utils.format_number(value, max_digits, decimal_places)
def value_to_db_ipaddress(self, value):
"""
Transforms a string representation of an IP address into the expected
type for the backend driver.
"""
return value
def year_lookup_bounds_for_date_field(self, value):
"""
Returns a two-elements list with the lower and upper bound to be used

View File

@@ -246,7 +246,7 @@ WHEN (new.%(col_name)s IS NULL)
cursor.execute('SELECT "%s".currval FROM dual' % sq_name)
return cursor.fetchone()[0]
def lookup_cast(self, lookup_type):
def lookup_cast(self, lookup_type, internal_type=None):
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
return "UPPER(%s)"
return "%s"

View File

@@ -38,6 +38,16 @@ psycopg2.extensions.register_adapter(SafeBytes, psycopg2.extensions.QuotedString
psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString)
psycopg2.extras.register_uuid()
# Register support for inet[] manually so we don't have to handle the Inet()
# object on load all the time.
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
'INETARRAY',
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql'

View File

@@ -6,6 +6,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
needs_datetime_string_cast = False
can_return_id_from_insert = True
has_real_datatype = True
has_native_uuid_field = True
has_native_duration_field = True
driver_supports_timedelta_args = True
can_defer_constraint_checks = True

View File

@@ -3,6 +3,8 @@ from __future__ import unicode_literals
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from psycopg2.extras import Inet
class DatabaseOperations(BaseDatabaseOperations):
def unification_cast_sql(self, output_field):
@@ -57,13 +59,16 @@ class DatabaseOperations(BaseDatabaseOperations):
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def lookup_cast(self, lookup_type):
def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
lookup = "%s::text"
if internal_type in ('IPAddressField', 'GenericIPAddressField'):
lookup = "HOST(%s)"
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
@@ -71,11 +76,6 @@ class DatabaseOperations(BaseDatabaseOperations):
return lookup
def field_cast_sql(self, db_type, internal_type):
if internal_type == "GenericIPAddressField" or internal_type == "IPAddressField":
return 'HOST(%s)'
return '%s'
def last_insert_id(self, cursor, table_name, pk_name):
# Use pg_get_serial_sequence to get the underlying sequence name
# from the table name and column name (available since PostgreSQL 8)
@@ -224,3 +224,17 @@ class DatabaseOperations(BaseDatabaseOperations):
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
def value_to_db_date(self, value):
return value
def value_to_db_datetime(self, value):
return value
def value_to_db_time(self, value):
return value
def value_to_db_ipaddress(self, value):
if value:
return Inet(value)
return None

View File

@@ -1983,7 +1983,7 @@ class GenericIPAddressField(Field):
def get_db_prep_value(self, value, connection, prepared=False):
if not prepared:
value = self.get_prep_value(value)
return value or None
return connection.ops.value_to_db_ipaddress(value)
def get_prep_value(self, value):
value = super(GenericIPAddressField, self).get_prep_value(value)
@@ -2366,8 +2366,10 @@ class UUIDField(Field):
def get_internal_type(self):
return "UUIDField"
def get_prep_value(self, value):
def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, uuid.UUID):
if connection.features.has_native_uuid_field:
return value
return value.hex
if isinstance(value, six.string_types):
return value.replace('-', '')

View File

@@ -198,7 +198,7 @@ class BuiltinLookup(Lookup):
db_type = self.lhs.output_field.db_type(connection=connection)
lhs_sql = connection.ops.field_cast_sql(
db_type, field_internal_type) % lhs_sql
lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
return lhs_sql, params
def as_sql(self, compiler, connection):

View File

@@ -695,6 +695,11 @@ class GenericIPAddressFieldTests(test.TestCase):
o = GenericIPAddress.objects.get()
self.assertIsNone(o.ip)
def test_save_load(self):
instance = GenericIPAddress.objects.create(ip='::1')
loaded = GenericIPAddress.objects.get()
self.assertEqual(loaded.ip, instance.ip)
class PromiseTest(test.TestCase):
def test_AutoField(self):

View File

@@ -27,7 +27,9 @@ class Migration(migrations.Migration):
name='DateTimeArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)),
('datetimes', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)),
('dates', django.contrib.postgres.fields.ArrayField(models.DateField(), size=None)),
('times', django.contrib.postgres.fields.ArrayField(models.TimeField(), size=None)),
],
options={
},
@@ -43,6 +45,18 @@ class Migration(migrations.Migration):
},
bases=(models.Model,),
),
migrations.CreateModel(
name='OtherTypesArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('ips', django.contrib.postgres.fields.ArrayField(models.GenericIPAddressField(), size=None)),
('uuids', django.contrib.postgres.fields.ArrayField(models.UUIDField(), size=None)),
('decimals', django.contrib.postgres.fields.ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
],
options={
},
bases=(models.Model,),
),
migrations.CreateModel(
name='IntegerArrayModel',
fields=[

View File

@@ -18,13 +18,21 @@ class CharArrayModel(models.Model):
class DateTimeArrayModel(models.Model):
field = ArrayField(models.DateTimeField())
datetimes = ArrayField(models.DateTimeField())
dates = ArrayField(models.DateField())
times = ArrayField(models.TimeField())
class NestedIntegerArrayModel(models.Model):
field = ArrayField(ArrayField(models.IntegerField()))
class OtherTypesArrayModel(models.Model):
ips = ArrayField(models.GenericIPAddressField())
uuids = ArrayField(models.UUIDField())
decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
class HStoreModel(models.Model):
field = HStoreField(blank=True, null=True)

View File

@@ -1,5 +1,7 @@
import decimal
import json
import unittest
import uuid
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField
@@ -10,7 +12,11 @@ from django import forms
from django.test import TestCase, override_settings
from django.utils import timezone
from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel, ArrayFieldSubclass
from .models import (
IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel,
DateTimeArrayModel, NestedIntegerArrayModel, OtherTypesArrayModel,
ArrayFieldSubclass,
)
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
@@ -29,10 +35,16 @@ class TestSaveLoad(TestCase):
self.assertEqual(instance.field, loaded.field)
def test_dates(self):
instance = DateTimeArrayModel(field=[timezone.now()])
instance = DateTimeArrayModel(
datetimes=[timezone.now()],
dates=[timezone.now().date()],
times=[timezone.now().time()],
)
instance.save()
loaded = DateTimeArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
self.assertEqual(instance.datetimes, loaded.datetimes)
self.assertEqual(instance.dates, loaded.dates)
self.assertEqual(instance.times, loaded.times)
def test_tuples(self):
instance = IntegerArrayModel(field=(1,))
@@ -70,6 +82,18 @@ class TestSaveLoad(TestCase):
loaded = NestedIntegerArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
def test_other_array_types(self):
instance = OtherTypesArrayModel(
ips=['192.168.0.1', '::1'],
uuids=[uuid.uuid4()],
decimals=[decimal.Decimal(1.25), 1.75],
)
instance.save()
loaded = OtherTypesArrayModel.objects.get()
self.assertEqual(instance.ips, loaded.ips)
self.assertEqual(instance.uuids, loaded.uuids)
self.assertEqual(instance.decimals, loaded.decimals)
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
class TestQuerying(TestCase):