1
0
mirror of https://github.com/django/django.git synced 2025-03-31 19:46:42 +00:00

Fixed #27147 -- Allowed specifying bounds of tuple inputs for non-discrete range fields.

This commit is contained in:
Guilherme Martins Crocetti 2021-06-17 18:13:49 -03:00 committed by Mariusz Felisiak
parent 52f6927d7f
commit fc565cb539
8 changed files with 181 additions and 13 deletions

View File

@ -44,6 +44,10 @@ class RangeField(models.Field):
empty_strings_allowed = False empty_strings_allowed = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'default_bounds' in kwargs:
raise TypeError(
f"Cannot use 'default_bounds' with {self.__class__.__name__}."
)
# Initializing base_field here ensures that its model matches the model for self. # Initializing base_field here ensures that its model matches the model for self.
if hasattr(self, 'base_field'): if hasattr(self, 'base_field'):
self.base_field = self.base_field() self.base_field = self.base_field()
@ -112,6 +116,37 @@ class RangeField(models.Field):
return super().formfield(**kwargs) return super().formfield(**kwargs)
CANONICAL_RANGE_BOUNDS = '[)'
class ContinuousRangeField(RangeField):
"""
Continuous range field. It allows specifying default bounds for list and
tuple inputs.
"""
def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs):
if default_bounds not in ('[)', '(]', '()', '[]'):
raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.")
self.default_bounds = default_bounds
super().__init__(*args, **kwargs)
def get_prep_value(self, value):
if isinstance(value, (list, tuple)):
return self.range_type(value[0], value[1], self.default_bounds)
return super().get_prep_value(value)
def formfield(self, **kwargs):
kwargs.setdefault('default_bounds', self.default_bounds)
return super().formfield(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS:
kwargs['default_bounds'] = self.default_bounds
return name, path, args, kwargs
class IntegerRangeField(RangeField): class IntegerRangeField(RangeField):
base_field = models.IntegerField base_field = models.IntegerField
range_type = NumericRange range_type = NumericRange
@ -130,7 +165,7 @@ class BigIntegerRangeField(RangeField):
return 'int8range' return 'int8range'
class DecimalRangeField(RangeField): class DecimalRangeField(ContinuousRangeField):
base_field = models.DecimalField base_field = models.DecimalField
range_type = NumericRange range_type = NumericRange
form_field = forms.DecimalRangeField form_field = forms.DecimalRangeField
@ -139,7 +174,7 @@ class DecimalRangeField(RangeField):
return 'numrange' return 'numrange'
class DateTimeRangeField(RangeField): class DateTimeRangeField(ContinuousRangeField):
base_field = models.DateTimeField base_field = models.DateTimeField
range_type = DateTimeTZRange range_type = DateTimeTZRange
form_field = forms.DateTimeRangeField form_field = forms.DateTimeRangeField

View File

@ -42,6 +42,9 @@ class BaseRangeField(forms.MultiValueField):
kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)] kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)]
kwargs.setdefault('required', False) kwargs.setdefault('required', False)
kwargs.setdefault('require_all_fields', False) kwargs.setdefault('require_all_fields', False)
self.range_kwargs = {}
if default_bounds := kwargs.pop('default_bounds', None):
self.range_kwargs = {'bounds': default_bounds}
super().__init__(**kwargs) super().__init__(**kwargs)
def prepare_value(self, value): def prepare_value(self, value):
@ -68,7 +71,7 @@ class BaseRangeField(forms.MultiValueField):
code='bound_ordering', code='bound_ordering',
) )
try: try:
range_value = self.range_type(lower, upper) range_value = self.range_type(lower, upper, **self.range_kwargs)
except TypeError: except TypeError:
raise exceptions.ValidationError( raise exceptions.ValidationError(
self.error_messages['invalid'], self.error_messages['invalid'],

View File

@ -503,9 +503,9 @@ All of the range fields translate to :ref:`psycopg2 Range objects
<psycopg2:adapt-range>` in Python, but also accept tuples as input if no bounds <psycopg2:adapt-range>` in Python, but also accept tuples as input if no bounds
information is necessary. The default is lower bound included, upper bound information is necessary. The default is lower bound included, upper bound
excluded, that is ``[)`` (see the PostgreSQL documentation for details about excluded, that is ``[)`` (see the PostgreSQL documentation for details about
`different bounds`_). `different bounds`_). The default bounds can be changed for non-discrete range
fields (:class:`.DateTimeRangeField` and :class:`.DecimalRangeField`) by using
.. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO the ``default_bounds`` argument.
``IntegerRangeField`` ``IntegerRangeField``
--------------------- ---------------------
@ -538,23 +538,43 @@ excluded, that is ``[)`` (see the PostgreSQL documentation for details about
``DecimalRangeField`` ``DecimalRangeField``
--------------------- ---------------------
.. class:: DecimalRangeField(**options) .. class:: DecimalRangeField(default_bounds='[)', **options)
Stores a range of floating point values. Based on a Stores a range of floating point values. Based on a
:class:`~django.db.models.DecimalField`. Represented by a ``numrange`` in :class:`~django.db.models.DecimalField`. Represented by a ``numrange`` in
the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in
Python. Python.
.. attribute:: DecimalRangeField.default_bounds
.. versionadded:: 4.1
Optional. The value of ``bounds`` for list and tuple inputs. The
default is lower bound included, upper bound excluded, that is ``[)``
(see the PostgreSQL documentation for details about
`different bounds`_). ``default_bounds`` is not used for
:class:`~psycopg2:psycopg2.extras.NumericRange` inputs.
``DateTimeRangeField`` ``DateTimeRangeField``
---------------------- ----------------------
.. class:: DateTimeRangeField(**options) .. class:: DateTimeRangeField(default_bounds='[)', **options)
Stores a range of timestamps. Based on a Stores a range of timestamps. Based on a
:class:`~django.db.models.DateTimeField`. Represented by a ``tstzrange`` in :class:`~django.db.models.DateTimeField`. Represented by a ``tstzrange`` in
the database and a :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` in the database and a :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` in
Python. Python.
.. attribute:: DateTimeRangeField.default_bounds
.. versionadded:: 4.1
Optional. The value of ``bounds`` for list and tuple inputs. The
default is lower bound included, upper bound excluded, that is ``[)``
(see the PostgreSQL documentation for details about
`different bounds`_). ``default_bounds`` is not used for
:class:`~psycopg2:psycopg2.extras.DateTimeTZRange` inputs.
``DateRangeField`` ``DateRangeField``
------------------ ------------------
@ -884,3 +904,5 @@ used with a custom range functions that expected boundaries, for example to
define :class:`~django.contrib.postgres.constraints.ExclusionConstraint`. See define :class:`~django.contrib.postgres.constraints.ExclusionConstraint`. See
`the PostgreSQL documentation for the full details <https://www.postgresql.org/ `the PostgreSQL documentation for the full details <https://www.postgresql.org/
docs/current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_. docs/current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_.
.. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO

View File

@ -76,6 +76,12 @@ Minor features
supports covering exclusion constraints using SP-GiST indexes on PostgreSQL supports covering exclusion constraints using SP-GiST indexes on PostgreSQL
14+. 14+.
* The new ``default_bounds`` attribute of :attr:`DateTimeRangeField
<django.contrib.postgres.fields.DateTimeRangeField.default_bounds>` and
:attr:`DecimalRangeField
<django.contrib.postgres.fields.DecimalRangeField.default_bounds>` allows
specifying bounds for list and tuple inputs.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -26,14 +26,23 @@ except ImportError:
}) })
return name, path, args, kwargs return name, path, args, kwargs
class DummyContinuousRangeField(models.Field):
def __init__(self, *args, default_bounds='[)', **kwargs):
super().__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs['default_bounds'] = '[)'
return name, path, args, kwargs
ArrayField = DummyArrayField ArrayField = DummyArrayField
BigIntegerRangeField = models.Field BigIntegerRangeField = models.Field
CICharField = models.Field CICharField = models.Field
CIEmailField = models.Field CIEmailField = models.Field
CITextField = models.Field CITextField = models.Field
DateRangeField = models.Field DateRangeField = models.Field
DateTimeRangeField = models.Field DateTimeRangeField = DummyContinuousRangeField
DecimalRangeField = models.Field DecimalRangeField = DummyContinuousRangeField
HStoreField = models.Field HStoreField = models.Field
IntegerRangeField = models.Field IntegerRangeField = models.Field
SearchVector = models.Expression SearchVector = models.Expression

View File

@ -249,6 +249,7 @@ class Migration(migrations.Migration):
('decimals', DecimalRangeField(null=True, blank=True)), ('decimals', DecimalRangeField(null=True, blank=True)),
('timestamps', DateTimeRangeField(null=True, blank=True)), ('timestamps', DateTimeRangeField(null=True, blank=True)),
('timestamps_inner', DateTimeRangeField(null=True, blank=True)), ('timestamps_inner', DateTimeRangeField(null=True, blank=True)),
('timestamps_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')),
('dates', DateRangeField(null=True, blank=True)), ('dates', DateRangeField(null=True, blank=True)),
('dates_inner', DateRangeField(null=True, blank=True)), ('dates_inner', DateRangeField(null=True, blank=True)),
], ],

View File

@ -135,6 +135,9 @@ class RangesModel(PostgreSQLModel):
decimals = DecimalRangeField(blank=True, null=True) decimals = DecimalRangeField(blank=True, null=True)
timestamps = DateTimeRangeField(blank=True, null=True) timestamps = DateTimeRangeField(blank=True, null=True)
timestamps_inner = DateTimeRangeField(blank=True, null=True) timestamps_inner = DateTimeRangeField(blank=True, null=True)
timestamps_closed_bounds = DateTimeRangeField(
blank=True, null=True, default_bounds='[]',
)
dates = DateRangeField(blank=True, null=True) dates = DateRangeField(blank=True, null=True)
dates_inner = DateRangeField(blank=True, null=True) dates_inner = DateRangeField(blank=True, null=True)

View File

@ -50,6 +50,41 @@ class BasicTests(PostgreSQLSimpleTestCase):
instance = Model(field=value) instance = Model(field=value)
self.assertEqual(instance.get_field_display(), display) self.assertEqual(instance.get_field_display(), display)
def test_discrete_range_fields_unsupported_default_bounds(self):
discrete_range_types = [
pg_fields.BigIntegerRangeField,
pg_fields.IntegerRangeField,
pg_fields.DateRangeField,
]
for field_type in discrete_range_types:
msg = f"Cannot use 'default_bounds' with {field_type.__name__}."
with self.assertRaisesMessage(TypeError, msg):
field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
def test_continuous_range_fields_default_bounds(self):
continuous_range_types = [
pg_fields.DecimalRangeField,
pg_fields.DateTimeRangeField,
]
for field_type in continuous_range_types:
field = field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
self.assertEqual(field.default_bounds, '[]')
def test_invalid_default_bounds(self):
tests = [')]', ')[', '](', '])', '([', '[(', 'x', '', None]
msg = "default_bounds must be one of '[)', '(]', '()', or '[]'."
for invalid_bounds in tests:
with self.assertRaisesMessage(ValueError, msg):
pg_fields.DecimalRangeField(default_bounds=invalid_bounds)
def test_deconstruct(self):
field = pg_fields.DecimalRangeField()
*_, kwargs = field.deconstruct()
self.assertEqual(kwargs, {})
field = pg_fields.DecimalRangeField(default_bounds='[]')
*_, kwargs = field.deconstruct()
self.assertEqual(kwargs, {'default_bounds': '[]'})
class TestSaveLoad(PostgreSQLTestCase): class TestSaveLoad(PostgreSQLTestCase):
@ -83,6 +118,19 @@ class TestSaveLoad(PostgreSQLTestCase):
loaded = RangesModel.objects.get() loaded = RangesModel.objects.get()
self.assertEqual(NumericRange(0, 10), loaded.ints) self.assertEqual(NumericRange(0, 10), loaded.ints)
def test_tuple_range_with_default_bounds(self):
range_ = (timezone.now(), timezone.now() + datetime.timedelta(hours=1))
RangesModel.objects.create(timestamps_closed_bounds=range_, timestamps=range_)
loaded = RangesModel.objects.get()
self.assertEqual(
loaded.timestamps_closed_bounds,
DateTimeTZRange(range_[0], range_[1], '[]'),
)
self.assertEqual(
loaded.timestamps,
DateTimeTZRange(range_[0], range_[1], '[)'),
)
def test_range_object_boundaries(self): def test_range_object_boundaries(self):
r = NumericRange(0, 10, '[]') r = NumericRange(0, 10, '[]')
instance = RangesModel(decimals=r) instance = RangesModel(decimals=r)
@ -91,6 +139,16 @@ class TestSaveLoad(PostgreSQLTestCase):
self.assertEqual(r, loaded.decimals) self.assertEqual(r, loaded.decimals)
self.assertIn(10, loaded.decimals) self.assertIn(10, loaded.decimals)
def test_range_object_boundaries_range_with_default_bounds(self):
range_ = DateTimeTZRange(
timezone.now(),
timezone.now() + datetime.timedelta(hours=1),
bounds='()',
)
RangesModel.objects.create(timestamps_closed_bounds=range_)
loaded = RangesModel.objects.get()
self.assertEqual(loaded.timestamps_closed_bounds, range_)
def test_unbounded(self): def test_unbounded(self):
r = NumericRange(None, None, '()') r = NumericRange(None, None, '()')
instance = RangesModel(decimals=r) instance = RangesModel(decimals=r)
@ -478,6 +536,8 @@ class TestSerialization(PostgreSQLSimpleTestCase):
'"bigints": null, "timestamps": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", ' '"bigints": null, "timestamps": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
'\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[)\\"}", ' '\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[)\\"}", '
'"timestamps_inner": null, ' '"timestamps_inner": null, '
'"timestamps_closed_bounds": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
'\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"()\\"}", '
'"dates": "{\\"upper\\": \\"2014-02-02\\", \\"lower\\": \\"2014-01-01\\", \\"bounds\\": \\"[)\\"}", ' '"dates": "{\\"upper\\": \\"2014-02-02\\", \\"lower\\": \\"2014-01-01\\", \\"bounds\\": \\"[)\\"}", '
'"dates_inner": null }, ' '"dates_inner": null }, '
'"model": "postgres_tests.rangesmodel", "pk": null}]' '"model": "postgres_tests.rangesmodel", "pk": null}]'
@ -492,15 +552,19 @@ class TestSerialization(PostgreSQLSimpleTestCase):
instance = RangesModel( instance = RangesModel(
ints=NumericRange(0, 10), decimals=NumericRange(empty=True), ints=NumericRange(0, 10), decimals=NumericRange(empty=True),
timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt), timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt),
timestamps_closed_bounds=DateTimeTZRange(
self.lower_dt, self.upper_dt, bounds='()',
),
dates=DateRange(self.lower_date, self.upper_date), dates=DateRange(self.lower_date, self.upper_date),
) )
data = serializers.serialize('json', [instance]) data = serializers.serialize('json', [instance])
dumped = json.loads(data) dumped = json.loads(data)
for field in ('ints', 'dates', 'timestamps'): for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
dumped[0]['fields'][field] = json.loads(dumped[0]['fields'][field]) dumped[0]['fields'][field] = json.loads(dumped[0]['fields'][field])
check = json.loads(self.test_data) check = json.loads(self.test_data)
for field in ('ints', 'dates', 'timestamps'): for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
check[0]['fields'][field] = json.loads(check[0]['fields'][field]) check[0]['fields'][field] = json.loads(check[0]['fields'][field])
self.assertEqual(dumped, check) self.assertEqual(dumped, check)
def test_loading(self): def test_loading(self):
@ -510,6 +574,10 @@ class TestSerialization(PostgreSQLSimpleTestCase):
self.assertIsNone(instance.bigints) self.assertIsNone(instance.bigints)
self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date)) self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date))
self.assertEqual(instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt)) self.assertEqual(instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt))
self.assertEqual(
instance.timestamps_closed_bounds,
DateTimeTZRange(self.lower_dt, self.upper_dt, bounds='()'),
)
def test_serialize_range_with_null(self): def test_serialize_range_with_null(self):
instance = RangesModel(ints=NumericRange(None, 10)) instance = RangesModel(ints=NumericRange(None, 10))
@ -886,26 +954,47 @@ class TestFormField(PostgreSQLSimpleTestCase):
model_field = pg_fields.IntegerRangeField() model_field = pg_fields.IntegerRangeField()
form_field = model_field.formfield() form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.IntegerRangeField) self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
self.assertEqual(form_field.range_kwargs, {})
def test_model_field_formfield_biginteger(self): def test_model_field_formfield_biginteger(self):
model_field = pg_fields.BigIntegerRangeField() model_field = pg_fields.BigIntegerRangeField()
form_field = model_field.formfield() form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.IntegerRangeField) self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
self.assertEqual(form_field.range_kwargs, {})
def test_model_field_formfield_float(self): def test_model_field_formfield_float(self):
model_field = pg_fields.DecimalRangeField() model_field = pg_fields.DecimalRangeField(default_bounds='()')
form_field = model_field.formfield() form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DecimalRangeField) self.assertIsInstance(form_field, pg_forms.DecimalRangeField)
self.assertEqual(form_field.range_kwargs, {'bounds': '()'})
def test_model_field_formfield_date(self): def test_model_field_formfield_date(self):
model_field = pg_fields.DateRangeField() model_field = pg_fields.DateRangeField()
form_field = model_field.formfield() form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateRangeField) self.assertIsInstance(form_field, pg_forms.DateRangeField)
self.assertEqual(form_field.range_kwargs, {})
def test_model_field_formfield_datetime(self): def test_model_field_formfield_datetime(self):
model_field = pg_fields.DateTimeRangeField() model_field = pg_fields.DateTimeRangeField()
form_field = model_field.formfield() form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField) self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
self.assertEqual(
form_field.range_kwargs,
{'bounds': pg_fields.ranges.CANONICAL_RANGE_BOUNDS},
)
def test_model_field_formfield_datetime_default_bounds(self):
model_field = pg_fields.DateTimeRangeField(default_bounds='[]')
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
self.assertEqual(form_field.range_kwargs, {'bounds': '[]'})
def test_model_field_with_default_bounds(self):
field = pg_forms.DateTimeRangeField(default_bounds='[]')
value = field.clean(['2014-01-01 00:00:00', '2014-02-03 12:13:14'])
lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
upper = datetime.datetime(2014, 2, 3, 12, 13, 14)
self.assertEqual(value, DateTimeTZRange(lower, upper, '[]'))
def test_has_changed(self): def test_has_changed(self):
for field, value in ( for field, value in (