diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 8e1a43a877..6569120287 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -67,9 +67,6 @@ class BaseDatabaseWrapper(object): self.allow_thread_sharing = allow_thread_sharing self._thread_ident = thread.get_ident() - # Compile implementations, used by compiler.compile(someelem) - self.compile_implementations = utils.get_implementations(self.vendor) - def __eq__(self, other): if isinstance(other, BaseDatabaseWrapper): return self.alias == other.alias diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 2610228085..c22cd9e587 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -195,31 +195,3 @@ def format_number(value, max_digits, decimal_places): return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)) else: return "%.*f" % (decimal_places, value) - -# Map of vendor name -> map of query element class -> implementation function -compile_implementations = defaultdict(dict) - - -def get_implementations(vendor): - return compile_implementations[vendor] - - -class add_implementation(object): - """ - A decorator to allow customised implementations for query expressions. - For example: - @add_implementation(Exact, 'mysql') - def mysql_exact(node, qn, connection): - # Play with the node here. - return somesql, list_of_params - Now Exact nodes are compiled to SQL using mysql_exact instead of - Exact.as_sql() when using MySQL backend. - """ - def __init__(self, klass, vendor): - self.klass = klass - self.vendor = vendor - - def __call__(self, func): - implementations = get_implementations(self.vendor) - implementations[self.klass] = func - return func diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 3465d4f163..d3baa876c8 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -4,7 +4,6 @@ import collections import copy import datetime import decimal -import inspect import math import warnings from base64 import b64decode, b64encode @@ -12,7 +11,7 @@ from itertools import tee from django.apps import apps from django.db import connection -from django.db.models.lookups import default_lookups +from django.db.models.lookups import default_lookups, RegisterLookupMixin from django.db.models.query_utils import QueryWrapper from django.conf import settings from django import forms @@ -82,7 +81,7 @@ def _empty(of_cls): @total_ordering -class Field(object): +class Field(RegisterLookupMixin): """Base class for all field types""" # Designates whether empty strings fundamentally are allowed at the @@ -459,30 +458,6 @@ class Field(object): def get_internal_type(self): return self.__class__.__name__ - def get_lookup(self, lookup_name): - try: - return self.class_lookups[lookup_name] - except KeyError: - for parent in inspect.getmro(self.__class__): - if not 'class_lookups' in parent.__dict__: - continue - if lookup_name in parent.class_lookups: - return parent.class_lookups[lookup_name] - - @classmethod - def register_lookup(cls, lookup): - if not 'class_lookups' in cls.__dict__: - cls.class_lookups = {} - cls.class_lookups[lookup.lookup_name] = lookup - - @classmethod - def _unregister_lookup(cls, lookup): - """ - Removes given lookup from cls lookups. Meant to be used in - tests only. - """ - del cls.class_lookups[lookup.lookup_name] - def pre_save(self, model_instance, add): """ Returns field's value just before saving. diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 00cbcd5173..1e5a872d9c 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,18 +1,49 @@ from copy import copy +import inspect from django.conf import settings from django.utils import timezone from django.utils.functional import cached_property -class Extract(object): +class RegisterLookupMixin(object): + def get_lookup(self, lookup_name): + try: + return self.class_lookups[lookup_name] + except KeyError: + # To allow for inheritance, check parent class class lookups. + for parent in inspect.getmro(self.__class__): + if not 'class_lookups' in parent.__dict__: + continue + if lookup_name in parent.class_lookups: + return parent.class_lookups[lookup_name] + except AttributeError: + # This class didn't have any class_lookups + pass + if hasattr(self, 'output_type'): + return self.output_type.get_lookup(lookup_name) + return None + + @classmethod + def register_lookup(cls, lookup): + if not 'class_lookups' in cls.__dict__: + cls.class_lookups = {} + cls.class_lookups[lookup.lookup_name] = lookup + + @classmethod + def _unregister_lookup(cls, lookup): + """ + Removes given lookup from cls lookups. Meant to be used in + tests only. + """ + del cls.class_lookups[lookup.lookup_name] + + +class Extract(RegisterLookupMixin): def __init__(self, lhs, lookups): self.lhs = lhs self.init_lookups = lookups[:] - def get_lookup(self, lookup): - return self.output_type.get_lookup(lookup) - def as_sql(self, qn, connection): raise NotImplementedError @@ -27,7 +58,7 @@ class Extract(object): return self.lhs.get_cols() -class Lookup(object): +class Lookup(RegisterLookupMixin): lookup_name = None def __init__(self, lhs, rhs): diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 7c4ec71be0..445079b72e 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -4,6 +4,7 @@ Classes to represent the default SQL aggregate functions import copy from django.db.models.fields import IntegerField, FloatField +from django.db.models.lookups import RegisterLookupMixin __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] @@ -14,7 +15,7 @@ ordinal_aggregate_field = IntegerField() computed_aggregate_field = FloatField() -class Aggregate(object): +class Aggregate(RegisterLookupMixin): """ Default SQL Aggregate. """ @@ -100,9 +101,6 @@ class Aggregate(object): def output_type(self): return self.field - def get_lookup(self, lookup): - return self.output_type.get_lookup(lookup) - class Avg(Aggregate): is_computed = True diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 9e355ee6c5..90b32c1b1e 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -70,9 +70,10 @@ class SQLCompiler(object): return self(name) def compile(self, node): - if node.__class__ in self.connection.compile_implementations: - return self.connection.compile_implementations[node.__class__]( - node, self, self.connection) + vendor_impl = getattr( + node, 'as_' + self.connection.vendor, None) + if vendor_impl: + return vendor_impl(self, self.connection) else: return node.as_sql(self, self.connection) diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 19d952c4d7..12f557e304 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -1,4 +1,3 @@ -from copy import copy from datetime import date import unittest @@ -6,7 +5,6 @@ from django.test import TestCase from .models import Author from django.db import models from django.db import connection -from django.db.backends.utils import add_implementation class Div3Lookup(models.lookups.Lookup): @@ -27,6 +25,37 @@ class Div3Extract(models.lookups.Extract): return '%s %%%% 3' % (lhs,), lhs_params +class YearExtract(models.lookups.Extract): + lookup_name = 'year' + + def as_sql(self, qn, connection): + lhs_sql, params = qn.compile(self.lhs) + return connection.ops.date_extract_sql('year', lhs_sql), params + + @property + def output_type(self): + return models.IntegerField() + + +class YearExact(models.lookups.Lookup): + lookup_name = 'exact' + + def as_sql(self, qn, connection): + # We will need to skip the extract part, and instead go + # directly with the originating field, that is self.lhs.lhs + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + # Note that we must be careful so that we have params in the + # same order as we have the parts in the SQL. + params = lhs_params + rhs_params + lhs_params + rhs_params + # We use PostgreSQL specific SQL here. Note that we must do the + # conversions in SQL instead of in Python to support F() references. + return ("%(lhs)s >= (%(rhs)s || '-01-01')::date " + "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) +YearExtract.register_lookup(YearExact) + + class YearLte(models.lookups.LessThanOrEqual): """ The purpose of this lookup is to efficiently compare the year of the field. @@ -44,80 +73,27 @@ class YearLte(models.lookups.LessThanOrEqual): # WHERE somecol <= '2013-12-31') # but also make it work if the rhs_sql is field reference. return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params +YearExtract.register_lookup(YearLte) -class YearExtract(models.lookups.Extract): - lookup_name = 'year' - - def as_sql(self, qn, connection): - lhs_sql, params = qn.compile(self.lhs) - return connection.ops.date_extract_sql('year', lhs_sql), params - - @property - def output_type(self): - return models.IntegerField() - - def get_lookup(self, lookup): - if lookup == 'lte': - return YearLte - elif lookup == 'exact': - return YearExact - else: - return super(YearExtract, self).get_lookup(lookup) - - -class YearExact(models.lookups.Lookup): - def as_sql(self, qn, connection): - # We will need to skip the extract part, and instead go - # directly with the originating field, that is self.lhs.lhs - lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(qn, connection) - # Note that we must be careful so that we have params in the - # same order as we have the parts in the SQL. - params = [] - params.extend(lhs_params) - params.extend(rhs_params) - params.extend(lhs_params) - params.extend(rhs_params) - # We use PostgreSQL specific SQL here. Note that we must do the - # conversions in SQL instead of in Python to support F() references. - return ("%(lhs)s >= (%(rhs)s || '-01-01')::date " - "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % - {'lhs': lhs_sql, 'rhs': rhs_sql}, params) - - -@add_implementation(YearExact, 'mysql') -def mysql_year_exact(node, qn, connection): - lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs) - rhs_sql, rhs_params = node.process_rhs(qn, connection) - params = [] - params.extend(lhs_params) - params.extend(rhs_params) - params.extend(lhs_params) - params.extend(rhs_params) - return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " - "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % - {'lhs': lhs_sql, 'rhs': rhs_sql}, params) +# We will register this class temporarily in the test method. class InMonth(models.lookups.Lookup): """ - InMonth matches if the column's month is contained in the value's month. + InMonth matches if the column's month is the same as value's month. """ lookup_name = 'inmonth' def as_sql(self, qn, connection): - lhs, params = self.process_lhs(qn, connection) + lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) # We need to be careful so that we get the params in right # places. - full_params = params[:] - full_params.extend(rhs_params) - full_params.extend(params) - full_params.extend(rhs_params) + params = lhs_params + rhs_params + lhs_params + rhs_params return ("%s >= date_trunc('month', %s) and " "%s < date_trunc('month', %s) + interval '1 months'" % - (lhs, rhs, lhs, rhs), full_params) + (lhs, rhs, lhs, rhs), params) class LookupTests(TestCase): @@ -178,44 +154,6 @@ class LookupTests(TestCase): finally: models.DateField._unregister_lookup(InMonth) - def test_custom_compiles(self): - a1 = Author.objects.create(name='a1', age=1) - a2 = Author.objects.create(name='a2', age=2) - a3 = Author.objects.create(name='a3', age=3) - a4 = Author.objects.create(name='a4', age=4) - - class AnotherEqual(models.lookups.Exact): - lookup_name = 'anotherequal' - models.Field.register_lookup(AnotherEqual) - try: - @add_implementation(AnotherEqual, connection.vendor) - def custom_eq_sql(node, qn, connection): - return '1 = 1', [] - - self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query)) - self.assertQuerysetEqual( - Author.objects.filter(name__anotherequal='asdf').order_by('name'), - [a1, a2, a3, a4], lambda x: x) - - @add_implementation(AnotherEqual, connection.vendor) - def another_custom_eq_sql(node, qn, connection): - # If you need to override one method, it seems this is the best - # option. - node = copy(node) - - class OverriddenAnotherEqual(AnotherEqual): - def get_rhs_op(self, connection, rhs): - return ' <> %s' - node.__class__ = OverriddenAnotherEqual - return node.as_sql(qn, connection) - self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query)) - self.assertQuerysetEqual( - Author.objects.filter(name__anotherequal='a1').order_by('name'), - [a2, a3, a4], lambda x: x - ) - finally: - models.Field._unregister_lookup(AnotherEqual) - def test_div3_extract(self): models.IntegerField.register_lookup(Div3Extract) try: @@ -293,11 +231,49 @@ class YearLteTests(TestCase): self.assertIn( '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) - @unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used') - def test_mysql_year_exact(self): - self.assertQuerysetEqual( - Author.objects.filter(birthdate__year=2012).order_by('name'), - [self.a2, self.a3, self.a4], lambda x: x) + def test_postgres_year_exact(self): + baseqs = Author.objects.order_by('name') self.assertIn( - 'concat(', - str(Author.objects.filter(birthdate__year=2012).query)) + '= (2011 || ', str(baseqs.filter(birthdate__year=2011).query)) + self.assertIn( + '-12-31', str(baseqs.filter(birthdate__year=2011).query)) + + def test_custom_implementation_year_exact(self): + try: + # Two ways to add a customized implementation for different backends: + # First is MonkeyPatch of the class. + def as_custom_sql(self, qn, connection): + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + lhs_params + rhs_params + return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + setattr(YearExact, 'as_' + connection.vendor, as_custom_sql) + self.assertIn( + 'concat(', + str(Author.objects.filter(birthdate__year=2012).query)) + finally: + delattr(YearExact, 'as_' + connection.vendor) + try: + # The other way is to subclass the original lookup and register the subclassed + # lookup instead of the original. + class CustomYearExact(YearExact): + # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres + # and so on, but as we don't know which DB we are running on, we need to use + # setattr. + def as_custom_sql(self, qn, connection): + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + lhs_params + rhs_params + return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql) + YearExtract.register_lookup(CustomYearExact) + self.assertIn( + 'CONCAT(', + str(Author.objects.filter(birthdate__year=2012).query)) + finally: + YearExtract._unregister_lookup(CustomYearExact) + YearExtract.register_lookup(YearExact)