Added documentation, polished implementation

This commit is contained in:
Anssi Kääriäinen 2013-12-01 02:14:19 +02:00
parent 32c04357a8
commit 2adf50428d
5 changed files with 322 additions and 23 deletions

View File

@ -7,8 +7,8 @@ from django.utils.functional import cached_property
class Extract(object):
def __init__(self, constraint_class, lhs):
self.constraint_class, self.lhs = constraint_class, lhs
def __init__(self, lhs):
self.lhs = lhs
def get_lookup(self, lookup):
return self.output_type.get_lookup(lookup)
@ -21,15 +21,18 @@ class Extract(object):
return self.lhs.output_type
def relabeled_clone(self, relabels):
return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels))
return self.__class__(self.lhs.relabeled_clone(relabels))
def get_cols(self):
return self.lhs.get_cols()
class Lookup(object):
lookup_name = None
extract_class = None
def __init__(self, constraint_class, lhs, rhs):
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
def __init__(self, lhs, rhs):
self.lhs, self.rhs = lhs, rhs
if rhs is None:
if not self.extract_class:
raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name)
@ -37,7 +40,7 @@ class Lookup(object):
self.rhs = self.get_prep_lookup()
def get_extract(self):
return self.extract_class(self.constraint_class, self.lhs)
return self.extract_class(self.lhs)
def get_prep_lookup(self):
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)

View File

@ -71,7 +71,8 @@ class SQLCompiler(object):
def compile(self, node):
if node.__class__ in self.connection.compile_implementations:
return self.connection.compile_implementations[node.__class__](node, self)
return self.connection.compile_implementations[node.__class__](
node, self, self.connection)
else:
return node.as_sql(self, self.connection)

View File

@ -18,6 +18,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.aggregates import refs_aggregate
from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist
from django.db.models.lookups import Extract
from django.db.models.query_utils import Q
from django.db.models.related import PathInfo
from django.db.models.sql import aggregates as base_aggregates_module
@ -1088,9 +1089,12 @@ class Query(object):
if next:
if not lookups:
# This was the last lookup, so return value lookup.
return next(self.where_class, lhs, rhs)
if issubclass(next, Extract):
lhs = next(lhs)
next = lhs.get_lookup('exact')
return next(lhs, rhs)
else:
lhs = next(self.where_class, lhs, None).get_extract()
lhs = next(lhs)
# A field's get_lookup() can return None to opt for backwards
# compatibility path.
elif len(lookups) > 1:

243
docs/ref/models/lookups.txt Normal file
View File

@ -0,0 +1,243 @@
==============
Custom lookups
==============
.. module:: django.db.models.lookups
:synopsis: Custom lookups
.. currentmodule:: django.db.models
(This documentation is candidate for complete rewrite, but contains
useful information of how to test the current implementation.)
This documentation constains instructions of how to create custom lookups
for model fields.
Django's ORM works using lookup paths when building query filters and other
query structures. For example in the query Book.filter(author__age__lte=30)
the author__age__lte is the lookup path.
The lookup path consist of three different part. First is the related lookups,
above part author refers to Book's related model Author. Second part of the
lookup path is the final field, above this is Author's field age. Finally the
lte part is commonly called just lookup (TODO: this nomenclature is confusing,
can we invent something better).
This documentation concentrates on writing custom lookups, that is custom
implementations for lte or any other lookup you wish to use.
Django will fetch a ``Lookup`` class from the final field using the field's
method get_lookup(lookup_name). This method can do three things:
1. Return a Lookup class
2. Raise a FieldError
3. Return None
Above return None is only available during backwards compatibility period and
returning None will not be allowed in Django 1.9 or later. The interpretation
is to use the old way of lookup hadling inside the ORM.
The returned Lookup will be used to build the query.
The Lookup class
~~~~~~~~~~~~~~~~
The API is as follows:
.. attribute:: lookup_name
A string used by Django to distinguish different lookups.
.. method:: __init__(lhs, rhs)
The lhs and rhs are the field reference (reference to field age in the
author__age__lte=30 example), and rhs is the value (30 in the example).
.. attribute:: Lookup.lhs
The left hand side part of this lookup. You can assume it implements the
query part interface (TODO: write interface definition...).
.. method:: Lookup.as_sql(qn, connection)
This method is used to produce the query string of the Lookup. A typical
implementation is usually something like::
def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params.extend(rhs_params)
return '%s <OPERATOR> %s', (lhs, rhs), params
where the <OPERATOR> is some query operator. The qn is a callable that
can be used to convert strings to quoted variants (that is, colname to
"colname"). Note that the quotation is *not* safe against SQL injection.
In addition the qn implements method compile() which can be used to turn
anything with as_sql() method to query string. You should always call
qn.compile(part) instead of part.as_sql(qn, connection) so that 3rd party
backends have ability to customize the produced query string. More of this
later on.
The connection is the used connection.
.. method:: Lookup.process_lhs(qn, connection, lhs=None)
This method is used to convert the left hand side of the lookup into query
string. The left hand side can be a field reference or a nested lookup. The
lhs kwarg can be used to convert something else than self.lhs to query string.
.. method:: Lookup.process_rhs(qn, connection, rhs=None)
The process_rhs method is used to convert the right hand side into query string.
The rhs is the value given in the filter clause. It can be a raw value to
compare agains, a F() reference to another field or even a QuerySet.
.. method:: get_extract()
The get_extract method is used in nested lookups. It must return an Extract instance.
.. classattribute:: Lookup.extract_class
The default implementation of get_extract() will return an instance of extract_class.
In addition there are some private methods - that is, implementing just the above
mentioned attributes and methods is not enough, you must subclass Lookup instead.
The Extract class
~~~~~~~~~~~~~~~~~
An Extract is something that converts a value to another value in the query string.
For example you could have an Extract that procudes modulo 3 of the given value.
In SQL this would be something like "author"."age" % 3.
Extracts are used in nested lookups. The Extract class must implement the query
part interface.
A simple Lookup example
~~~~~~~~~~~~~~~~~~~~~~~
This is how to write a simple div3 lookup for IntegerField::
from django.db.models import Lookup, IntegerField
class Div3(Lookup):
lookup_name = 'div3'
def as_sql(self, qn, connection):
lhs_sql, params = self.process_lhs(qn, connection)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
# We need doulbe-escaping for the %%%% operator.
return '%s %%%% %s' % (lhs_sql, rhs_sql), params
IntegerField.register_lookup(Div3)
Now all IntegerFields or subclasses of IntegerField will have
a div3 lookup. For example you could do Author.objects.filter(age__div3=2).
This query would return every author whose age % 3 == 2.
A simple nested lookup example
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here is how to write an Extract and a Lookup for IntegerField. The example
lookup can be used similarly as the above div3 lookup, and in addition it
support nesting lookups::
class Div3Extract(Extract):
lookup_name = 'div3'
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params
IntegerField.register_lookup(Div3Extract)
Note that if you already added Div3 for IntegerField in the above
example, now Div3LookupWithExtract will override that lookup.
This lookup can be used like Div3 lookup, but in addition it supports
nesting, too. The default output type for Extracts is the same type as the
lhs' output_type. So, the Div3Extract supports all the same lookups as
IntegerField. For example Author.objects.filter(age__div3__in=[1, 2])
returns all authors for which age % 3 in (1, 2).
A more complex nested lookup
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We will write a Year lookup that extracts year from date field. This
field will convert the output type of the field - the lhs (or "input")
field is DateField, but output is of type IntegerField.::
from django.db.models import IntegerField, DateField
from django.db.models.lookups import Extract
class YearExtract(Extract):
lookup_name = 'year'
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
# hmmh - this is internal API...
return connection.ops.date_extract_sql('year', lhs_sql), params
@property
def output_type(self):
return IntegerField()
DateField.register_lookup(YearExtract)
Now you could write Author.objects.filter(birthdate__year=1981). This will
produce SQL like 'EXTRACT('year' from "author"."birthdate") = 1981'. The
produces SQL depends on used backend. In addtition you can use any lookup
defined for IntegerField, even div3 if you added that. So,
Authos.objects.filter(birthdate__year__div3=2) will return every author
with birthdate.year % 3 == 2.
We could go further and add an optimized implementation for exact lookups::
from django.db.models.lookups import Lookup
class YearExtractOptimized(YearExtract):
def get_lookup(self, lookup):
if lookup == 'exact':
return YearExact
return super(YearExtractOptimized, self).get_lookup()
class YearExact(Lookup):
def as_sql(self, qn, connection):
# We will need to skip the extract part, and instead go
# directly with the originating field, that is self.lhs.lhs
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
# Note that we must be careful so that we have params in the
# same order as we have the parts in the SQL.
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
# We use PostgreSQL specific SQL here. Note that we must do the
# conversions in SQL instead of in Python to support F() references.
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
Note that we used PostgreSQL specific SQL above. What if we want to support
MySQL, too? This can be done by registering a different compiling implementation
for MySQL::
from django.db.backends.utils import add_implementation
@add_implementation(YearExact, 'mysql')
def mysql_year_exact(node, qn, connection):
lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
rhs_sql, rhs_params = node.process_rhs(qn, connection)
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
Now, on MySQL instead of calling as_sql() of the YearExact Django will use the
above compile implementation.

View File

@ -20,16 +20,13 @@ class Div3Lookup(models.lookups.Lookup):
class Div3Extract(models.lookups.Extract):
lookup_name = 'div3'
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params
class Div3LookupWithExtract(Div3Lookup):
lookup_name = 'div3'
extract_class = Div3Extract
class YearLte(models.lookups.LessThanOrEqual):
"""
The purpose of this lookup is to efficiently compare the year of the field.
@ -50,6 +47,8 @@ class YearLte(models.lookups.LessThanOrEqual):
class YearExtract(models.lookups.Extract):
lookup_name = 'year'
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params
@ -61,12 +60,44 @@ class YearExtract(models.lookups.Extract):
def get_lookup(self, lookup):
if lookup == 'lte':
return YearLte
elif lookup == 'exact':
return YearExact
else:
return super(YearExtract, self).get_lookup(lookup)
class YearWithExtract(models.lookups.Year):
extract_class = YearExtract
class YearExact(models.lookups.Lookup):
def as_sql(self, qn, connection):
# We will need to skip the extract part, and instead go
# directly with the originating field, that is self.lhs.lhs
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
# Note that we must be careful so that we have params in the
# same order as we have the parts in the SQL.
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
# We use PostgreSQL specific SQL here. Note that we must do the
# conversions in SQL instead of in Python to support F() references.
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
@add_implementation(YearExact, 'mysql')
def mysql_year_exact(node, qn, connection):
lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
rhs_sql, rhs_params = node.process_rhs(qn, connection)
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
class InMonth(models.lookups.Lookup):
@ -158,7 +189,7 @@ class LookupTests(TestCase):
models.Field.register_lookup(AnotherEqual)
try:
@add_implementation(AnotherEqual, connection.vendor)
def custom_eq_sql(node, compiler):
def custom_eq_sql(node, qn, connection):
return '1 = 1', []
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
@ -167,7 +198,7 @@ class LookupTests(TestCase):
[a1, a2, a3, a4], lambda x: x)
@add_implementation(AnotherEqual, connection.vendor)
def another_custom_eq_sql(node, compiler):
def another_custom_eq_sql(node, qn, connection):
# If you need to override one method, it seems this is the best
# option.
node = copy(node)
@ -176,7 +207,7 @@ class LookupTests(TestCase):
def get_rhs_op(self, connection, rhs):
return ' <> %s'
node.__class__ = OverriddenAnotherEqual
return node.as_sql(compiler, compiler.connection)
return node.as_sql(qn, connection)
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='a1').order_by('name'),
@ -186,13 +217,16 @@ class LookupTests(TestCase):
models.Field._unregister_lookup(AnotherEqual)
def test_div3_extract(self):
models.IntegerField.register_lookup(Div3LookupWithExtract)
models.IntegerField.register_lookup(Div3Extract)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__div3=2),
[a2], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__lte=3),
[a1, a2, a3, a4], lambda x: x)
@ -200,19 +234,19 @@ class LookupTests(TestCase):
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3LookupWithExtract)
models.IntegerField._unregister_lookup(Div3Extract)
class YearLteTests(TestCase):
def setUp(self):
models.DateField.register_lookup(YearWithExtract)
models.DateField.register_lookup(YearExtract)
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
def tearDown(self):
models.DateField._unregister_lookup(YearWithExtract)
models.DateField._unregister_lookup(YearExtract)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_year_lte(self):
@ -220,6 +254,11 @@ class YearLteTests(TestCase):
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2012),
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(birthdate__year=2012),
[self.a2, self.a3, self.a4], lambda x: x)
self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2011),
[self.a1], lambda x: x)
@ -253,3 +292,12 @@ class YearLteTests(TestCase):
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
self.assertIn(
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used')
def test_mysql_year_exact(self):
self.assertQuerysetEqual(
Author.objects.filter(birthdate__year=2012).order_by('name'),
[self.a2, self.a3, self.a4], lambda x: x)
self.assertIn(
'concat(',
str(Author.objects.filter(birthdate__year=2012).query))