mirror of
				https://github.com/django/django.git
				synced 2025-10-26 07:06:08 +00:00 
			
		
		
		
	Fixed #24629 -- Unified Transform and Expression APIs
This commit is contained in:
		| @@ -81,14 +81,14 @@ class KeyTransformFactory(object): | ||||
|  | ||||
|  | ||||
| @HStoreField.register_lookup | ||||
| class KeysTransform(lookups.FunctionTransform): | ||||
| class KeysTransform(Transform): | ||||
|     lookup_name = 'keys' | ||||
|     function = 'akeys' | ||||
|     output_field = ArrayField(TextField()) | ||||
|  | ||||
|  | ||||
| @HStoreField.register_lookup | ||||
| class ValuesTransform(lookups.FunctionTransform): | ||||
| class ValuesTransform(Transform): | ||||
|     lookup_name = 'values' | ||||
|     function = 'avals' | ||||
|     output_field = ArrayField(TextField()) | ||||
|   | ||||
| @@ -173,7 +173,7 @@ class AdjacentToLookup(lookups.PostgresSimpleLookup): | ||||
|  | ||||
|  | ||||
| @RangeField.register_lookup | ||||
| class RangeStartsWith(lookups.FunctionTransform): | ||||
| class RangeStartsWith(models.Transform): | ||||
|     lookup_name = 'startswith' | ||||
|     function = 'lower' | ||||
|  | ||||
| @@ -183,7 +183,7 @@ class RangeStartsWith(lookups.FunctionTransform): | ||||
|  | ||||
|  | ||||
| @RangeField.register_lookup | ||||
| class RangeEndsWith(lookups.FunctionTransform): | ||||
| class RangeEndsWith(models.Transform): | ||||
|     lookup_name = 'endswith' | ||||
|     function = 'upper' | ||||
|  | ||||
| @@ -193,7 +193,7 @@ class RangeEndsWith(lookups.FunctionTransform): | ||||
|  | ||||
|  | ||||
| @RangeField.register_lookup | ||||
| class IsEmpty(lookups.FunctionTransform): | ||||
| class IsEmpty(models.Transform): | ||||
|     lookup_name = 'isempty' | ||||
|     function = 'isempty' | ||||
|     output_field = models.BooleanField() | ||||
|   | ||||
| @@ -9,12 +9,6 @@ class PostgresSimpleLookup(Lookup): | ||||
|         return '%s %s %s' % (lhs, self.operator, rhs), params | ||||
|  | ||||
|  | ||||
| class FunctionTransform(Transform): | ||||
|     def as_sql(self, qn, connection): | ||||
|         lhs, params = qn.compile(self.lhs) | ||||
|         return "%s(%s)" % (self.function, lhs), params | ||||
|  | ||||
|  | ||||
| class DataContains(PostgresSimpleLookup): | ||||
|     lookup_name = 'contains' | ||||
|     operator = '@>' | ||||
| @@ -45,7 +39,7 @@ class HasAnyKeys(PostgresSimpleLookup): | ||||
|     operator = '?|' | ||||
|  | ||||
|  | ||||
| class Unaccent(FunctionTransform): | ||||
| class Unaccent(Transform): | ||||
|     bilateral = True | ||||
|     lookup_name = 'unaccent' | ||||
|     function = 'UNACCENT' | ||||
|   | ||||
| @@ -20,10 +20,7 @@ from django.core import checks, exceptions, validators | ||||
| # purposes. | ||||
| from django.core.exceptions import FieldDoesNotExist  # NOQA | ||||
| from django.db import connection, connections, router | ||||
| from django.db.models.lookups import ( | ||||
|     Lookup, RegisterLookupMixin, Transform, default_lookups, | ||||
| ) | ||||
| from django.db.models.query_utils import QueryWrapper | ||||
| from django.db.models.query_utils import QueryWrapper, RegisterLookupMixin | ||||
| from django.utils import six, timezone | ||||
| from django.utils.datastructures import DictWrapper | ||||
| from django.utils.dateparse import ( | ||||
| @@ -120,7 +117,6 @@ class Field(RegisterLookupMixin): | ||||
|         'unique_for_date': _("%(field_label)s must be unique for " | ||||
|                              "%(date_field_label)s %(lookup_type)s."), | ||||
|     } | ||||
|     class_lookups = default_lookups.copy() | ||||
|     system_check_deprecated_details = None | ||||
|     system_check_removed_details = None | ||||
|  | ||||
| @@ -1492,22 +1488,6 @@ class DateTimeField(DateField): | ||||
|         return super(DateTimeField, self).formfield(**defaults) | ||||
|  | ||||
|  | ||||
| @DateTimeField.register_lookup | ||||
| class DateTimeDateTransform(Transform): | ||||
|     lookup_name = 'date' | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         return DateField() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         lhs, lhs_params = compiler.compile(self.lhs) | ||||
|         tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None | ||||
|         sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname) | ||||
|         lhs_params.extend(tz_params) | ||||
|         return sql, lhs_params | ||||
|  | ||||
|  | ||||
| class DecimalField(Field): | ||||
|     empty_strings_allowed = False | ||||
|     default_error_messages = { | ||||
| @@ -2450,146 +2430,3 @@ class UUIDField(Field): | ||||
|         } | ||||
|         defaults.update(kwargs) | ||||
|         return super(UUIDField, self).formfield(**defaults) | ||||
|  | ||||
|  | ||||
| class DateTransform(Transform): | ||||
|     def as_sql(self, compiler, connection): | ||||
|         sql, params = compiler.compile(self.lhs) | ||||
|         lhs_output_field = self.lhs.output_field | ||||
|         if isinstance(lhs_output_field, DateTimeField): | ||||
|             tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None | ||||
|             sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) | ||||
|             params.extend(tz_params) | ||||
|         elif isinstance(lhs_output_field, DateField): | ||||
|             sql = connection.ops.date_extract_sql(self.lookup_name, sql) | ||||
|         elif isinstance(lhs_output_field, TimeField): | ||||
|             sql = connection.ops.time_extract_sql(self.lookup_name, sql) | ||||
|         else: | ||||
|             raise ValueError('DateTransform only valid on Date/Time/DateTimeFields') | ||||
|         return sql, params | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         return IntegerField() | ||||
|  | ||||
|  | ||||
| class YearTransform(DateTransform): | ||||
|     lookup_name = 'year' | ||||
|  | ||||
|  | ||||
| class YearLookup(Lookup): | ||||
|     def year_lookup_bounds(self, connection, year): | ||||
|         output_field = self.lhs.lhs.output_field | ||||
|         if isinstance(output_field, DateTimeField): | ||||
|             bounds = connection.ops.year_lookup_bounds_for_datetime_field(year) | ||||
|         else: | ||||
|             bounds = connection.ops.year_lookup_bounds_for_date_field(year) | ||||
|         return bounds | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearExact(YearLookup): | ||||
|     lookup_name = 'exact' | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # We will need to skip the extract part and instead go | ||||
|         # directly with the originating field, that is self.lhs.lhs. | ||||
|         lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) | ||||
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection) | ||||
|         bounds = self.year_lookup_bounds(connection, rhs_params[0]) | ||||
|         params.extend(bounds) | ||||
|         return '%s BETWEEN %%s AND %%s' % lhs_sql, params | ||||
|  | ||||
|  | ||||
| class YearComparisonLookup(YearLookup): | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # We will need to skip the extract part and instead go | ||||
|         # directly with the originating field, that is self.lhs.lhs. | ||||
|         lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) | ||||
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection) | ||||
|         rhs_sql = self.get_rhs_op(connection, rhs_sql) | ||||
|         start, finish = self.year_lookup_bounds(connection, rhs_params[0]) | ||||
|         params.append(self.get_bound(start, finish)) | ||||
|         return '%s %s' % (lhs_sql, rhs_sql), params | ||||
|  | ||||
|     def get_rhs_op(self, connection, rhs): | ||||
|         return connection.operators[self.lookup_name] % rhs | ||||
|  | ||||
|     def get_bound(self): | ||||
|         raise NotImplementedError( | ||||
|             'subclasses of YearComparisonLookup must provide a get_bound() method' | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearGt(YearComparisonLookup): | ||||
|     lookup_name = 'gt' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return finish | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearGte(YearComparisonLookup): | ||||
|     lookup_name = 'gte' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return start | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearLt(YearComparisonLookup): | ||||
|     lookup_name = 'lt' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return start | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearLte(YearComparisonLookup): | ||||
|     lookup_name = 'lte' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return finish | ||||
|  | ||||
|  | ||||
| class MonthTransform(DateTransform): | ||||
|     lookup_name = 'month' | ||||
|  | ||||
|  | ||||
| class DayTransform(DateTransform): | ||||
|     lookup_name = 'day' | ||||
|  | ||||
|  | ||||
| class WeekDayTransform(DateTransform): | ||||
|     lookup_name = 'week_day' | ||||
|  | ||||
|  | ||||
| class HourTransform(DateTransform): | ||||
|     lookup_name = 'hour' | ||||
|  | ||||
|  | ||||
| class MinuteTransform(DateTransform): | ||||
|     lookup_name = 'minute' | ||||
|  | ||||
|  | ||||
| class SecondTransform(DateTransform): | ||||
|     lookup_name = 'second' | ||||
|  | ||||
|  | ||||
| DateField.register_lookup(YearTransform) | ||||
| DateField.register_lookup(MonthTransform) | ||||
| DateField.register_lookup(DayTransform) | ||||
| DateField.register_lookup(WeekDayTransform) | ||||
|  | ||||
| TimeField.register_lookup(HourTransform) | ||||
| TimeField.register_lookup(MinuteTransform) | ||||
| TimeField.register_lookup(SecondTransform) | ||||
|  | ||||
| DateTimeField.register_lookup(YearTransform) | ||||
| DateTimeField.register_lookup(MonthTransform) | ||||
| DateTimeField.register_lookup(DayTransform) | ||||
| DateTimeField.register_lookup(WeekDayTransform) | ||||
| DateTimeField.register_lookup(HourTransform) | ||||
| DateTimeField.register_lookup(MinuteTransform) | ||||
| DateTimeField.register_lookup(SecondTransform) | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| """ | ||||
| Classes that represent database functions. | ||||
| """ | ||||
| from django.db.models import DateTimeField, IntegerField | ||||
| from django.db.models.expressions import Func, Value | ||||
| from django.db.models import ( | ||||
|     DateTimeField, Func, IntegerField, Transform, Value, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class Coalesce(Func): | ||||
| @@ -123,9 +124,10 @@ class Least(Func): | ||||
|         return super(Least, self).as_sql(compiler, connection, function='MIN') | ||||
|  | ||||
|  | ||||
| class Length(Func): | ||||
| class Length(Transform): | ||||
|     """Returns the number of characters in the expression""" | ||||
|     function = 'LENGTH' | ||||
|     lookup_name = 'length' | ||||
|  | ||||
|     def __init__(self, expression, **extra): | ||||
|         output_field = extra.pop('output_field', IntegerField()) | ||||
| @@ -136,8 +138,9 @@ class Length(Func): | ||||
|         return super(Length, self).as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class Lower(Func): | ||||
| class Lower(Transform): | ||||
|     function = 'LOWER' | ||||
|     lookup_name = 'lower' | ||||
|  | ||||
|     def __init__(self, expression, **extra): | ||||
|         super(Lower, self).__init__(expression, **extra) | ||||
| @@ -188,8 +191,9 @@ class Substr(Func): | ||||
|         return super(Substr, self).as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class Upper(Func): | ||||
| class Upper(Transform): | ||||
|     function = 'UPPER' | ||||
|     lookup_name = 'upper' | ||||
|  | ||||
|     def __init__(self, expression, **extra): | ||||
|         super(Upper, self).__init__(expression, **extra) | ||||
|   | ||||
| @@ -1,101 +1,17 @@ | ||||
| import inspect | ||||
| from copy import copy | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db.models.expressions import Func, Value | ||||
| from django.db.models.fields import ( | ||||
|     DateField, DateTimeField, Field, IntegerField, TimeField, | ||||
| ) | ||||
| from django.db.models.query_utils import RegisterLookupMixin | ||||
| from django.utils import timezone | ||||
| from django.utils.functional import cached_property | ||||
| from django.utils.six.moves import range | ||||
|  | ||||
| from .query_utils import QueryWrapper | ||||
|  | ||||
|  | ||||
| class RegisterLookupMixin(object): | ||||
|     def _get_lookup(self, lookup_name): | ||||
|         try: | ||||
|             return self.class_lookups[lookup_name] | ||||
|         except KeyError: | ||||
|             # To allow for inheritance, check parent class' class_lookups. | ||||
|             for parent in inspect.getmro(self.__class__): | ||||
|                 if 'class_lookups' not in parent.__dict__: | ||||
|                     continue | ||||
|                 if lookup_name in parent.class_lookups: | ||||
|                     return parent.class_lookups[lookup_name] | ||||
|         except AttributeError: | ||||
|             # This class didn't have any class_lookups | ||||
|             pass | ||||
|         return None | ||||
|  | ||||
|     def get_lookup(self, lookup_name): | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, 'output_field'): | ||||
|             return self.output_field.get_lookup(lookup_name) | ||||
|         if found is not None and not issubclass(found, Lookup): | ||||
|             return None | ||||
|         return found | ||||
|  | ||||
|     def get_transform(self, lookup_name): | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, 'output_field'): | ||||
|             return self.output_field.get_transform(lookup_name) | ||||
|         if found is not None and not issubclass(found, Transform): | ||||
|             return None | ||||
|         return found | ||||
|  | ||||
|     @classmethod | ||||
|     def register_lookup(cls, lookup): | ||||
|         if 'class_lookups' not in cls.__dict__: | ||||
|             cls.class_lookups = {} | ||||
|         cls.class_lookups[lookup.lookup_name] = lookup | ||||
|         return lookup | ||||
|  | ||||
|     @classmethod | ||||
|     def _unregister_lookup(cls, lookup): | ||||
|         """ | ||||
|         Removes given lookup from cls lookups. Meant to be used in | ||||
|         tests only. | ||||
|         """ | ||||
|         del cls.class_lookups[lookup.lookup_name] | ||||
|  | ||||
|  | ||||
| class Transform(RegisterLookupMixin): | ||||
|  | ||||
|     bilateral = False | ||||
|  | ||||
|     def __init__(self, lhs, lookups): | ||||
|         self.lhs = lhs | ||||
|         self.init_lookups = lookups[:] | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         return self.lhs.output_field | ||||
|  | ||||
|     def copy(self): | ||||
|         return copy(self) | ||||
|  | ||||
|     def relabeled_clone(self, relabels): | ||||
|         copy = self.copy() | ||||
|         copy.lhs = self.lhs.relabeled_clone(relabels) | ||||
|         return copy | ||||
|  | ||||
|     def get_group_by_cols(self): | ||||
|         return self.lhs.get_group_by_cols() | ||||
|  | ||||
|     def get_bilateral_transforms(self): | ||||
|         if hasattr(self.lhs, 'get_bilateral_transforms'): | ||||
|             bilateral_transforms = self.lhs.get_bilateral_transforms() | ||||
|         else: | ||||
|             bilateral_transforms = [] | ||||
|         if self.bilateral: | ||||
|             bilateral_transforms.append((self.__class__, self.init_lookups)) | ||||
|         return bilateral_transforms | ||||
|  | ||||
|     @cached_property | ||||
|     def contains_aggregate(self): | ||||
|         return self.lhs.contains_aggregate | ||||
|  | ||||
|  | ||||
| class Lookup(RegisterLookupMixin): | ||||
| class Lookup(object): | ||||
|     lookup_name = None | ||||
|  | ||||
|     def __init__(self, lhs, rhs): | ||||
| @@ -115,8 +31,8 @@ class Lookup(RegisterLookupMixin): | ||||
|         self.bilateral_transforms = bilateral_transforms | ||||
|  | ||||
|     def apply_bilateral_transforms(self, value): | ||||
|         for transform, lookups in self.bilateral_transforms: | ||||
|             value = transform(value, lookups) | ||||
|         for transform in self.bilateral_transforms: | ||||
|             value = transform(value) | ||||
|         return value | ||||
|  | ||||
|     def batch_process_rhs(self, compiler, connection, rhs=None): | ||||
| @@ -125,9 +41,9 @@ class Lookup(RegisterLookupMixin): | ||||
|         if self.bilateral_transforms: | ||||
|             sqls, sqls_params = [], [] | ||||
|             for p in rhs: | ||||
|                 value = QueryWrapper('%s', | ||||
|                     [self.lhs.output_field.get_db_prep_value(p, connection)]) | ||||
|                 value = Value(p, output_field=self.lhs.output_field) | ||||
|                 value = self.apply_bilateral_transforms(value) | ||||
|                 value = value.resolve_expression(compiler.query) | ||||
|                 sql, sql_params = compiler.compile(value) | ||||
|                 sqls.append(sql) | ||||
|                 sqls_params.extend(sql_params) | ||||
| @@ -155,9 +71,9 @@ class Lookup(RegisterLookupMixin): | ||||
|             if self.rhs_is_direct_value(): | ||||
|                 # Do not call get_db_prep_lookup here as the value will be | ||||
|                 # transformed before being used for lookup | ||||
|                 value = QueryWrapper("%s", | ||||
|                     [self.lhs.output_field.get_db_prep_value(value, connection)]) | ||||
|                 value = Value(value, output_field=self.lhs.output_field) | ||||
|             value = self.apply_bilateral_transforms(value) | ||||
|             value = value.resolve_expression(compiler.query) | ||||
|         # Due to historical reasons there are a couple of different | ||||
|         # ways to produce sql here. get_compiler is likely a Query | ||||
|         # instance, _as_sql QuerySet and as_sql just something with | ||||
| @@ -201,6 +117,31 @@ class Lookup(RegisterLookupMixin): | ||||
|         return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) | ||||
|  | ||||
|  | ||||
| class Transform(RegisterLookupMixin, Func): | ||||
|     """ | ||||
|     RegisterLookupMixin() is first so that get_lookup() and get_transform() | ||||
|     first examine self and then check output_field. | ||||
|     """ | ||||
|     bilateral = False | ||||
|  | ||||
|     def __init__(self, expression, **extra): | ||||
|         # Restrict Transform to allow only a single expression. | ||||
|         super(Transform, self).__init__(expression, **extra) | ||||
|  | ||||
|     @property | ||||
|     def lhs(self): | ||||
|         return self.get_source_expressions()[0] | ||||
|  | ||||
|     def get_bilateral_transforms(self): | ||||
|         if hasattr(self.lhs, 'get_bilateral_transforms'): | ||||
|             bilateral_transforms = self.lhs.get_bilateral_transforms() | ||||
|         else: | ||||
|             bilateral_transforms = [] | ||||
|         if self.bilateral: | ||||
|             bilateral_transforms.append(self.__class__) | ||||
|         return bilateral_transforms | ||||
|  | ||||
|  | ||||
| class BuiltinLookup(Lookup): | ||||
|     def process_lhs(self, compiler, connection, lhs=None): | ||||
|         lhs_sql, params = super(BuiltinLookup, self).process_lhs( | ||||
| @@ -223,12 +164,9 @@ class BuiltinLookup(Lookup): | ||||
|         return connection.operators[self.lookup_name] % rhs | ||||
|  | ||||
|  | ||||
| default_lookups = {} | ||||
|  | ||||
|  | ||||
| class Exact(BuiltinLookup): | ||||
|     lookup_name = 'exact' | ||||
| default_lookups['exact'] = Exact | ||||
| Field.register_lookup(Exact) | ||||
|  | ||||
|  | ||||
| class IExact(BuiltinLookup): | ||||
| @@ -241,27 +179,27 @@ class IExact(BuiltinLookup): | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| default_lookups['iexact'] = IExact | ||||
| Field.register_lookup(IExact) | ||||
|  | ||||
|  | ||||
| class GreaterThan(BuiltinLookup): | ||||
|     lookup_name = 'gt' | ||||
| default_lookups['gt'] = GreaterThan | ||||
| Field.register_lookup(GreaterThan) | ||||
|  | ||||
|  | ||||
| class GreaterThanOrEqual(BuiltinLookup): | ||||
|     lookup_name = 'gte' | ||||
| default_lookups['gte'] = GreaterThanOrEqual | ||||
| Field.register_lookup(GreaterThanOrEqual) | ||||
|  | ||||
|  | ||||
| class LessThan(BuiltinLookup): | ||||
|     lookup_name = 'lt' | ||||
| default_lookups['lt'] = LessThan | ||||
| Field.register_lookup(LessThan) | ||||
|  | ||||
|  | ||||
| class LessThanOrEqual(BuiltinLookup): | ||||
|     lookup_name = 'lte' | ||||
| default_lookups['lte'] = LessThanOrEqual | ||||
| Field.register_lookup(LessThanOrEqual) | ||||
|  | ||||
|  | ||||
| class In(BuiltinLookup): | ||||
| @@ -286,32 +224,32 @@ class In(BuiltinLookup): | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         max_in_list_size = connection.ops.max_in_list_size() | ||||
|         if self.rhs_is_direct_value() and (max_in_list_size and | ||||
|                                            len(self.rhs) > max_in_list_size): | ||||
|             # This is a special case for Oracle which limits the number of elements | ||||
|             # which can appear in an 'IN' clause. | ||||
|             lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|             rhs, rhs_params = self.batch_process_rhs(compiler, connection) | ||||
|             in_clause_elements = ['('] | ||||
|             params = [] | ||||
|             for offset in range(0, len(rhs_params), max_in_list_size): | ||||
|                 if offset > 0: | ||||
|                     in_clause_elements.append(' OR ') | ||||
|                 in_clause_elements.append('%s IN (' % lhs) | ||||
|                 params.extend(lhs_params) | ||||
|                 sqls = rhs[offset: offset + max_in_list_size] | ||||
|                 sqls_params = rhs_params[offset: offset + max_in_list_size] | ||||
|                 param_group = ', '.join(sqls) | ||||
|                 in_clause_elements.append(param_group) | ||||
|                 in_clause_elements.append(')') | ||||
|                 params.extend(sqls_params) | ||||
|         if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size: | ||||
|             return self.split_parameter_list_as_sql(compiler, connection) | ||||
|         return super(In, self).as_sql(compiler, connection) | ||||
|  | ||||
|     def split_parameter_list_as_sql(self, compiler, connection): | ||||
|         # This is a special case for databases which limit the number of | ||||
|         # elements which can appear in an 'IN' clause. | ||||
|         max_in_list_size = connection.ops.max_in_list_size() | ||||
|         lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|         rhs, rhs_params = self.batch_process_rhs(compiler, connection) | ||||
|         in_clause_elements = ['('] | ||||
|         params = [] | ||||
|         for offset in range(0, len(rhs_params), max_in_list_size): | ||||
|             if offset > 0: | ||||
|                 in_clause_elements.append(' OR ') | ||||
|             in_clause_elements.append('%s IN (' % lhs) | ||||
|             params.extend(lhs_params) | ||||
|             sqls = rhs[offset: offset + max_in_list_size] | ||||
|             sqls_params = rhs_params[offset: offset + max_in_list_size] | ||||
|             param_group = ', '.join(sqls) | ||||
|             in_clause_elements.append(param_group) | ||||
|             in_clause_elements.append(')') | ||||
|             return ''.join(in_clause_elements), params | ||||
|         else: | ||||
|             return super(In, self).as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| default_lookups['in'] = In | ||||
|             params.extend(sqls_params) | ||||
|         in_clause_elements.append(')') | ||||
|         return ''.join(in_clause_elements), params | ||||
| Field.register_lookup(In) | ||||
|  | ||||
|  | ||||
| class PatternLookup(BuiltinLookup): | ||||
| @@ -342,16 +280,12 @@ class Contains(PatternLookup): | ||||
|         if params and not self.bilateral_transforms: | ||||
|             params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0]) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| default_lookups['contains'] = Contains | ||||
| Field.register_lookup(Contains) | ||||
|  | ||||
|  | ||||
| class IContains(Contains): | ||||
|     lookup_name = 'icontains' | ||||
|  | ||||
|  | ||||
| default_lookups['icontains'] = IContains | ||||
| Field.register_lookup(IContains) | ||||
|  | ||||
|  | ||||
| class StartsWith(PatternLookup): | ||||
| @@ -362,9 +296,7 @@ class StartsWith(PatternLookup): | ||||
|         if params and not self.bilateral_transforms: | ||||
|             params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| default_lookups['startswith'] = StartsWith | ||||
| Field.register_lookup(StartsWith) | ||||
|  | ||||
|  | ||||
| class IStartsWith(PatternLookup): | ||||
| @@ -375,9 +307,7 @@ class IStartsWith(PatternLookup): | ||||
|         if params and not self.bilateral_transforms: | ||||
|             params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| default_lookups['istartswith'] = IStartsWith | ||||
| Field.register_lookup(IStartsWith) | ||||
|  | ||||
|  | ||||
| class EndsWith(PatternLookup): | ||||
| @@ -388,9 +318,7 @@ class EndsWith(PatternLookup): | ||||
|         if params and not self.bilateral_transforms: | ||||
|             params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| default_lookups['endswith'] = EndsWith | ||||
| Field.register_lookup(EndsWith) | ||||
|  | ||||
|  | ||||
| class IEndsWith(PatternLookup): | ||||
| @@ -401,9 +329,7 @@ class IEndsWith(PatternLookup): | ||||
|         if params and not self.bilateral_transforms: | ||||
|             params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| default_lookups['iendswith'] = IEndsWith | ||||
| Field.register_lookup(IEndsWith) | ||||
|  | ||||
|  | ||||
| class Between(BuiltinLookup): | ||||
| @@ -424,8 +350,7 @@ class Range(BuiltinLookup): | ||||
|             return self.batch_process_rhs(compiler, connection) | ||||
|         else: | ||||
|             return super(Range, self).process_rhs(compiler, connection) | ||||
|  | ||||
| default_lookups['range'] = Range | ||||
| Field.register_lookup(Range) | ||||
|  | ||||
|  | ||||
| class IsNull(BuiltinLookup): | ||||
| @@ -437,7 +362,7 @@ class IsNull(BuiltinLookup): | ||||
|             return "%s IS NULL" % sql, params | ||||
|         else: | ||||
|             return "%s IS NOT NULL" % sql, params | ||||
| default_lookups['isnull'] = IsNull | ||||
| Field.register_lookup(IsNull) | ||||
|  | ||||
|  | ||||
| class Search(BuiltinLookup): | ||||
| @@ -448,8 +373,7 @@ class Search(BuiltinLookup): | ||||
|         rhs, rhs_params = self.process_rhs(compiler, connection) | ||||
|         sql_template = connection.ops.fulltext_search_sql(field_name=lhs) | ||||
|         return sql_template, lhs_params + rhs_params | ||||
|  | ||||
| default_lookups['search'] = Search | ||||
| Field.register_lookup(Search) | ||||
|  | ||||
|  | ||||
| class Regex(BuiltinLookup): | ||||
| @@ -463,9 +387,168 @@ class Regex(BuiltinLookup): | ||||
|             rhs, rhs_params = self.process_rhs(compiler, connection) | ||||
|             sql_template = connection.ops.regex_lookup(self.lookup_name) | ||||
|             return sql_template % (lhs, rhs), lhs_params + rhs_params | ||||
| default_lookups['regex'] = Regex | ||||
| Field.register_lookup(Regex) | ||||
|  | ||||
|  | ||||
| class IRegex(Regex): | ||||
|     lookup_name = 'iregex' | ||||
| default_lookups['iregex'] = IRegex | ||||
| Field.register_lookup(IRegex) | ||||
|  | ||||
|  | ||||
| class DateTimeDateTransform(Transform): | ||||
|     lookup_name = 'date' | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         return DateField() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         lhs, lhs_params = compiler.compile(self.lhs) | ||||
|         tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None | ||||
|         sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname) | ||||
|         lhs_params.extend(tz_params) | ||||
|         return sql, lhs_params | ||||
|  | ||||
|  | ||||
| class DateTransform(Transform): | ||||
|     def as_sql(self, compiler, connection): | ||||
|         sql, params = compiler.compile(self.lhs) | ||||
|         lhs_output_field = self.lhs.output_field | ||||
|         if isinstance(lhs_output_field, DateTimeField): | ||||
|             tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None | ||||
|             sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) | ||||
|             params.extend(tz_params) | ||||
|         elif isinstance(lhs_output_field, DateField): | ||||
|             sql = connection.ops.date_extract_sql(self.lookup_name, sql) | ||||
|         elif isinstance(lhs_output_field, TimeField): | ||||
|             sql = connection.ops.time_extract_sql(self.lookup_name, sql) | ||||
|         else: | ||||
|             raise ValueError('DateTransform only valid on Date/Time/DateTimeFields') | ||||
|         return sql, params | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         return IntegerField() | ||||
|  | ||||
|  | ||||
| class YearTransform(DateTransform): | ||||
|     lookup_name = 'year' | ||||
|  | ||||
|  | ||||
| class YearLookup(Lookup): | ||||
|     def year_lookup_bounds(self, connection, year): | ||||
|         output_field = self.lhs.lhs.output_field | ||||
|         if isinstance(output_field, DateTimeField): | ||||
|             bounds = connection.ops.year_lookup_bounds_for_datetime_field(year) | ||||
|         else: | ||||
|             bounds = connection.ops.year_lookup_bounds_for_date_field(year) | ||||
|         return bounds | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearExact(YearLookup): | ||||
|     lookup_name = 'exact' | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # We will need to skip the extract part and instead go | ||||
|         # directly with the originating field, that is self.lhs.lhs. | ||||
|         lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) | ||||
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection) | ||||
|         bounds = self.year_lookup_bounds(connection, rhs_params[0]) | ||||
|         params.extend(bounds) | ||||
|         return '%s BETWEEN %%s AND %%s' % lhs_sql, params | ||||
|  | ||||
|  | ||||
| class YearComparisonLookup(YearLookup): | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # We will need to skip the extract part and instead go | ||||
|         # directly with the originating field, that is self.lhs.lhs. | ||||
|         lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) | ||||
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection) | ||||
|         rhs_sql = self.get_rhs_op(connection, rhs_sql) | ||||
|         start, finish = self.year_lookup_bounds(connection, rhs_params[0]) | ||||
|         params.append(self.get_bound(start, finish)) | ||||
|         return '%s %s' % (lhs_sql, rhs_sql), params | ||||
|  | ||||
|     def get_rhs_op(self, connection, rhs): | ||||
|         return connection.operators[self.lookup_name] % rhs | ||||
|  | ||||
|     def get_bound(self): | ||||
|         raise NotImplementedError( | ||||
|             'subclasses of YearComparisonLookup must provide a get_bound() method' | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearGt(YearComparisonLookup): | ||||
|     lookup_name = 'gt' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return finish | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearGte(YearComparisonLookup): | ||||
|     lookup_name = 'gte' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return start | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearLt(YearComparisonLookup): | ||||
|     lookup_name = 'lt' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return start | ||||
|  | ||||
|  | ||||
| @YearTransform.register_lookup | ||||
| class YearLte(YearComparisonLookup): | ||||
|     lookup_name = 'lte' | ||||
|  | ||||
|     def get_bound(self, start, finish): | ||||
|         return finish | ||||
|  | ||||
|  | ||||
| class MonthTransform(DateTransform): | ||||
|     lookup_name = 'month' | ||||
|  | ||||
|  | ||||
| class DayTransform(DateTransform): | ||||
|     lookup_name = 'day' | ||||
|  | ||||
|  | ||||
| class WeekDayTransform(DateTransform): | ||||
|     lookup_name = 'week_day' | ||||
|  | ||||
|  | ||||
| class HourTransform(DateTransform): | ||||
|     lookup_name = 'hour' | ||||
|  | ||||
|  | ||||
| class MinuteTransform(DateTransform): | ||||
|     lookup_name = 'minute' | ||||
|  | ||||
|  | ||||
| class SecondTransform(DateTransform): | ||||
|     lookup_name = 'second' | ||||
|  | ||||
|  | ||||
| DateField.register_lookup(YearTransform) | ||||
| DateField.register_lookup(MonthTransform) | ||||
| DateField.register_lookup(DayTransform) | ||||
| DateField.register_lookup(WeekDayTransform) | ||||
|  | ||||
| TimeField.register_lookup(HourTransform) | ||||
| TimeField.register_lookup(MinuteTransform) | ||||
| TimeField.register_lookup(SecondTransform) | ||||
|  | ||||
| DateTimeField.register_lookup(DateTimeDateTransform) | ||||
| DateTimeField.register_lookup(YearTransform) | ||||
| DateTimeField.register_lookup(MonthTransform) | ||||
| DateTimeField.register_lookup(DayTransform) | ||||
| DateTimeField.register_lookup(WeekDayTransform) | ||||
| DateTimeField.register_lookup(HourTransform) | ||||
| DateTimeField.register_lookup(MinuteTransform) | ||||
| DateTimeField.register_lookup(SecondTransform) | ||||
|   | ||||
| @@ -7,6 +7,7 @@ circular import difficulties. | ||||
| """ | ||||
| from __future__ import unicode_literals | ||||
|  | ||||
| import inspect | ||||
| from collections import namedtuple | ||||
|  | ||||
| from django.apps import apps | ||||
| @@ -169,6 +170,60 @@ class DeferredAttribute(object): | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class RegisterLookupMixin(object): | ||||
|     def _get_lookup(self, lookup_name): | ||||
|         try: | ||||
|             return self.class_lookups[lookup_name] | ||||
|         except KeyError: | ||||
|             # To allow for inheritance, check parent class' class_lookups. | ||||
|             for parent in inspect.getmro(self.__class__): | ||||
|                 if 'class_lookups' not in parent.__dict__: | ||||
|                     continue | ||||
|                 if lookup_name in parent.class_lookups: | ||||
|                     return parent.class_lookups[lookup_name] | ||||
|         except AttributeError: | ||||
|             # This class didn't have any class_lookups | ||||
|             pass | ||||
|         return None | ||||
|  | ||||
|     def get_lookup(self, lookup_name): | ||||
|         from django.db.models.lookups import Lookup | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, 'output_field'): | ||||
|             return self.output_field.get_lookup(lookup_name) | ||||
|         if found is not None and not issubclass(found, Lookup): | ||||
|             return None | ||||
|         return found | ||||
|  | ||||
|     def get_transform(self, lookup_name): | ||||
|         from django.db.models.lookups import Transform | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, 'output_field'): | ||||
|             return self.output_field.get_transform(lookup_name) | ||||
|         if found is not None and not issubclass(found, Transform): | ||||
|             return None | ||||
|         return found | ||||
|  | ||||
|     @classmethod | ||||
|     def register_lookup(cls, lookup, lookup_name=None): | ||||
|         if lookup_name is None: | ||||
|             lookup_name = lookup.lookup_name | ||||
|         if 'class_lookups' not in cls.__dict__: | ||||
|             cls.class_lookups = {} | ||||
|         cls.class_lookups[lookup_name] = lookup | ||||
|         return lookup | ||||
|  | ||||
|     @classmethod | ||||
|     def _unregister_lookup(cls, lookup, lookup_name=None): | ||||
|         """ | ||||
|         Remove given lookup from cls lookups. For use in tests only as it's | ||||
|         not thread-safe. | ||||
|         """ | ||||
|         if lookup_name is None: | ||||
|             lookup_name = lookup.lookup_name | ||||
|         del cls.class_lookups[lookup_name] | ||||
|  | ||||
|  | ||||
| def select_related_descend(field, restricted, requested, load_fields, reverse=False): | ||||
|     """ | ||||
|     Returns True if this field should be used to descend deeper for | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import copy | ||||
| import warnings | ||||
|  | ||||
| from django.db.models.fields import FloatField, IntegerField | ||||
| from django.db.models.lookups import RegisterLookupMixin | ||||
| from django.db.models.query_utils import RegisterLookupMixin | ||||
| from django.utils.deprecation import RemovedInDjango110Warning | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|   | ||||
| @@ -1105,9 +1105,9 @@ class Query(object): | ||||
|         Helper method for build_lookup. Tries to fetch and initialize | ||||
|         a transform for name parameter from lhs. | ||||
|         """ | ||||
|         next = lhs.get_transform(name) | ||||
|         if next: | ||||
|             return next(lhs, rest_of_lookups) | ||||
|         transform_class = lhs.get_transform(name) | ||||
|         if transform_class: | ||||
|             return transform_class(lhs) | ||||
|         else: | ||||
|             raise FieldError( | ||||
|                 "Unsupported lookup '%s' for %s or join on the field not " | ||||
|   | ||||
| @@ -120,10 +120,7 @@ function ``ABS()`` to transform the value before comparison:: | ||||
|  | ||||
|   class AbsoluteValue(Transform): | ||||
|       lookup_name = 'abs' | ||||
|  | ||||
|       def as_sql(self, compiler, connection): | ||||
|           lhs, params = compiler.compile(self.lhs) | ||||
|           return "ABS(%s)" % lhs, params | ||||
|       function = 'ABS' | ||||
|  | ||||
| Next, let's register it for ``IntegerField``:: | ||||
|  | ||||
| @@ -157,10 +154,7 @@ be done by adding an ``output_field`` attribute to the transform:: | ||||
|  | ||||
|     class AbsoluteValue(Transform): | ||||
|         lookup_name = 'abs' | ||||
|  | ||||
|         def as_sql(self, compiler, connection): | ||||
|             lhs, params = compiler.compile(self.lhs) | ||||
|             return "ABS(%s)" % lhs, params | ||||
|         function = 'ABS' | ||||
|  | ||||
|         @property | ||||
|         def output_field(self): | ||||
| @@ -243,12 +237,9 @@ this transformation should apply to both ``lhs`` and ``rhs``:: | ||||
|  | ||||
|   class UpperCase(Transform): | ||||
|       lookup_name = 'upper' | ||||
|       function = 'UPPER' | ||||
|       bilateral = True | ||||
|  | ||||
|       def as_sql(self, compiler, connection): | ||||
|           lhs, params = compiler.compile(self.lhs) | ||||
|           return "UPPER(%s)" % lhs, params | ||||
|  | ||||
| Next, let's register it:: | ||||
|  | ||||
|   from django.db.models import CharField, TextField | ||||
|   | ||||
| @@ -180,6 +180,18 @@ Usage example:: | ||||
|     >>> print(author.name_length, author.goes_by_length) | ||||
|     (14, None) | ||||
|  | ||||
| It can also be registered as a transform. For example:: | ||||
|  | ||||
|     >>> from django.db.models import CharField | ||||
|     >>> from django.db.models.functions import Length | ||||
|     >>> CharField.register_lookup(Length, 'length') | ||||
|     >>> # Get authors whose name is longer than 7 characters | ||||
|     >>> authors = Author.objects.filter(name__length__gt=7) | ||||
|  | ||||
| .. versionchanged:: 1.9 | ||||
|  | ||||
|     The ability to register the function as a transform was added. | ||||
|  | ||||
| Lower | ||||
| ------ | ||||
|  | ||||
| @@ -188,6 +200,8 @@ Lower | ||||
| Accepts a single text field or expression and returns the lowercase | ||||
| representation. | ||||
|  | ||||
| It can also be registered as a transform as described in :class:`Length`. | ||||
|  | ||||
| Usage example:: | ||||
|  | ||||
|     >>> from django.db.models.functions import Lower | ||||
| @@ -196,6 +210,10 @@ Usage example:: | ||||
|     >>> print(author.name_lower) | ||||
|     margaret smith | ||||
|  | ||||
| .. versionchanged:: 1.9 | ||||
|  | ||||
|     The ability to register the function as a transform was added. | ||||
|  | ||||
| Now | ||||
| --- | ||||
|  | ||||
| @@ -246,6 +264,8 @@ Upper | ||||
| Accepts a single text field or expression and returns the uppercase | ||||
| representation. | ||||
|  | ||||
| It can also be registered as a transform as described in :class:`Length`. | ||||
|  | ||||
| Usage example:: | ||||
|  | ||||
|     >>> from django.db.models.functions import Upper | ||||
| @@ -253,3 +273,7 @@ Usage example:: | ||||
|     >>> author = Author.objects.annotate(name_upper=Upper('name')).get() | ||||
|     >>> print(author.name_upper) | ||||
|     MARGARET SMITH | ||||
|  | ||||
| .. versionchanged:: 1.9 | ||||
|  | ||||
|     The ability to register the function as a transform was added. | ||||
|   | ||||
| @@ -42,12 +42,17 @@ register lookups on itself. The two prominent examples are | ||||
|  | ||||
|     A mixin that implements the lookup API on a class. | ||||
|  | ||||
|     .. classmethod:: register_lookup(lookup) | ||||
|     .. classmethod:: register_lookup(lookup, lookup_name=None) | ||||
|  | ||||
|         Registers a new lookup in the class. For example | ||||
|         ``DateField.register_lookup(YearExact)`` will register ``YearExact`` | ||||
|         lookup on ``DateField``. It overrides a lookup that already exists with | ||||
|         the same name. | ||||
|         the same name. ``lookup_name`` will be used for this lookup if | ||||
|         provided, otherwise ``lookup.lookup_name`` will be used. | ||||
|  | ||||
|         .. versionchanged:: 1.9 | ||||
|  | ||||
|             The ``lookup_name`` parameter was added. | ||||
|  | ||||
|     .. method:: get_lookup(lookup_name) | ||||
|  | ||||
| @@ -125,7 +130,14 @@ Transform reference | ||||
|     ``<expression>__<transformation>`` (e.g. ``date__year``). | ||||
|  | ||||
|     This class follows the :ref:`Query Expression API <query-expression>`, which | ||||
|     implies that you can use ``<expression>__<transform1>__<transform2>``. | ||||
|     implies that you can use ``<expression>__<transform1>__<transform2>``. It's | ||||
|     a specialized :ref:`Func() expression <func-expressions>` that only accepts | ||||
|     one argument.  It can also be used on the right hand side of a filter or | ||||
|     directly as an annotation. | ||||
|  | ||||
|     .. versionchanged:: 1.9 | ||||
|  | ||||
|         ``Transform`` is now a subclass of ``Func``. | ||||
|  | ||||
|     .. attribute:: bilateral | ||||
|  | ||||
| @@ -152,18 +164,6 @@ Transform reference | ||||
|         :class:`~django.db.models.Field` instance. By default is the same as | ||||
|         its ``lhs.output_field``. | ||||
|  | ||||
|     .. method:: as_sql | ||||
|  | ||||
|         To be overridden; raises :exc:`NotImplementedError`. | ||||
|  | ||||
|     .. method:: get_lookup(lookup_name) | ||||
|  | ||||
|         Same as :meth:`~lookups.RegisterLookupMixin.get_lookup()`. | ||||
|  | ||||
|     .. method:: get_transform(transform_name) | ||||
|  | ||||
|         Same as :meth:`~lookups.RegisterLookupMixin.get_transform()`. | ||||
|  | ||||
| Lookup reference | ||||
| ~~~~~~~~~~~~~~~~ | ||||
|  | ||||
|   | ||||
| @@ -520,6 +520,14 @@ Models | ||||
| * Added the :class:`~django.db.models.functions.Now` database function, which | ||||
|   returns the current date and time. | ||||
|  | ||||
| * :class:`~django.db.models.Transform` is now a subclass of | ||||
|   :ref:`Func() <func-expressions>` which allows ``Transform``\s to be used on | ||||
|   the right hand side of an expression, just like regular ``Func``\s. This | ||||
|   allows registering some database functions like | ||||
|   :class:`~django.db.models.functions.Length`, | ||||
|   :class:`~django.db.models.functions.Lower`, and | ||||
|   :class:`~django.db.models.functions.Upper` as transforms. | ||||
|  | ||||
| * :class:`~django.db.models.SlugField` now accepts an | ||||
|   :attr:`~django.db.models.SlugField.allow_unicode` argument to allow Unicode | ||||
|   characters in slugs. | ||||
|   | ||||
| @@ -126,11 +126,17 @@ class YearLte(models.lookups.LessThanOrEqual): | ||||
|         return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params | ||||
|  | ||||
|  | ||||
| class SQLFunc(models.Lookup): | ||||
|     def __init__(self, name, *args, **kwargs): | ||||
|         super(SQLFunc, self).__init__(*args, **kwargs) | ||||
|         self.name = name | ||||
| class Exactly(models.lookups.Exact): | ||||
|     """ | ||||
|     This lookup is used to test lookup registration. | ||||
|     """ | ||||
|     lookup_name = 'exactly' | ||||
|  | ||||
|     def get_rhs_op(self, connection, rhs): | ||||
|         return connection.operators['exact'] % rhs | ||||
|  | ||||
|  | ||||
| class SQLFuncMixin(object): | ||||
|     def as_sql(self, compiler, connection): | ||||
|         return '%s()', [self.name] | ||||
|  | ||||
| @@ -139,13 +145,28 @@ class SQLFunc(models.Lookup): | ||||
|         return CustomField() | ||||
|  | ||||
|  | ||||
| class SQLFuncLookup(SQLFuncMixin, models.Lookup): | ||||
|     def __init__(self, name, *args, **kwargs): | ||||
|         super(SQLFuncLookup, self).__init__(*args, **kwargs) | ||||
|         self.name = name | ||||
|  | ||||
|  | ||||
| class SQLFuncTransform(SQLFuncMixin, models.Transform): | ||||
|     def __init__(self, name, *args, **kwargs): | ||||
|         super(SQLFuncTransform, self).__init__(*args, **kwargs) | ||||
|         self.name = name | ||||
|  | ||||
|  | ||||
| class SQLFuncFactory(object): | ||||
|  | ||||
|     def __init__(self, name): | ||||
|     def __init__(self, key, name): | ||||
|         self.key = key | ||||
|         self.name = name | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         return SQLFunc(self.name, *args, **kwargs) | ||||
|         if self.key == 'lookupfunc': | ||||
|             return SQLFuncLookup(self.name, *args, **kwargs) | ||||
|         return SQLFuncTransform(self.name, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class CustomField(models.TextField): | ||||
| @@ -153,13 +174,13 @@ class CustomField(models.TextField): | ||||
|     def get_lookup(self, lookup_name): | ||||
|         if lookup_name.startswith('lookupfunc_'): | ||||
|             key, name = lookup_name.split('_', 1) | ||||
|             return SQLFuncFactory(name) | ||||
|             return SQLFuncFactory(key, name) | ||||
|         return super(CustomField, self).get_lookup(lookup_name) | ||||
|  | ||||
|     def get_transform(self, lookup_name): | ||||
|         if lookup_name.startswith('transformfunc_'): | ||||
|             key, name = lookup_name.split('_', 1) | ||||
|             return SQLFuncFactory(name) | ||||
|             return SQLFuncFactory(key, name) | ||||
|         return super(CustomField, self).get_transform(lookup_name) | ||||
|  | ||||
|  | ||||
| @@ -200,6 +221,27 @@ class DateTimeTransform(models.Transform): | ||||
|  | ||||
|  | ||||
| class LookupTests(TestCase): | ||||
|  | ||||
|     def test_custom_name_lookup(self): | ||||
|         a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) | ||||
|         Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) | ||||
|         custom_lookup_name = 'isactually' | ||||
|         custom_transform_name = 'justtheyear' | ||||
|         try: | ||||
|             models.DateField.register_lookup(YearTransform) | ||||
|             models.DateField.register_lookup(YearTransform, custom_transform_name) | ||||
|             YearTransform.register_lookup(Exactly) | ||||
|             YearTransform.register_lookup(Exactly, custom_lookup_name) | ||||
|             qs1 = Author.objects.filter(birthdate__testyear__exactly=1981) | ||||
|             qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981) | ||||
|             self.assertQuerysetEqual(qs1, [a1], lambda x: x) | ||||
|             self.assertQuerysetEqual(qs2, [a1], lambda x: x) | ||||
|         finally: | ||||
|             YearTransform._unregister_lookup(Exactly) | ||||
|             YearTransform._unregister_lookup(Exactly, custom_lookup_name) | ||||
|             models.DateField._unregister_lookup(YearTransform) | ||||
|             models.DateField._unregister_lookup(YearTransform, custom_transform_name) | ||||
|  | ||||
|     def test_basic_lookup(self): | ||||
|         a1 = Author.objects.create(name='a1', age=1) | ||||
|         a2 = Author.objects.create(name='a2', age=2) | ||||
| @@ -299,6 +341,19 @@ class BilateralTransformTests(TestCase): | ||||
|             with self.assertRaises(NotImplementedError): | ||||
|                 Author.objects.filter(name__upper__in=Author.objects.values_list('name')) | ||||
|  | ||||
|     def test_bilateral_multi_value(self): | ||||
|         with register_lookup(models.CharField, UpperBilateralTransform): | ||||
|             Author.objects.bulk_create([ | ||||
|                 Author(name='Foo'), | ||||
|                 Author(name='Bar'), | ||||
|                 Author(name='Ray'), | ||||
|             ]) | ||||
|             self.assertQuerysetEqual( | ||||
|                 Author.objects.filter(name__upper__in=['foo', 'bar', 'doe']).order_by('name'), | ||||
|                 ['Bar', 'Foo'], | ||||
|                 lambda a: a.name | ||||
|             ) | ||||
|  | ||||
|     def test_div3_bilateral_extract(self): | ||||
|         with register_lookup(models.IntegerField, Div3BilateralTransform): | ||||
|             a1 = Author.objects.create(name='a1', age=1) | ||||
|   | ||||
| @@ -547,3 +547,97 @@ class FunctionTests(TestCase): | ||||
|             ['How to Time Travel'], | ||||
|             lambda a: a.title | ||||
|         ) | ||||
|  | ||||
|     def test_length_transform(self): | ||||
|         try: | ||||
|             CharField.register_lookup(Length, 'length') | ||||
|             Author.objects.create(name='John Smith', alias='smithj') | ||||
|             Author.objects.create(name='Rhonda') | ||||
|             authors = Author.objects.filter(name__length__gt=7) | ||||
|             self.assertQuerysetEqual( | ||||
|                 authors.order_by('name'), [ | ||||
|                     'John Smith', | ||||
|                 ], | ||||
|                 lambda a: a.name | ||||
|             ) | ||||
|         finally: | ||||
|             CharField._unregister_lookup(Length, 'length') | ||||
|  | ||||
|     def test_lower_transform(self): | ||||
|         try: | ||||
|             CharField.register_lookup(Lower, 'lower') | ||||
|             Author.objects.create(name='John Smith', alias='smithj') | ||||
|             Author.objects.create(name='Rhonda') | ||||
|             authors = Author.objects.filter(name__lower__exact='john smith') | ||||
|             self.assertQuerysetEqual( | ||||
|                 authors.order_by('name'), [ | ||||
|                     'John Smith', | ||||
|                 ], | ||||
|                 lambda a: a.name | ||||
|             ) | ||||
|         finally: | ||||
|             CharField._unregister_lookup(Lower, 'lower') | ||||
|  | ||||
|     def test_upper_transform(self): | ||||
|         try: | ||||
|             CharField.register_lookup(Upper, 'upper') | ||||
|             Author.objects.create(name='John Smith', alias='smithj') | ||||
|             Author.objects.create(name='Rhonda') | ||||
|             authors = Author.objects.filter(name__upper__exact='JOHN SMITH') | ||||
|             self.assertQuerysetEqual( | ||||
|                 authors.order_by('name'), [ | ||||
|                     'John Smith', | ||||
|                 ], | ||||
|                 lambda a: a.name | ||||
|             ) | ||||
|         finally: | ||||
|             CharField._unregister_lookup(Upper, 'upper') | ||||
|  | ||||
|     def test_func_transform_bilateral(self): | ||||
|         class UpperBilateral(Upper): | ||||
|             bilateral = True | ||||
|  | ||||
|         try: | ||||
|             CharField.register_lookup(UpperBilateral, 'upper') | ||||
|             Author.objects.create(name='John Smith', alias='smithj') | ||||
|             Author.objects.create(name='Rhonda') | ||||
|             authors = Author.objects.filter(name__upper__exact='john smith') | ||||
|             self.assertQuerysetEqual( | ||||
|                 authors.order_by('name'), [ | ||||
|                     'John Smith', | ||||
|                 ], | ||||
|                 lambda a: a.name | ||||
|             ) | ||||
|         finally: | ||||
|             CharField._unregister_lookup(UpperBilateral, 'upper') | ||||
|  | ||||
|     def test_func_transform_bilateral_multivalue(self): | ||||
|         class UpperBilateral(Upper): | ||||
|             bilateral = True | ||||
|  | ||||
|         try: | ||||
|             CharField.register_lookup(UpperBilateral, 'upper') | ||||
|             Author.objects.create(name='John Smith', alias='smithj') | ||||
|             Author.objects.create(name='Rhonda') | ||||
|             authors = Author.objects.filter(name__upper__in=['john smith', 'rhonda']) | ||||
|             self.assertQuerysetEqual( | ||||
|                 authors.order_by('name'), [ | ||||
|                     'John Smith', | ||||
|                     'Rhonda', | ||||
|                 ], | ||||
|                 lambda a: a.name | ||||
|             ) | ||||
|         finally: | ||||
|             CharField._unregister_lookup(UpperBilateral, 'upper') | ||||
|  | ||||
|     def test_function_as_filter(self): | ||||
|         Author.objects.create(name='John Smith', alias='SMITHJ') | ||||
|         Author.objects.create(name='Rhonda') | ||||
|         self.assertQuerysetEqual( | ||||
|             Author.objects.filter(alias=Upper(V('smithj'))), | ||||
|             ['John Smith'], lambda x: x.name | ||||
|         ) | ||||
|         self.assertQuerysetEqual( | ||||
|             Author.objects.exclude(alias=Upper(V('smithj'))), | ||||
|             ['Rhonda'], lambda x: x.name | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user