diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index d4edf2ce0f..1fdb11b273 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -486,14 +486,18 @@ class YearLookup(Lookup): class YearComparisonLookup(YearLookup): def as_sql(self, compiler, connection): - # We will need to skip the extract part and instead go - # directly with the originating field, that is self.lhs.lhs. - lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(compiler, connection) - rhs_sql = self.get_rhs_op(connection, rhs_sql) - start, finish = self.year_lookup_bounds(connection, rhs_params[0]) - params.append(self.get_bound(start, finish)) - return '%s %s' % (lhs_sql, rhs_sql), params + # Avoid the extract operation if the rhs is a direct value to allow + # indexes to be used. + if self.rhs_is_direct_value(): + # Skip the extract part by directly using the originating field, + # that is self.lhs.lhs. + lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) + rhs_sql, _ = self.process_rhs(compiler, connection) + rhs_sql = self.get_rhs_op(connection, rhs_sql) + start, finish = self.year_lookup_bounds(connection, self.rhs) + params.append(self.get_bound(start, finish)) + return '%s %s' % (lhs_sql, rhs_sql), params + return super().as_sql(compiler, connection) def get_rhs_op(self, connection, rhs): return connection.operators[self.lookup_name] % rhs @@ -520,29 +524,22 @@ class YearExact(YearLookup, Exact): return super().as_sql(compiler, connection) -class YearGt(YearComparisonLookup): - lookup_name = 'gt' +class YearGt(YearComparisonLookup, GreaterThan): def get_bound(self, start, finish): return finish -class YearGte(YearComparisonLookup): - lookup_name = 'gte' - +class YearGte(YearComparisonLookup, GreaterThanOrEqual): def get_bound(self, start, finish): return start -class YearLt(YearComparisonLookup): - lookup_name = 'lt' - +class YearLt(YearComparisonLookup, LessThan): def get_bound(self, start, finish): return start -class YearLte(YearComparisonLookup): - lookup_name = 'lte' - +class YearLte(YearComparisonLookup, LessThanOrEqual): def get_bound(self, start, finish): return finish diff --git a/tests/db_functions/datetime/test_extract_trunc.py b/tests/db_functions/datetime/test_extract_trunc.py index f62bd0f0b2..854959aca6 100644 --- a/tests/db_functions/datetime/test_extract_trunc.py +++ b/tests/db_functions/datetime/test_extract_trunc.py @@ -135,6 +135,11 @@ class DateFunctionTests(TestCase): qs = DTModel.objects.filter(**{'start_datetime__%s__gte' % lookup: 2015}) self.assertEqual(qs.count(), 2) self.assertEqual(str(qs.query).lower().count('extract'), 0) + qs = DTModel.objects.annotate( + start_year=ExtractYear('start_datetime'), + ).filter(**{'end_datetime__%s__gte' % lookup: F('start_year')}) + self.assertEqual(qs.count(), 1) + self.assertGreaterEqual(str(qs.query).lower().count('extract'), 2) def test_extract_year_lessthan_lookup(self): start_datetime = datetime(2015, 6, 15, 14, 10) @@ -153,6 +158,11 @@ class DateFunctionTests(TestCase): qs = DTModel.objects.filter(**{'start_datetime__%s__lte' % lookup: 2016}) self.assertEqual(qs.count(), 2) self.assertEqual(str(qs.query).count('extract'), 0) + qs = DTModel.objects.annotate( + end_year=ExtractYear('end_datetime'), + ).filter(**{'start_datetime__%s__lte' % lookup: F('end_year')}) + self.assertEqual(qs.count(), 1) + self.assertGreaterEqual(str(qs.query).lower().count('extract'), 2) def test_extract_func(self): start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321)