From 514104cf236c1039644b70c0c0f128cecd42b233 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Mon, 20 May 2019 19:58:11 -0400 Subject: [PATCH] Refs #29396, #30494 -- Reduced code duplication in year lookups. --- django/db/models/lookups.py | 53 ++++++++++++++---------------------- tests/lookup/test_lookups.py | 12 ++++---- 2 files changed, 27 insertions(+), 38 deletions(-) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 1fdb11b273..92c1a0fcc3 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -483,8 +483,6 @@ class YearLookup(Lookup): bounds = connection.ops.year_lookup_bounds_for_date_field(year) return bounds - -class YearComparisonLookup(YearLookup): def as_sql(self, compiler, connection): # Avoid the extract operation if the rhs is a direct value to allow # indexes to be used. @@ -493,53 +491,44 @@ class YearComparisonLookup(YearLookup): # that is self.lhs.lhs. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) rhs_sql, _ = self.process_rhs(compiler, connection) - rhs_sql = self.get_rhs_op(connection, rhs_sql) + rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql) start, finish = self.year_lookup_bounds(connection, self.rhs) - params.append(self.get_bound(start, finish)) + params.extend(self.get_bound_params(start, finish)) return '%s %s' % (lhs_sql, rhs_sql), params return super().as_sql(compiler, connection) - def get_rhs_op(self, connection, rhs): + def get_direct_rhs_sql(self, connection, rhs): return connection.operators[self.lookup_name] % rhs - def get_bound(self, start, finish): + def get_bound_params(self, start, finish): raise NotImplementedError( - 'subclasses of YearComparisonLookup must provide a get_bound() method' + 'subclasses of YearLookup must provide a get_bound_params() method' ) class YearExact(YearLookup, Exact): - lookup_name = 'exact' + def get_direct_rhs_sql(self, connection, rhs): + return 'BETWEEN %s AND %s' - def as_sql(self, compiler, connection): - # Avoid the extract operation if the rhs is a direct value to allow - # indexes to be used. - if self.rhs_is_direct_value(): - # Skip the extract part by directly using the originating field, - # that is self.lhs.lhs. - lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) - bounds = self.year_lookup_bounds(connection, self.rhs) - params.extend(bounds) - return '%s BETWEEN %%s AND %%s' % lhs_sql, params - return super().as_sql(compiler, connection) + def get_bound_params(self, start, finish): + return (start, finish) - -class YearGt(YearComparisonLookup, GreaterThan): - def get_bound(self, start, finish): - return finish +class YearGt(YearLookup, GreaterThan): + def get_bound_params(self, start, finish): + return (finish,) -class YearGte(YearComparisonLookup, GreaterThanOrEqual): - def get_bound(self, start, finish): - return start +class YearGte(YearLookup, GreaterThanOrEqual): + def get_bound_params(self, start, finish): + return (start,) -class YearLt(YearComparisonLookup, LessThan): - def get_bound(self, start, finish): - return start +class YearLt(YearLookup, LessThan): + def get_bound_params(self, start, finish): + return (start,) -class YearLte(YearComparisonLookup, LessThanOrEqual): - def get_bound(self, start, finish): - return finish +class YearLte(YearLookup, LessThanOrEqual): + def get_bound_params(self, start, finish): + return (finish,) diff --git a/tests/lookup/test_lookups.py b/tests/lookup/test_lookups.py index d327b472fc..9b2d90fdd9 100644 --- a/tests/lookup/test_lookups.py +++ b/tests/lookup/test_lookups.py @@ -2,16 +2,16 @@ from datetime import datetime from django.db.models import Value from django.db.models.fields import DateTimeField -from django.db.models.lookups import YearComparisonLookup +from django.db.models.lookups import YearLookup from django.test import SimpleTestCase -class YearComparisonLookupTests(SimpleTestCase): - def test_get_bound(self): - look_up = YearComparisonLookup( +class YearLookupTests(SimpleTestCase): + def test_get_bound_params(self): + look_up = YearLookup( lhs=Value(datetime(2010, 1, 1, 0, 0, 0), output_field=DateTimeField()), rhs=Value(datetime(2010, 1, 1, 23, 59, 59), output_field=DateTimeField()), ) - msg = 'subclasses of YearComparisonLookup must provide a get_bound() method' + msg = 'subclasses of YearLookup must provide a get_bound_params() method' with self.assertRaisesMessage(NotImplementedError, msg): - look_up.get_bound(datetime(2010, 1, 1, 0, 0, 0), datetime(2010, 1, 1, 23, 59, 59)) + look_up.get_bound_params(datetime(2010, 1, 1, 0, 0, 0), datetime(2010, 1, 1, 23, 59, 59))