From c7d5f8661b7d364962bed2e6f81161c1b4f1bcc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Sat, 11 Jan 2014 14:45:53 +0200 Subject: [PATCH] Altered query string customization for backends vendors The new way is trying to call first method 'as_' + connection.vendor. If that doesn't exist, then call as_sql(). Also altered how lookup registration is done. There is now RegisterLookupMixin class that is used by Field, Extract and sql.Aggregate. This allows one to register lookups for extracts and aggregates in the same way lookup registration is done for fields. --- django/db/backends/__init__.py | 3 - django/db/backends/utils.py | 28 ----- django/db/models/fields/__init__.py | 29 +---- django/db/models/lookups.py | 41 +++++- django/db/models/sql/aggregates.py | 6 +- django/db/models/sql/compiler.py | 7 +- tests/custom_lookups/tests.py | 188 ++++++++++++---------------- 7 files changed, 126 insertions(+), 176 deletions(-) diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 8e1a43a877..6569120287 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -67,9 +67,6 @@ class BaseDatabaseWrapper(object): self.allow_thread_sharing = allow_thread_sharing self._thread_ident = thread.get_ident() - # Compile implementations, used by compiler.compile(someelem) - self.compile_implementations = utils.get_implementations(self.vendor) - def __eq__(self, other): if isinstance(other, BaseDatabaseWrapper): return self.alias == other.alias diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 2610228085..c22cd9e587 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -195,31 +195,3 @@ def format_number(value, max_digits, decimal_places): return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)) else: return "%.*f" % (decimal_places, value) - -# Map of vendor name -> map of query element class -> implementation function -compile_implementations = defaultdict(dict) - - -def get_implementations(vendor): - return compile_implementations[vendor] - - -class add_implementation(object): - """ - A decorator to allow customised implementations for query expressions. - For example: - @add_implementation(Exact, 'mysql') - def mysql_exact(node, qn, connection): - # Play with the node here. - return somesql, list_of_params - Now Exact nodes are compiled to SQL using mysql_exact instead of - Exact.as_sql() when using MySQL backend. - """ - def __init__(self, klass, vendor): - self.klass = klass - self.vendor = vendor - - def __call__(self, func): - implementations = get_implementations(self.vendor) - implementations[self.klass] = func - return func diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 3465d4f163..d3baa876c8 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -4,7 +4,6 @@ import collections import copy import datetime import decimal -import inspect import math import warnings from base64 import b64decode, b64encode @@ -12,7 +11,7 @@ from itertools import tee from django.apps import apps from django.db import connection -from django.db.models.lookups import default_lookups +from django.db.models.lookups import default_lookups, RegisterLookupMixin from django.db.models.query_utils import QueryWrapper from django.conf import settings from django import forms @@ -82,7 +81,7 @@ def _empty(of_cls): @total_ordering -class Field(object): +class Field(RegisterLookupMixin): """Base class for all field types""" # Designates whether empty strings fundamentally are allowed at the @@ -459,30 +458,6 @@ class Field(object): def get_internal_type(self): return self.__class__.__name__ - def get_lookup(self, lookup_name): - try: - return self.class_lookups[lookup_name] - except KeyError: - for parent in inspect.getmro(self.__class__): - if not 'class_lookups' in parent.__dict__: - continue - if lookup_name in parent.class_lookups: - return parent.class_lookups[lookup_name] - - @classmethod - def register_lookup(cls, lookup): - if not 'class_lookups' in cls.__dict__: - cls.class_lookups = {} - cls.class_lookups[lookup.lookup_name] = lookup - - @classmethod - def _unregister_lookup(cls, lookup): - """ - Removes given lookup from cls lookups. Meant to be used in - tests only. - """ - del cls.class_lookups[lookup.lookup_name] - def pre_save(self, model_instance, add): """ Returns field's value just before saving. diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 00cbcd5173..1e5a872d9c 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,18 +1,49 @@ from copy import copy +import inspect from django.conf import settings from django.utils import timezone from django.utils.functional import cached_property -class Extract(object): +class RegisterLookupMixin(object): + def get_lookup(self, lookup_name): + try: + return self.class_lookups[lookup_name] + except KeyError: + # To allow for inheritance, check parent class class lookups. + for parent in inspect.getmro(self.__class__): + if not 'class_lookups' in parent.__dict__: + continue + if lookup_name in parent.class_lookups: + return parent.class_lookups[lookup_name] + except AttributeError: + # This class didn't have any class_lookups + pass + if hasattr(self, 'output_type'): + return self.output_type.get_lookup(lookup_name) + return None + + @classmethod + def register_lookup(cls, lookup): + if not 'class_lookups' in cls.__dict__: + cls.class_lookups = {} + cls.class_lookups[lookup.lookup_name] = lookup + + @classmethod + def _unregister_lookup(cls, lookup): + """ + Removes given lookup from cls lookups. Meant to be used in + tests only. + """ + del cls.class_lookups[lookup.lookup_name] + + +class Extract(RegisterLookupMixin): def __init__(self, lhs, lookups): self.lhs = lhs self.init_lookups = lookups[:] - def get_lookup(self, lookup): - return self.output_type.get_lookup(lookup) - def as_sql(self, qn, connection): raise NotImplementedError @@ -27,7 +58,7 @@ class Extract(object): return self.lhs.get_cols() -class Lookup(object): +class Lookup(RegisterLookupMixin): lookup_name = None def __init__(self, lhs, rhs): diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 7c4ec71be0..445079b72e 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -4,6 +4,7 @@ Classes to represent the default SQL aggregate functions import copy from django.db.models.fields import IntegerField, FloatField +from django.db.models.lookups import RegisterLookupMixin __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] @@ -14,7 +15,7 @@ ordinal_aggregate_field = IntegerField() computed_aggregate_field = FloatField() -class Aggregate(object): +class Aggregate(RegisterLookupMixin): """ Default SQL Aggregate. """ @@ -100,9 +101,6 @@ class Aggregate(object): def output_type(self): return self.field - def get_lookup(self, lookup): - return self.output_type.get_lookup(lookup) - class Avg(Aggregate): is_computed = True diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 9e355ee6c5..90b32c1b1e 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -70,9 +70,10 @@ class SQLCompiler(object): return self(name) def compile(self, node): - if node.__class__ in self.connection.compile_implementations: - return self.connection.compile_implementations[node.__class__]( - node, self, self.connection) + vendor_impl = getattr( + node, 'as_' + self.connection.vendor, None) + if vendor_impl: + return vendor_impl(self, self.connection) else: return node.as_sql(self, self.connection) diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 19d952c4d7..12f557e304 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -1,4 +1,3 @@ -from copy import copy from datetime import date import unittest @@ -6,7 +5,6 @@ from django.test import TestCase from .models import Author from django.db import models from django.db import connection -from django.db.backends.utils import add_implementation class Div3Lookup(models.lookups.Lookup): @@ -27,6 +25,37 @@ class Div3Extract(models.lookups.Extract): return '%s %%%% 3' % (lhs,), lhs_params +class YearExtract(models.lookups.Extract): + lookup_name = 'year' + + def as_sql(self, qn, connection): + lhs_sql, params = qn.compile(self.lhs) + return connection.ops.date_extract_sql('year', lhs_sql), params + + @property + def output_type(self): + return models.IntegerField() + + +class YearExact(models.lookups.Lookup): + lookup_name = 'exact' + + def as_sql(self, qn, connection): + # We will need to skip the extract part, and instead go + # directly with the originating field, that is self.lhs.lhs + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + # Note that we must be careful so that we have params in the + # same order as we have the parts in the SQL. + params = lhs_params + rhs_params + lhs_params + rhs_params + # We use PostgreSQL specific SQL here. Note that we must do the + # conversions in SQL instead of in Python to support F() references. + return ("%(lhs)s >= (%(rhs)s || '-01-01')::date " + "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) +YearExtract.register_lookup(YearExact) + + class YearLte(models.lookups.LessThanOrEqual): """ The purpose of this lookup is to efficiently compare the year of the field. @@ -44,80 +73,27 @@ class YearLte(models.lookups.LessThanOrEqual): # WHERE somecol <= '2013-12-31') # but also make it work if the rhs_sql is field reference. return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params +YearExtract.register_lookup(YearLte) -class YearExtract(models.lookups.Extract): - lookup_name = 'year' - - def as_sql(self, qn, connection): - lhs_sql, params = qn.compile(self.lhs) - return connection.ops.date_extract_sql('year', lhs_sql), params - - @property - def output_type(self): - return models.IntegerField() - - def get_lookup(self, lookup): - if lookup == 'lte': - return YearLte - elif lookup == 'exact': - return YearExact - else: - return super(YearExtract, self).get_lookup(lookup) - - -class YearExact(models.lookups.Lookup): - def as_sql(self, qn, connection): - # We will need to skip the extract part, and instead go - # directly with the originating field, that is self.lhs.lhs - lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(qn, connection) - # Note that we must be careful so that we have params in the - # same order as we have the parts in the SQL. - params = [] - params.extend(lhs_params) - params.extend(rhs_params) - params.extend(lhs_params) - params.extend(rhs_params) - # We use PostgreSQL specific SQL here. Note that we must do the - # conversions in SQL instead of in Python to support F() references. - return ("%(lhs)s >= (%(rhs)s || '-01-01')::date " - "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % - {'lhs': lhs_sql, 'rhs': rhs_sql}, params) - - -@add_implementation(YearExact, 'mysql') -def mysql_year_exact(node, qn, connection): - lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs) - rhs_sql, rhs_params = node.process_rhs(qn, connection) - params = [] - params.extend(lhs_params) - params.extend(rhs_params) - params.extend(lhs_params) - params.extend(rhs_params) - return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " - "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % - {'lhs': lhs_sql, 'rhs': rhs_sql}, params) +# We will register this class temporarily in the test method. class InMonth(models.lookups.Lookup): """ - InMonth matches if the column's month is contained in the value's month. + InMonth matches if the column's month is the same as value's month. """ lookup_name = 'inmonth' def as_sql(self, qn, connection): - lhs, params = self.process_lhs(qn, connection) + lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) # We need to be careful so that we get the params in right # places. - full_params = params[:] - full_params.extend(rhs_params) - full_params.extend(params) - full_params.extend(rhs_params) + params = lhs_params + rhs_params + lhs_params + rhs_params return ("%s >= date_trunc('month', %s) and " "%s < date_trunc('month', %s) + interval '1 months'" % - (lhs, rhs, lhs, rhs), full_params) + (lhs, rhs, lhs, rhs), params) class LookupTests(TestCase): @@ -178,44 +154,6 @@ class LookupTests(TestCase): finally: models.DateField._unregister_lookup(InMonth) - def test_custom_compiles(self): - a1 = Author.objects.create(name='a1', age=1) - a2 = Author.objects.create(name='a2', age=2) - a3 = Author.objects.create(name='a3', age=3) - a4 = Author.objects.create(name='a4', age=4) - - class AnotherEqual(models.lookups.Exact): - lookup_name = 'anotherequal' - models.Field.register_lookup(AnotherEqual) - try: - @add_implementation(AnotherEqual, connection.vendor) - def custom_eq_sql(node, qn, connection): - return '1 = 1', [] - - self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query)) - self.assertQuerysetEqual( - Author.objects.filter(name__anotherequal='asdf').order_by('name'), - [a1, a2, a3, a4], lambda x: x) - - @add_implementation(AnotherEqual, connection.vendor) - def another_custom_eq_sql(node, qn, connection): - # If you need to override one method, it seems this is the best - # option. - node = copy(node) - - class OverriddenAnotherEqual(AnotherEqual): - def get_rhs_op(self, connection, rhs): - return ' <> %s' - node.__class__ = OverriddenAnotherEqual - return node.as_sql(qn, connection) - self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query)) - self.assertQuerysetEqual( - Author.objects.filter(name__anotherequal='a1').order_by('name'), - [a2, a3, a4], lambda x: x - ) - finally: - models.Field._unregister_lookup(AnotherEqual) - def test_div3_extract(self): models.IntegerField.register_lookup(Div3Extract) try: @@ -293,11 +231,49 @@ class YearLteTests(TestCase): self.assertIn( '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) - @unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used') - def test_mysql_year_exact(self): - self.assertQuerysetEqual( - Author.objects.filter(birthdate__year=2012).order_by('name'), - [self.a2, self.a3, self.a4], lambda x: x) + def test_postgres_year_exact(self): + baseqs = Author.objects.order_by('name') self.assertIn( - 'concat(', - str(Author.objects.filter(birthdate__year=2012).query)) + '= (2011 || ', str(baseqs.filter(birthdate__year=2011).query)) + self.assertIn( + '-12-31', str(baseqs.filter(birthdate__year=2011).query)) + + def test_custom_implementation_year_exact(self): + try: + # Two ways to add a customized implementation for different backends: + # First is MonkeyPatch of the class. + def as_custom_sql(self, qn, connection): + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + lhs_params + rhs_params + return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + setattr(YearExact, 'as_' + connection.vendor, as_custom_sql) + self.assertIn( + 'concat(', + str(Author.objects.filter(birthdate__year=2012).query)) + finally: + delattr(YearExact, 'as_' + connection.vendor) + try: + # The other way is to subclass the original lookup and register the subclassed + # lookup instead of the original. + class CustomYearExact(YearExact): + # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres + # and so on, but as we don't know which DB we are running on, we need to use + # setattr. + def as_custom_sql(self, qn, connection): + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + lhs_params + rhs_params + return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql) + YearExtract.register_lookup(CustomYearExact) + self.assertIn( + 'CONCAT(', + str(Author.objects.filter(birthdate__year=2012).query)) + finally: + YearExtract._unregister_lookup(CustomYearExact) + YearExtract.register_lookup(YearExact)