mirror of
				https://github.com/django/django.git
				synced 2025-10-31 01:25:32 +00:00 
			
		
		
		
	Implemented nested lookups
But there is no support of using lookups outside filtering yet.
This commit is contained in:
		| @@ -1136,11 +1136,14 @@ class ForeignObject(RelatedField): | |||||||
|         pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] |         pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] | ||||||
|         return pathinfos |         return pathinfos | ||||||
|  |  | ||||||
|     def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, |     def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, | ||||||
|                               raw_value): |                               raw_value): | ||||||
|         from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR |         from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR | ||||||
|         root_constraint = constraint_class() |         root_constraint = constraint_class() | ||||||
|         assert len(targets) == len(sources) |         assert len(targets) == len(sources) | ||||||
|  |         if len(lookups) > 1: | ||||||
|  |             raise exceptions.FieldError('Relation fields do not support nested lookups') | ||||||
|  |         lookup_type = lookups[0] | ||||||
|  |  | ||||||
|         def get_normalized_value(value): |         def get_normalized_value(value): | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,27 +1,58 @@ | |||||||
| from copy import copy | from copy import copy | ||||||
|  |  | ||||||
|  | from django.core.exceptions import FieldError | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
|  | from django.utils.functional import cached_property | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Extract(object): | ||||||
|  |     def __init__(self, constraint_class, lhs): | ||||||
|  |         self.constraint_class, self.lhs = constraint_class, lhs | ||||||
|  |  | ||||||
|  |     def get_lookup(self, lookup): | ||||||
|  |         return self.output_type.get_lookup(lookup) | ||||||
|  |  | ||||||
|  |     def as_sql(self, qn, connection): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @cached_property | ||||||
|  |     def output_type(self): | ||||||
|  |         return self.lhs.output_type | ||||||
|  |  | ||||||
|  |     def relabeled_clone(self, relabels): | ||||||
|  |         return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Lookup(object): | class Lookup(object): | ||||||
|  |     lookup_name = None | ||||||
|  |     extract_class = None | ||||||
|  |  | ||||||
|     def __init__(self, constraint_class, lhs, rhs): |     def __init__(self, constraint_class, lhs, rhs): | ||||||
|         self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs |         self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs | ||||||
|  |         if rhs is None: | ||||||
|  |             if not self.extract_class: | ||||||
|  |                 raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name) | ||||||
|  |         else: | ||||||
|             self.rhs = self.get_prep_lookup() |             self.rhs = self.get_prep_lookup() | ||||||
|  |  | ||||||
|  |     def get_extract(self): | ||||||
|  |         return self.extract_class(self.constraint_class, self.lhs) | ||||||
|  |  | ||||||
|  |     def get_prep_lookup(self): | ||||||
|  |         return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) | ||||||
|  |  | ||||||
|     def get_db_prep_lookup(self, value, connection): |     def get_db_prep_lookup(self, value, connection): | ||||||
|         return ( |         return ( | ||||||
|             '%s', self.lhs.output_type.get_db_prep_lookup( |             '%s', self.lhs.output_type.get_db_prep_lookup( | ||||||
|                 self.lookup_name, value, connection, prepared=True)) |                 self.lookup_name, value, connection, prepared=True)) | ||||||
|  |  | ||||||
|     def get_prep_lookup(self): |     def process_lhs(self, qn, connection, lhs=None): | ||||||
|         return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) |         lhs = lhs or self.lhs | ||||||
|  |         return qn.compile(lhs) | ||||||
|  |  | ||||||
|     def process_lhs(self, qn, connection): |     def process_rhs(self, qn, connection, rhs=None): | ||||||
|         return qn.compile(self.lhs) |         value = rhs or self.rhs | ||||||
|  |  | ||||||
|     def process_rhs(self, qn, connection): |  | ||||||
|         value = self.rhs |  | ||||||
|         # Due to historical reasons there are a couple of different |         # Due to historical reasons there are a couple of different | ||||||
|         # ways to produce sql here. get_compiler is likely a Query |         # ways to produce sql here. get_compiler is likely a Query | ||||||
|         # instance, _as_sql QuerySet and as_sql just something with |         # instance, _as_sql QuerySet and as_sql just something with | ||||||
| @@ -118,7 +149,7 @@ class In(DjangoLookup): | |||||||
|     lookup_name = 'in' |     lookup_name = 'in' | ||||||
|  |  | ||||||
|     def get_db_prep_lookup(self, value, connection): |     def get_db_prep_lookup(self, value, connection): | ||||||
|         params = self.lhs.field.get_db_prep_lookup( |         params = self.lhs.output_type.get_db_prep_lookup( | ||||||
|             self.lookup_name, value, connection, prepared=True) |             self.lookup_name, value, connection, prepared=True) | ||||||
|         if not params: |         if not params: | ||||||
|             # TODO: check why this leads to circular import |             # TODO: check why this leads to circular import | ||||||
|   | |||||||
| @@ -100,6 +100,9 @@ class Aggregate(object): | |||||||
|     def output_type(self): |     def output_type(self): | ||||||
|         return self.field |         return self.field | ||||||
|  |  | ||||||
|  |     def get_lookup(self, lookup): | ||||||
|  |         return self.output_type.get_lookup(lookup) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Avg(Aggregate): | class Avg(Aggregate): | ||||||
|     is_computed = True |     is_computed = True | ||||||
|   | |||||||
| @@ -25,6 +25,9 @@ class Col(object): | |||||||
|     def get_cols(self): |     def get_cols(self): | ||||||
|         return [(self.alias, self.target.column)] |         return [(self.alias, self.target.column)] | ||||||
|  |  | ||||||
|  |     def get_lookup(self, name): | ||||||
|  |         return self.output_type.get_lookup(name) | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmptyResultSet(Exception): | class EmptyResultSet(Exception): | ||||||
|     pass |     pass | ||||||
|   | |||||||
| @@ -1027,19 +1027,16 @@ class Query(object): | |||||||
|         # Add the aggregate to the query |         # Add the aggregate to the query | ||||||
|         aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) |         aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) | ||||||
|  |  | ||||||
|     def prepare_lookup_value(self, value, lookup_type, can_reuse): |     def prepare_lookup_value(self, value, lookups, can_reuse): | ||||||
|  |         # Default lookup if none given is exact. | ||||||
|  |         if len(lookups) == 0: | ||||||
|  |             lookups = ['exact'] | ||||||
|         # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all |         # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all | ||||||
|         # uses of None as a query value. |         # uses of None as a query value. | ||||||
|         if len(lookup_type) > 1: |  | ||||||
|             raise FieldError('Nested lookups not allowed') |  | ||||||
|         elif len(lookup_type) == 0: |  | ||||||
|             lookup_type = 'exact' |  | ||||||
|         else: |  | ||||||
|             lookup_type = lookup_type[0] |  | ||||||
|         if value is None: |         if value is None: | ||||||
|             if lookup_type != 'exact': |             if lookups[-1] != 'exact': | ||||||
|                 raise ValueError("Cannot use None as a query value") |                 raise ValueError("Cannot use None as a query value") | ||||||
|             lookup_type = 'isnull' |             lookups[-1] = 'isnull' | ||||||
|             value = True |             value = True | ||||||
|         elif callable(value): |         elif callable(value): | ||||||
|             value = value() |             value = value() | ||||||
| @@ -1057,10 +1054,10 @@ class Query(object): | |||||||
|         # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we |         # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we | ||||||
|         # can do here. Similar thing is done in is_nullable(), too. |         # can do here. Similar thing is done in is_nullable(), too. | ||||||
|         if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and |         if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and | ||||||
|                 lookup_type == 'exact' and value == ''): |                 lookups[-1] == 'exact' and value == ''): | ||||||
|             value = True |             value = True | ||||||
|             lookup_type = 'isnull' |             lookups[-1] = ['isnull'] | ||||||
|         return value, lookup_type |         return value, lookups | ||||||
|  |  | ||||||
|     def solve_lookup_type(self, lookup): |     def solve_lookup_type(self, lookup): | ||||||
|         """ |         """ | ||||||
| @@ -1069,35 +1066,36 @@ class Query(object): | |||||||
|         lookup_splitted = lookup.split(LOOKUP_SEP) |         lookup_splitted = lookup.split(LOOKUP_SEP) | ||||||
|         aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) |         aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) | ||||||
|         if aggregate: |         if aggregate: | ||||||
|             if len(aggregate_lookups) > 1: |  | ||||||
|                 raise FieldError("Nested lookups not allowed.") |  | ||||||
|             return aggregate_lookups, (), aggregate |             return aggregate_lookups, (), aggregate | ||||||
|         _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) |         _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) | ||||||
|         field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] |         field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] | ||||||
|         if len(lookup_parts) == 0: |         if len(lookup_parts) == 0: | ||||||
|             lookup_parts = ['exact'] |             lookup_parts = ['exact'] | ||||||
|         elif len(lookup_parts) > 1: |         elif len(lookup_parts) > 1: | ||||||
|             if field_parts: |             if not field_parts: | ||||||
|                 raise FieldError( |  | ||||||
|                     'Only one lookup part allowed (found path "%s" from "%s").' % |  | ||||||
|                     (LOOKUP_SEP.join(field_parts), lookup)) |  | ||||||
|             else: |  | ||||||
|                 raise FieldError( |                 raise FieldError( | ||||||
|                     'Invalid lookup "%s" for model %s".' % |                     'Invalid lookup "%s" for model %s".' % | ||||||
|                     (lookup, self.get_meta().model.__name__)) |                     (lookup, self.get_meta().model.__name__)) | ||||||
|         else: |  | ||||||
|             if not hasattr(field, 'get_lookup_constraint'): |  | ||||||
|                 lookup_class = field.get_lookup(lookup_parts[0]) |  | ||||||
|                 if lookup_class is None and lookup_parts[0] not in self.query_terms: |  | ||||||
|                     raise FieldError( |  | ||||||
|                         'Invalid lookup name %s' % lookup_parts[0]) |  | ||||||
|         return lookup_parts, field_parts, False |         return lookup_parts, field_parts, False | ||||||
|  |  | ||||||
|     def build_lookup(self, lookup_type, lhs, rhs): |     def build_lookup(self, lookups, lhs, rhs): | ||||||
|         if hasattr(lhs.output_type, 'get_lookup'): |         lookups = lookups[:] | ||||||
|             lookup = lhs.output_type.get_lookup(lookup_type) |         lookups.reverse() | ||||||
|             if lookup: |         while lookups: | ||||||
|                 return lookup(self.where_class, lhs, rhs) |             lookup = lookups.pop() | ||||||
|  |             next = lhs.get_lookup(lookup) | ||||||
|  |             if next: | ||||||
|  |                 if not lookups: | ||||||
|  |                     # This was the last lookup, so return value lookup. | ||||||
|  |                     return next(self.where_class, lhs, rhs) | ||||||
|  |                 else: | ||||||
|  |                     lhs = next(self.where_class, lhs, None).get_extract() | ||||||
|  |             # A field's get_lookup() can return None to opt for backwards | ||||||
|  |             # compatibility path. | ||||||
|  |             elif len(lookups) > 1: | ||||||
|  |                 raise FieldError( | ||||||
|  |                     "Unsupported lookup for field '%s'" % lhs.output_type.name) | ||||||
|  |             else: | ||||||
|                 return None |                 return None | ||||||
|  |  | ||||||
|     def build_filter(self, filter_expr, branch_negated=False, current_negated=False, |     def build_filter(self, filter_expr, branch_negated=False, current_negated=False, | ||||||
| @@ -1130,19 +1128,20 @@ class Query(object): | |||||||
|         arg, value = filter_expr |         arg, value = filter_expr | ||||||
|         if not arg: |         if not arg: | ||||||
|             raise FieldError("Cannot parse keyword query %r" % arg) |             raise FieldError("Cannot parse keyword query %r" % arg) | ||||||
|         lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg) |         lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) | ||||||
|  |  | ||||||
|         # Work out the lookup type and remove it from the end of 'parts', |         # Work out the lookup type and remove it from the end of 'parts', | ||||||
|         # if necessary. |         # if necessary. | ||||||
|         value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse) |         value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) | ||||||
|         used_joins = getattr(value, '_used_joins', []) |         used_joins = getattr(value, '_used_joins', []) | ||||||
|  |  | ||||||
|         clause = self.where_class() |         clause = self.where_class() | ||||||
|         if reffed_aggregate: |         if reffed_aggregate: | ||||||
|             condition = self.build_lookup(lookup_type, reffed_aggregate, value) |             condition = self.build_lookup(lookups, reffed_aggregate, value) | ||||||
|             if not condition: |             if not condition: | ||||||
|                 # Backwards compat for custom lookups |                 # Backwards compat for custom lookups | ||||||
|                 condition = (reffed_aggregate, lookup_type, value) |                 assert len(lookups) == 1 | ||||||
|  |                 condition = (reffed_aggregate, lookups[0], value) | ||||||
|             clause.add(condition, AND) |             clause.add(condition, AND) | ||||||
|             return clause, [] |             return clause, [] | ||||||
|  |  | ||||||
| @@ -1169,14 +1168,27 @@ class Query(object): | |||||||
|             # For now foreign keys get special treatment. This should be |             # For now foreign keys get special treatment. This should be | ||||||
|             # refactored when composite fields lands. |             # refactored when composite fields lands. | ||||||
|             condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, |             condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, | ||||||
|                                                     lookup_type, value) |                                                     lookups, value) | ||||||
|  |             lookup_type = lookups[-1] | ||||||
|         else: |         else: | ||||||
|             assert(len(targets) == 1) |             assert(len(targets) == 1) | ||||||
|             col = Col(alias, targets[0], field) |             col = Col(alias, targets[0], field) | ||||||
|             condition = self.build_lookup(lookup_type, col, value) |             condition = self.build_lookup(lookups, col, value) | ||||||
|             if not condition: |             if not condition: | ||||||
|                 # Backwards compat for custom lookups |                 # Backwards compat for custom lookups | ||||||
|                 condition = (Constraint(alias, targets[0].column, field), lookup_type, value) |                 if lookups[0] not in self.query_terms: | ||||||
|  |                     raise FieldError( | ||||||
|  |                         "Join on field '%s' not permitted. Did you " | ||||||
|  |                         "misspell '%s' for the lookup type?" % | ||||||
|  |                         (col.output_type.name, lookups[0])) | ||||||
|  |                 if len(lookups) > 1: | ||||||
|  |                     raise FieldError("Nested lookup '%s' not supported." % | ||||||
|  |                                      LOOKUP_SEP.join(lookups)) | ||||||
|  |                 condition = (Constraint(alias, targets[0].column, field), lookups[0], value) | ||||||
|  |                 lookup_type = lookups[-1] | ||||||
|  |             else: | ||||||
|  |                 lookup_type = condition.lookup_name | ||||||
|  |  | ||||||
|         clause.add(condition, AND) |         clause.add(condition, AND) | ||||||
|  |  | ||||||
|         require_outer = lookup_type == 'isnull' and value is True and not current_negated |         require_outer = lookup_type == 'isnull' and value is True and not current_negated | ||||||
| @@ -1296,7 +1308,7 @@ class Query(object): | |||||||
|         needed_inner = joinpromoter.update_join_types(self) |         needed_inner = joinpromoter.update_join_types(self) | ||||||
|         return target_clause, needed_inner |         return target_clause, needed_inner | ||||||
|  |  | ||||||
|     def names_to_path(self, names, opts, allow_many=True): |     def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): | ||||||
|         """ |         """ | ||||||
|         Walks the names path and turns them PathInfo tuples. Note that a |         Walks the names path and turns them PathInfo tuples. Note that a | ||||||
|         single name in 'names' can generate multiple PathInfos (m2m for |         single name in 'names' can generate multiple PathInfos (m2m for | ||||||
| @@ -1354,10 +1366,15 @@ class Query(object): | |||||||
|                 final_field = field |                 final_field = field | ||||||
|                 targets = (field,) |                 targets = (field,) | ||||||
|                 break |                 break | ||||||
|         if pos == -1: |         if pos == -1 or (fail_on_missing and pos + 1 != len(names)): | ||||||
|             raise FieldError('Whazaa') |             self.raise_field_error(opts, name) | ||||||
|         return path, final_field, targets, names[pos + 1:] |         return path, final_field, targets, names[pos + 1:] | ||||||
|  |  | ||||||
|  |     def raise_field_error(self, opts, name): | ||||||
|  |         available = opts.get_all_field_names() + list(self.aggregate_select) | ||||||
|  |         raise FieldError("Cannot resolve keyword %r into field. " | ||||||
|  |                          "Choices are: %s" % (name, ", ".join(available))) | ||||||
|  |  | ||||||
|     def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): |     def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): | ||||||
|         """ |         """ | ||||||
|         Compute the necessary table joins for the passage through the fields |         Compute the necessary table joins for the passage through the fields | ||||||
| @@ -1386,9 +1403,8 @@ class Query(object): | |||||||
|         joins = [alias] |         joins = [alias] | ||||||
|         # First, generate the path for the names |         # First, generate the path for the names | ||||||
|         path, final_field, targets, rest = self.names_to_path( |         path, final_field, targets, rest = self.names_to_path( | ||||||
|             names, opts, allow_many) |             names, opts, allow_many, fail_on_missing=True) | ||||||
|         if rest: |  | ||||||
|             raise FieldError('Invalid lookup') |  | ||||||
|         # Then, add the path to the query's joins. Note that we can't trim |         # Then, add the path to the query's joins. Note that we can't trim | ||||||
|         # joins at this stage - we will need the information about join type |         # joins at this stage - we will need the information about join type | ||||||
|         # of the trimmed joins. |         # of the trimmed joins. | ||||||
|   | |||||||
| @@ -1,7 +1,13 @@ | |||||||
| from django.db import models | from django.db import models | ||||||
|  | from django.utils.encoding import python_2_unicode_compatible | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @python_2_unicode_compatible | ||||||
| class Author(models.Model): | class Author(models.Model): | ||||||
|     name = models.CharField(max_length=20) |     name = models.CharField(max_length=20) | ||||||
|     age = models.IntegerField(null=True) |     age = models.IntegerField(null=True) | ||||||
|     birthdate = models.DateField(null=True) |     birthdate = models.DateField(null=True) | ||||||
|  |     average_rating = models.FloatField(null=True) | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return self.name | ||||||
|   | |||||||
| @@ -19,6 +19,56 @@ class Div3Lookup(models.lookups.Lookup): | |||||||
|         return '%s %%%% 3 = %s' % (lhs, rhs), params |         return '%s %%%% 3 = %s' % (lhs, rhs), params | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Div3Extract(models.lookups.Extract): | ||||||
|  |     def as_sql(self, qn, connection): | ||||||
|  |         lhs, lhs_params = qn.compile(self.lhs) | ||||||
|  |         return '%s %%%% 3' % (lhs,), lhs_params | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Div3LookupWithExtract(Div3Lookup): | ||||||
|  |     lookup_name = 'div3' | ||||||
|  |     extract_class = Div3Extract | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class YearLte(models.lookups.LessThanOrEqual): | ||||||
|  |     """ | ||||||
|  |     The purpose of this lookup is to efficiently compare the year of the field. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def as_sql(self, qn, connection): | ||||||
|  |         # Skip the YearExtract above us (no possibility for efficient | ||||||
|  |         # lookup otherwise). | ||||||
|  |         real_lhs = self.lhs.lhs | ||||||
|  |         lhs_sql, params = self.process_lhs(qn, connection, real_lhs) | ||||||
|  |         rhs_sql, rhs_params = self.process_rhs(qn, connection) | ||||||
|  |         params.extend(rhs_params) | ||||||
|  |         # Build SQL where the integer year is concatenated with last month | ||||||
|  |         # and day, then convert that to date. (We try to have SQL like: | ||||||
|  |         #     WHERE somecol <= '2013-12-31') | ||||||
|  |         # but also make it work if the rhs_sql is field reference. | ||||||
|  |         return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class YearExtract(models.lookups.Extract): | ||||||
|  |     def as_sql(self, qn, connection): | ||||||
|  |         lhs_sql, params = qn.compile(self.lhs) | ||||||
|  |         return connection.ops.date_extract_sql('year', lhs_sql), params | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def output_type(self): | ||||||
|  |         return models.IntegerField() | ||||||
|  |  | ||||||
|  |     def get_lookup(self, lookup): | ||||||
|  |         if lookup == 'lte': | ||||||
|  |             return YearLte | ||||||
|  |         else: | ||||||
|  |             return super(YearExtract, self).get_lookup(lookup) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class YearWithExtract(models.lookups.Year): | ||||||
|  |     extract_class = YearExtract | ||||||
|  |  | ||||||
|  |  | ||||||
| class InMonth(models.lookups.Lookup): | class InMonth(models.lookups.Lookup): | ||||||
|     """ |     """ | ||||||
|     InMonth matches if the column's month is contained in the value's month. |     InMonth matches if the column's month is contained in the value's month. | ||||||
| @@ -134,3 +184,72 @@ class LookupTests(TestCase): | |||||||
|             ) |             ) | ||||||
|         finally: |         finally: | ||||||
|             models.Field._unregister_lookup(AnotherEqual) |             models.Field._unregister_lookup(AnotherEqual) | ||||||
|  |  | ||||||
|  |     def test_div3_extract(self): | ||||||
|  |         models.IntegerField.register_lookup(Div3LookupWithExtract) | ||||||
|  |         try: | ||||||
|  |             a1 = Author.objects.create(name='a1', age=1) | ||||||
|  |             a2 = Author.objects.create(name='a2', age=2) | ||||||
|  |             a3 = Author.objects.create(name='a3', age=3) | ||||||
|  |             a4 = Author.objects.create(name='a4', age=4) | ||||||
|  |             baseqs = Author.objects.order_by('name') | ||||||
|  |             self.assertQuerysetEqual( | ||||||
|  |                 baseqs.filter(age__div3__lte=3), | ||||||
|  |                 [a1, a2, a3, a4], lambda x: x) | ||||||
|  |             self.assertQuerysetEqual( | ||||||
|  |                 baseqs.filter(age__div3__in=[0, 2]), | ||||||
|  |                 [a2, a3], lambda x: x) | ||||||
|  |         finally: | ||||||
|  |             models.IntegerField._unregister_lookup(Div3LookupWithExtract) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class YearLteTests(TestCase): | ||||||
|  |     def setUp(self): | ||||||
|  |         models.DateField.register_lookup(YearWithExtract) | ||||||
|  |         self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) | ||||||
|  |         self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) | ||||||
|  |         self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31)) | ||||||
|  |         self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1)) | ||||||
|  |  | ||||||
|  |     def tearDown(self): | ||||||
|  |         models.DateField._unregister_lookup(YearWithExtract) | ||||||
|  |  | ||||||
|  |     @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") | ||||||
|  |     def test_year_lte(self): | ||||||
|  |         baseqs = Author.objects.order_by('name') | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             baseqs.filter(birthdate__year__lte=2012), | ||||||
|  |             [self.a1, self.a2, self.a3, self.a4], lambda x: x) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             baseqs.filter(birthdate__year__lte=2011), | ||||||
|  |             [self.a1], lambda x: x) | ||||||
|  |         # The non-optimized version works, too. | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             baseqs.filter(birthdate__year__lt=2012), | ||||||
|  |             [self.a1], lambda x: x) | ||||||
|  |  | ||||||
|  |     @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") | ||||||
|  |     def test_year_lte_fexpr(self): | ||||||
|  |         self.a2.age = 2011 | ||||||
|  |         self.a2.save() | ||||||
|  |         self.a3.age = 2012 | ||||||
|  |         self.a3.save() | ||||||
|  |         self.a4.age = 2013 | ||||||
|  |         self.a4.save() | ||||||
|  |         baseqs = Author.objects.order_by('name') | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             baseqs.filter(birthdate__year__lte=models.F('age')), | ||||||
|  |             [self.a3, self.a4], lambda x: x) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             baseqs.filter(birthdate__year__lt=models.F('age')), | ||||||
|  |             [self.a4], lambda x: x) | ||||||
|  |  | ||||||
|  |     def test_year_lte_sql(self): | ||||||
|  |         # This test will just check the generated SQL for __lte. This | ||||||
|  |         # doesn't require running on PostgreSQL and spots the most likely | ||||||
|  |         # error - not running YearLte SQL at all. | ||||||
|  |         baseqs = Author.objects.order_by('name') | ||||||
|  |         self.assertIn( | ||||||
|  |             '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query)) | ||||||
|  |         self.assertIn( | ||||||
|  |             '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) | ||||||
|   | |||||||
| @@ -41,9 +41,6 @@ class NullQueriesTests(TestCase): | |||||||
|         # Can't use None on anything other than __exact |         # Can't use None on anything other than __exact | ||||||
|         self.assertRaises(ValueError, Choice.objects.filter, id__gt=None) |         self.assertRaises(ValueError, Choice.objects.filter, id__gt=None) | ||||||
|  |  | ||||||
|         # Can't use None on anything other than __exact |  | ||||||
|         self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None) |  | ||||||
|  |  | ||||||
|         # Related managers use __exact=None implicitly if the object hasn't been saved. |         # Related managers use __exact=None implicitly if the object hasn't been saved. | ||||||
|         p2 = Poll(question="How?") |         p2 = Poll(question="How?") | ||||||
|         self.assertEqual(repr(p2.choice_set.all()), '[]') |         self.assertEqual(repr(p2.choice_set.all()), '[]') | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user