From 7bda2d8ebc7833747363ac837fecb6535c817dcd Mon Sep 17 00:00:00 2001
From: Marc Tamlyn <marc.tamlyn@gmail.com>
Date: Thu, 21 May 2015 20:55:50 +0930
Subject: [PATCH] Fixed #24837 -- field__contained_by=Range

Provide `contained_by` lookups for the equivalent single valued fields
related to the range field types. This acts as the opposite direction to
rangefield__contains.

With thanks to schinckel for the idea and initial tests.
---
 django/contrib/postgres/fields/ranges.py      | 32 +++++++
 docs/ref/contrib/postgres/fields.txt          | 24 ++++-
 docs/releases/1.9.txt                         |  2 +
 .../migrations/0002_create_test_models.py     | 15 +++
 tests/postgres_tests/models.py                | 12 +++
 tests/postgres_tests/test_ranges.py           | 93 ++++++++++++++++++-
 6 files changed, 175 insertions(+), 3 deletions(-)

diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py
index 679e87f44e..6e0f8e2284 100644
--- a/django/contrib/postgres/fields/ranges.py
+++ b/django/contrib/postgres/fields/ranges.py
@@ -98,6 +98,38 @@ RangeField.register_lookup(lookups.ContainedBy)
 RangeField.register_lookup(lookups.Overlap)
 
 
+class RangeContainedBy(models.Lookup):
+    lookup_name = 'contained_by'
+    type_mapping = {
+        'integer': 'int4range',
+        'bigint': 'int8range',
+        'double precision': 'numrange',
+        'date': 'daterange',
+        'timestamp with time zone': 'tstzrange',
+    }
+
+    def as_sql(self, qn, connection):
+        field = self.lhs.output_field
+        if isinstance(field, models.FloatField):
+            sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
+        else:
+            sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return sql % (lhs, rhs), params
+
+    def get_prep_lookup(self):
+        return RangeField().get_prep_lookup(self.lookup_name, self.rhs)
+
+
+models.DateField.register_lookup(RangeContainedBy)
+models.DateTimeField.register_lookup(RangeContainedBy)
+models.IntegerField.register_lookup(RangeContainedBy)
+models.BigIntegerField.register_lookup(RangeContainedBy)
+models.FloatField.register_lookup(RangeContainedBy)
+
+
 @RangeField.register_lookup
 class FullyLessThan(lookups.PostgresSimpleLookup):
     lookup_name = 'fully_lt'
diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt
index 4b934c97fd..5ca8177eb5 100644
--- a/docs/ref/contrib/postgres/fields.txt
+++ b/docs/ref/contrib/postgres/fields.txt
@@ -631,14 +631,18 @@ model::
     class Event(models.Model):
         name = models.CharField(max_length=200)
         ages = IntegerRangeField()
+        start = models.DateTimeField()
 
         def __str__(self):  # __unicode__ on Python 2
             return self.name
 
 We will also use the following example objects::
 
-    >>> Event.objects.create(name='Soft play', ages=(0, 10))
-    >>> Event.objects.create(name='Pub trip', ages=(21, None))
+    >>> import datetime
+    >>> from django.utils import timezone
+    >>> now = timezone.now()
+    >>> Event.objects.create(name='Soft play', ages=(0, 10), start=now)
+    >>> Event.objects.create(name='Pub trip', ages=(21, None), start=now - datetime.timedelta(days=1))
 
 and ``NumericRange``:
 
@@ -667,6 +671,22 @@ contained_by
     >>> Event.objects.filter(ages__contained_by=NumericRange(0, 15))
     [<Event: Soft play>]
 
+.. versionadded 1.9
+
+    The `contained_by` lookup is also available on the non-range field types:
+    :class:`~django.db.models.fields.IntegerField`,
+    :class:`~django.db.models.fields.BigIntegerField`,
+    :class:`~django.db.models.fields.FloatField`,
+    :class:`~django.db.models.fields.DateField`, and
+    :class:`~django.db.models.fields.DateTimeField`. For example::
+
+    >>> from psycopg2.extras import DateTimeTZRange
+    >>> Event.objects.filter(start__contained_by=DateTimeTZRange(
+    ...     timezone.now() - datetime.timedelta(hours=1),
+    ...     timezone.now() + datetime.timedelta(hours=1),
+    ... )
+    [<Event: Soft play>]
+
 .. fieldlookup:: rangefield.overlap
 
 overlap
diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt
index 51be301659..793e26ffa3 100644
--- a/docs/releases/1.9.txt
+++ b/docs/releases/1.9.txt
@@ -91,6 +91,8 @@ Minor features
 :mod:`django.contrib.postgres`
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+* Added support for the :lookup:`rangefield.contained_by` lookup for some built
+  in fields which correspond to the range fields.
 * Added :class:`~django.contrib.postgres.fields.JSONField`.
 * Added :doc:`/ref/contrib/postgres/aggregates`.
 
diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py
index 4eb2154f02..f1d06cc2d9 100644
--- a/tests/postgres_tests/migrations/0002_create_test_models.py
+++ b/tests/postgres_tests/migrations/0002_create_test_models.py
@@ -143,6 +143,21 @@ class Migration(migrations.Migration):
                 ('timestamps', DateTimeRangeField(null=True, blank=True)),
                 ('dates', DateRangeField(null=True, blank=True)),
             ],
+            options={
+                'required_db_vendor': 'postgresql'
+            },
+            bases=(models.Model,)
+        ),
+        migrations.CreateModel(
+            name='RangeLookupsModel',
+            fields=[
+                ('parent', models.ForeignKey('postgres_tests.RangesModel', blank=True, null=True)),
+                ('integer', models.IntegerField(blank=True, null=True)),
+                ('big_integer', models.BigIntegerField(blank=True, null=True)),
+                ('float', models.FloatField(blank=True, null=True)),
+                ('timestamp', models.DateTimeField(blank=True, null=True)),
+                ('date', models.DateField(blank=True, null=True)),
+            ],
             options={
                 'required_db_vendor': 'postgresql',
             },
diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py
index 329a91c951..bef77b7b21 100644
--- a/tests/postgres_tests/models.py
+++ b/tests/postgres_tests/models.py
@@ -60,11 +60,23 @@ if connection.vendor == 'postgresql' and connection.pg_version >= 90200:
         floats = FloatRangeField(blank=True, null=True)
         timestamps = DateTimeRangeField(blank=True, null=True)
         dates = DateRangeField(blank=True, null=True)
+
+    class RangeLookupsModel(PostgreSQLModel):
+        parent = models.ForeignKey(RangesModel, blank=True, null=True)
+        integer = models.IntegerField(blank=True, null=True)
+        big_integer = models.BigIntegerField(blank=True, null=True)
+        float = models.FloatField(blank=True, null=True)
+        timestamp = models.DateTimeField(blank=True, null=True)
+        date = models.DateField(blank=True, null=True)
+
 else:
     # create an object with this name so we don't have failing imports
     class RangesModel(object):
         pass
 
+    class RangeLookupsModel(object):
+        pass
+
 
 # Only create this model for postgres >= 9.4
 if connection.vendor == 'postgresql' and connection.pg_version >= 90400:
diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py
index 7cd8a60b2f..2461130b35 100644
--- a/tests/postgres_tests/test_ranges.py
+++ b/tests/postgres_tests/test_ranges.py
@@ -5,11 +5,12 @@ import unittest
 from django import forms
 from django.core import exceptions, serializers
 from django.db import connection
+from django.db.models import F
 from django.test import TestCase, override_settings
 from django.utils import timezone
 
 from . import PostgreSQLTestCase
-from .models import RangesModel
+from .models import RangeLookupsModel, RangesModel
 
 try:
     from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
@@ -197,6 +198,96 @@ class TestQuerying(TestCase):
         )
 
 
+@skipUnlessPG92
+class TestQueringWithRanges(TestCase):
+    def test_date_range(self):
+        objs = [
+            RangeLookupsModel.objects.create(date='2015-01-01'),
+            RangeLookupsModel.objects.create(date='2015-05-05'),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.filter(date__contained_by=DateRange('2015-01-01', '2015-05-04')),
+            [objs[0]],
+        )
+
+    def test_date_range_datetime_field(self):
+        objs = [
+            RangeLookupsModel.objects.create(timestamp='2015-01-01'),
+            RangeLookupsModel.objects.create(timestamp='2015-05-05'),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.filter(timestamp__date__contained_by=DateRange('2015-01-01', '2015-05-04')),
+            [objs[0]],
+        )
+
+    def test_datetime_range(self):
+        objs = [
+            RangeLookupsModel.objects.create(timestamp='2015-01-01T09:00:00'),
+            RangeLookupsModel.objects.create(timestamp='2015-05-05T17:00:00'),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.filter(
+                timestamp__contained_by=DateTimeTZRange('2015-01-01T09:00', '2015-05-04T23:55')
+            ),
+            [objs[0]],
+        )
+
+    def test_integer_range(self):
+        objs = [
+            RangeLookupsModel.objects.create(integer=5),
+            RangeLookupsModel.objects.create(integer=99),
+            RangeLookupsModel.objects.create(integer=-1),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.filter(integer__contained_by=NumericRange(1, 98)),
+            [objs[0]]
+        )
+
+    def test_biginteger_range(self):
+        objs = [
+            RangeLookupsModel.objects.create(big_integer=5),
+            RangeLookupsModel.objects.create(big_integer=99),
+            RangeLookupsModel.objects.create(big_integer=-1),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.filter(big_integer__contained_by=NumericRange(1, 98)),
+            [objs[0]]
+        )
+
+    def test_float_range(self):
+        objs = [
+            RangeLookupsModel.objects.create(float=5),
+            RangeLookupsModel.objects.create(float=99),
+            RangeLookupsModel.objects.create(float=-1),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.filter(float__contained_by=NumericRange(1, 98)),
+            [objs[0]]
+        )
+
+    def test_f_ranges(self):
+        parent = RangesModel.objects.create(floats=NumericRange(0, 10))
+        objs = [
+            RangeLookupsModel.objects.create(float=5, parent=parent),
+            RangeLookupsModel.objects.create(float=99, parent=parent),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.filter(float__contained_by=F('parent__floats')),
+            [objs[0]]
+        )
+
+    def test_exclude(self):
+        objs = [
+            RangeLookupsModel.objects.create(float=5),
+            RangeLookupsModel.objects.create(float=99),
+            RangeLookupsModel.objects.create(float=-1),
+        ]
+        self.assertSequenceEqual(
+            RangeLookupsModel.objects.exclude(float__contained_by=NumericRange(0, 100)),
+            [objs[2]]
+        )
+
+
 @skipUnlessPG92
 class TestSerialization(TestCase):
     test_data = (