mirror of
				https://github.com/django/django.git
				synced 2025-10-31 01:25:32 +00:00 
			
		
		
		
	Implemented nested lookups
But there is no support of using lookups outside filtering yet.
This commit is contained in:
		| @@ -1136,11 +1136,14 @@ class ForeignObject(RelatedField): | ||||
|         pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] | ||||
|         return pathinfos | ||||
|  | ||||
|     def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, | ||||
|     def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, | ||||
|                               raw_value): | ||||
|         from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR | ||||
|         root_constraint = constraint_class() | ||||
|         assert len(targets) == len(sources) | ||||
|         if len(lookups) > 1: | ||||
|             raise exceptions.FieldError('Relation fields do not support nested lookups') | ||||
|         lookup_type = lookups[0] | ||||
|  | ||||
|         def get_normalized_value(value): | ||||
|  | ||||
|   | ||||
| @@ -1,27 +1,58 @@ | ||||
| from copy import copy | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django.conf import settings | ||||
| from django.utils import timezone | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|  | ||||
| class Extract(object): | ||||
|     def __init__(self, constraint_class, lhs): | ||||
|         self.constraint_class, self.lhs = constraint_class, lhs | ||||
|  | ||||
|     def get_lookup(self, lookup): | ||||
|         return self.output_type.get_lookup(lookup) | ||||
|  | ||||
|     def as_sql(self, qn, connection): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @cached_property | ||||
|     def output_type(self): | ||||
|         return self.lhs.output_type | ||||
|  | ||||
|     def relabeled_clone(self, relabels): | ||||
|         return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels)) | ||||
|  | ||||
|  | ||||
| class Lookup(object): | ||||
|     lookup_name = None | ||||
|     extract_class = None | ||||
|  | ||||
|     def __init__(self, constraint_class, lhs, rhs): | ||||
|         self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs | ||||
|         self.rhs = self.get_prep_lookup() | ||||
|         if rhs is None: | ||||
|             if not self.extract_class: | ||||
|                 raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name) | ||||
|         else: | ||||
|             self.rhs = self.get_prep_lookup() | ||||
|  | ||||
|     def get_extract(self): | ||||
|         return self.extract_class(self.constraint_class, self.lhs) | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) | ||||
|  | ||||
|     def get_db_prep_lookup(self, value, connection): | ||||
|         return ( | ||||
|             '%s', self.lhs.output_type.get_db_prep_lookup( | ||||
|                 self.lookup_name, value, connection, prepared=True)) | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) | ||||
|     def process_lhs(self, qn, connection, lhs=None): | ||||
|         lhs = lhs or self.lhs | ||||
|         return qn.compile(lhs) | ||||
|  | ||||
|     def process_lhs(self, qn, connection): | ||||
|         return qn.compile(self.lhs) | ||||
|  | ||||
|     def process_rhs(self, qn, connection): | ||||
|         value = self.rhs | ||||
|     def process_rhs(self, qn, connection, rhs=None): | ||||
|         value = rhs or self.rhs | ||||
|         # Due to historical reasons there are a couple of different | ||||
|         # ways to produce sql here. get_compiler is likely a Query | ||||
|         # instance, _as_sql QuerySet and as_sql just something with | ||||
| @@ -118,7 +149,7 @@ class In(DjangoLookup): | ||||
|     lookup_name = 'in' | ||||
|  | ||||
|     def get_db_prep_lookup(self, value, connection): | ||||
|         params = self.lhs.field.get_db_prep_lookup( | ||||
|         params = self.lhs.output_type.get_db_prep_lookup( | ||||
|             self.lookup_name, value, connection, prepared=True) | ||||
|         if not params: | ||||
|             # TODO: check why this leads to circular import | ||||
|   | ||||
| @@ -100,6 +100,9 @@ class Aggregate(object): | ||||
|     def output_type(self): | ||||
|         return self.field | ||||
|  | ||||
|     def get_lookup(self, lookup): | ||||
|         return self.output_type.get_lookup(lookup) | ||||
|  | ||||
|  | ||||
| class Avg(Aggregate): | ||||
|     is_computed = True | ||||
|   | ||||
| @@ -25,6 +25,9 @@ class Col(object): | ||||
|     def get_cols(self): | ||||
|         return [(self.alias, self.target.column)] | ||||
|  | ||||
|     def get_lookup(self, name): | ||||
|         return self.output_type.get_lookup(name) | ||||
|  | ||||
|  | ||||
| class EmptyResultSet(Exception): | ||||
|     pass | ||||
|   | ||||
| @@ -1027,19 +1027,16 @@ class Query(object): | ||||
|         # Add the aggregate to the query | ||||
|         aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) | ||||
|  | ||||
|     def prepare_lookup_value(self, value, lookup_type, can_reuse): | ||||
|     def prepare_lookup_value(self, value, lookups, can_reuse): | ||||
|         # Default lookup if none given is exact. | ||||
|         if len(lookups) == 0: | ||||
|             lookups = ['exact'] | ||||
|         # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all | ||||
|         # uses of None as a query value. | ||||
|         if len(lookup_type) > 1: | ||||
|             raise FieldError('Nested lookups not allowed') | ||||
|         elif len(lookup_type) == 0: | ||||
|             lookup_type = 'exact' | ||||
|         else: | ||||
|             lookup_type = lookup_type[0] | ||||
|         if value is None: | ||||
|             if lookup_type != 'exact': | ||||
|             if lookups[-1] != 'exact': | ||||
|                 raise ValueError("Cannot use None as a query value") | ||||
|             lookup_type = 'isnull' | ||||
|             lookups[-1] = 'isnull' | ||||
|             value = True | ||||
|         elif callable(value): | ||||
|             value = value() | ||||
| @@ -1057,10 +1054,10 @@ class Query(object): | ||||
|         # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we | ||||
|         # can do here. Similar thing is done in is_nullable(), too. | ||||
|         if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and | ||||
|                 lookup_type == 'exact' and value == ''): | ||||
|                 lookups[-1] == 'exact' and value == ''): | ||||
|             value = True | ||||
|             lookup_type = 'isnull' | ||||
|         return value, lookup_type | ||||
|             lookups[-1] = ['isnull'] | ||||
|         return value, lookups | ||||
|  | ||||
|     def solve_lookup_type(self, lookup): | ||||
|         """ | ||||
| @@ -1069,36 +1066,37 @@ class Query(object): | ||||
|         lookup_splitted = lookup.split(LOOKUP_SEP) | ||||
|         aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) | ||||
|         if aggregate: | ||||
|             if len(aggregate_lookups) > 1: | ||||
|                 raise FieldError("Nested lookups not allowed.") | ||||
|             return aggregate_lookups, (), aggregate | ||||
|         _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) | ||||
|         field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] | ||||
|         if len(lookup_parts) == 0: | ||||
|             lookup_parts = ['exact'] | ||||
|         elif len(lookup_parts) > 1: | ||||
|             if field_parts: | ||||
|                 raise FieldError( | ||||
|                     'Only one lookup part allowed (found path "%s" from "%s").' % | ||||
|                     (LOOKUP_SEP.join(field_parts), lookup)) | ||||
|             else: | ||||
|             if not field_parts: | ||||
|                 raise FieldError( | ||||
|                     'Invalid lookup "%s" for model %s".' % | ||||
|                     (lookup, self.get_meta().model.__name__)) | ||||
|         else: | ||||
|             if not hasattr(field, 'get_lookup_constraint'): | ||||
|                 lookup_class = field.get_lookup(lookup_parts[0]) | ||||
|                 if lookup_class is None and lookup_parts[0] not in self.query_terms: | ||||
|                     raise FieldError( | ||||
|                         'Invalid lookup name %s' % lookup_parts[0]) | ||||
|         return lookup_parts, field_parts, False | ||||
|  | ||||
|     def build_lookup(self, lookup_type, lhs, rhs): | ||||
|         if hasattr(lhs.output_type, 'get_lookup'): | ||||
|             lookup = lhs.output_type.get_lookup(lookup_type) | ||||
|             if lookup: | ||||
|                 return lookup(self.where_class, lhs, rhs) | ||||
|         return None | ||||
|     def build_lookup(self, lookups, lhs, rhs): | ||||
|         lookups = lookups[:] | ||||
|         lookups.reverse() | ||||
|         while lookups: | ||||
|             lookup = lookups.pop() | ||||
|             next = lhs.get_lookup(lookup) | ||||
|             if next: | ||||
|                 if not lookups: | ||||
|                     # This was the last lookup, so return value lookup. | ||||
|                     return next(self.where_class, lhs, rhs) | ||||
|                 else: | ||||
|                     lhs = next(self.where_class, lhs, None).get_extract() | ||||
|             # A field's get_lookup() can return None to opt for backwards | ||||
|             # compatibility path. | ||||
|             elif len(lookups) > 1: | ||||
|                 raise FieldError( | ||||
|                     "Unsupported lookup for field '%s'" % lhs.output_type.name) | ||||
|             else: | ||||
|                 return None | ||||
|  | ||||
|     def build_filter(self, filter_expr, branch_negated=False, current_negated=False, | ||||
|                      can_reuse=None, connector=AND): | ||||
| @@ -1130,19 +1128,20 @@ class Query(object): | ||||
|         arg, value = filter_expr | ||||
|         if not arg: | ||||
|             raise FieldError("Cannot parse keyword query %r" % arg) | ||||
|         lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg) | ||||
|         lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) | ||||
|  | ||||
|         # Work out the lookup type and remove it from the end of 'parts', | ||||
|         # if necessary. | ||||
|         value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse) | ||||
|         value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) | ||||
|         used_joins = getattr(value, '_used_joins', []) | ||||
|  | ||||
|         clause = self.where_class() | ||||
|         if reffed_aggregate: | ||||
|             condition = self.build_lookup(lookup_type, reffed_aggregate, value) | ||||
|             condition = self.build_lookup(lookups, reffed_aggregate, value) | ||||
|             if not condition: | ||||
|                 # Backwards compat for custom lookups | ||||
|                 condition = (reffed_aggregate, lookup_type, value) | ||||
|                 assert len(lookups) == 1 | ||||
|                 condition = (reffed_aggregate, lookups[0], value) | ||||
|             clause.add(condition, AND) | ||||
|             return clause, [] | ||||
|  | ||||
| @@ -1169,14 +1168,27 @@ class Query(object): | ||||
|             # For now foreign keys get special treatment. This should be | ||||
|             # refactored when composite fields lands. | ||||
|             condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, | ||||
|                                                     lookup_type, value) | ||||
|                                                     lookups, value) | ||||
|             lookup_type = lookups[-1] | ||||
|         else: | ||||
|             assert(len(targets) == 1) | ||||
|             col = Col(alias, targets[0], field) | ||||
|             condition = self.build_lookup(lookup_type, col, value) | ||||
|             condition = self.build_lookup(lookups, col, value) | ||||
|             if not condition: | ||||
|                 # Backwards compat for custom lookups | ||||
|                 condition = (Constraint(alias, targets[0].column, field), lookup_type, value) | ||||
|                 if lookups[0] not in self.query_terms: | ||||
|                     raise FieldError( | ||||
|                         "Join on field '%s' not permitted. Did you " | ||||
|                         "misspell '%s' for the lookup type?" % | ||||
|                         (col.output_type.name, lookups[0])) | ||||
|                 if len(lookups) > 1: | ||||
|                     raise FieldError("Nested lookup '%s' not supported." % | ||||
|                                      LOOKUP_SEP.join(lookups)) | ||||
|                 condition = (Constraint(alias, targets[0].column, field), lookups[0], value) | ||||
|                 lookup_type = lookups[-1] | ||||
|             else: | ||||
|                 lookup_type = condition.lookup_name | ||||
|  | ||||
|         clause.add(condition, AND) | ||||
|  | ||||
|         require_outer = lookup_type == 'isnull' and value is True and not current_negated | ||||
| @@ -1296,7 +1308,7 @@ class Query(object): | ||||
|         needed_inner = joinpromoter.update_join_types(self) | ||||
|         return target_clause, needed_inner | ||||
|  | ||||
|     def names_to_path(self, names, opts, allow_many=True): | ||||
|     def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): | ||||
|         """ | ||||
|         Walks the names path and turns them PathInfo tuples. Note that a | ||||
|         single name in 'names' can generate multiple PathInfos (m2m for | ||||
| @@ -1354,10 +1366,15 @@ class Query(object): | ||||
|                 final_field = field | ||||
|                 targets = (field,) | ||||
|                 break | ||||
|         if pos == -1: | ||||
|             raise FieldError('Whazaa') | ||||
|         if pos == -1 or (fail_on_missing and pos + 1 != len(names)): | ||||
|             self.raise_field_error(opts, name) | ||||
|         return path, final_field, targets, names[pos + 1:] | ||||
|  | ||||
|     def raise_field_error(self, opts, name): | ||||
|         available = opts.get_all_field_names() + list(self.aggregate_select) | ||||
|         raise FieldError("Cannot resolve keyword %r into field. " | ||||
|                          "Choices are: %s" % (name, ", ".join(available))) | ||||
|  | ||||
|     def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): | ||||
|         """ | ||||
|         Compute the necessary table joins for the passage through the fields | ||||
| @@ -1386,9 +1403,8 @@ class Query(object): | ||||
|         joins = [alias] | ||||
|         # First, generate the path for the names | ||||
|         path, final_field, targets, rest = self.names_to_path( | ||||
|             names, opts, allow_many) | ||||
|         if rest: | ||||
|             raise FieldError('Invalid lookup') | ||||
|             names, opts, allow_many, fail_on_missing=True) | ||||
|  | ||||
|         # Then, add the path to the query's joins. Note that we can't trim | ||||
|         # joins at this stage - we will need the information about join type | ||||
|         # of the trimmed joins. | ||||
|   | ||||
| @@ -1,7 +1,13 @@ | ||||
| from django.db import models | ||||
| from django.utils.encoding import python_2_unicode_compatible | ||||
|  | ||||
|  | ||||
| @python_2_unicode_compatible | ||||
| class Author(models.Model): | ||||
|     name = models.CharField(max_length=20) | ||||
|     age = models.IntegerField(null=True) | ||||
|     birthdate = models.DateField(null=True) | ||||
|     average_rating = models.FloatField(null=True) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.name | ||||
|   | ||||
| @@ -19,6 +19,56 @@ class Div3Lookup(models.lookups.Lookup): | ||||
|         return '%s %%%% 3 = %s' % (lhs, rhs), params | ||||
|  | ||||
|  | ||||
| class Div3Extract(models.lookups.Extract): | ||||
|     def as_sql(self, qn, connection): | ||||
|         lhs, lhs_params = qn.compile(self.lhs) | ||||
|         return '%s %%%% 3' % (lhs,), lhs_params | ||||
|  | ||||
|  | ||||
| class Div3LookupWithExtract(Div3Lookup): | ||||
|     lookup_name = 'div3' | ||||
|     extract_class = Div3Extract | ||||
|  | ||||
|  | ||||
| class YearLte(models.lookups.LessThanOrEqual): | ||||
|     """ | ||||
|     The purpose of this lookup is to efficiently compare the year of the field. | ||||
|     """ | ||||
|  | ||||
|     def as_sql(self, qn, connection): | ||||
|         # Skip the YearExtract above us (no possibility for efficient | ||||
|         # lookup otherwise). | ||||
|         real_lhs = self.lhs.lhs | ||||
|         lhs_sql, params = self.process_lhs(qn, connection, real_lhs) | ||||
|         rhs_sql, rhs_params = self.process_rhs(qn, connection) | ||||
|         params.extend(rhs_params) | ||||
|         # Build SQL where the integer year is concatenated with last month | ||||
|         # and day, then convert that to date. (We try to have SQL like: | ||||
|         #     WHERE somecol <= '2013-12-31') | ||||
|         # but also make it work if the rhs_sql is field reference. | ||||
|         return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params | ||||
|  | ||||
|  | ||||
| class YearExtract(models.lookups.Extract): | ||||
|     def as_sql(self, qn, connection): | ||||
|         lhs_sql, params = qn.compile(self.lhs) | ||||
|         return connection.ops.date_extract_sql('year', lhs_sql), params | ||||
|  | ||||
|     @property | ||||
|     def output_type(self): | ||||
|         return models.IntegerField() | ||||
|  | ||||
|     def get_lookup(self, lookup): | ||||
|         if lookup == 'lte': | ||||
|             return YearLte | ||||
|         else: | ||||
|             return super(YearExtract, self).get_lookup(lookup) | ||||
|  | ||||
|  | ||||
| class YearWithExtract(models.lookups.Year): | ||||
|     extract_class = YearExtract | ||||
|  | ||||
|  | ||||
| class InMonth(models.lookups.Lookup): | ||||
|     """ | ||||
|     InMonth matches if the column's month is contained in the value's month. | ||||
| @@ -134,3 +184,72 @@ class LookupTests(TestCase): | ||||
|             ) | ||||
|         finally: | ||||
|             models.Field._unregister_lookup(AnotherEqual) | ||||
|  | ||||
|     def test_div3_extract(self): | ||||
|         models.IntegerField.register_lookup(Div3LookupWithExtract) | ||||
|         try: | ||||
|             a1 = Author.objects.create(name='a1', age=1) | ||||
|             a2 = Author.objects.create(name='a2', age=2) | ||||
|             a3 = Author.objects.create(name='a3', age=3) | ||||
|             a4 = Author.objects.create(name='a4', age=4) | ||||
|             baseqs = Author.objects.order_by('name') | ||||
|             self.assertQuerysetEqual( | ||||
|                 baseqs.filter(age__div3__lte=3), | ||||
|                 [a1, a2, a3, a4], lambda x: x) | ||||
|             self.assertQuerysetEqual( | ||||
|                 baseqs.filter(age__div3__in=[0, 2]), | ||||
|                 [a2, a3], lambda x: x) | ||||
|         finally: | ||||
|             models.IntegerField._unregister_lookup(Div3LookupWithExtract) | ||||
|  | ||||
|  | ||||
| class YearLteTests(TestCase): | ||||
|     def setUp(self): | ||||
|         models.DateField.register_lookup(YearWithExtract) | ||||
|         self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) | ||||
|         self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) | ||||
|         self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31)) | ||||
|         self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1)) | ||||
|  | ||||
|     def tearDown(self): | ||||
|         models.DateField._unregister_lookup(YearWithExtract) | ||||
|  | ||||
|     @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") | ||||
|     def test_year_lte(self): | ||||
|         baseqs = Author.objects.order_by('name') | ||||
|         self.assertQuerysetEqual( | ||||
|             baseqs.filter(birthdate__year__lte=2012), | ||||
|             [self.a1, self.a2, self.a3, self.a4], lambda x: x) | ||||
|         self.assertQuerysetEqual( | ||||
|             baseqs.filter(birthdate__year__lte=2011), | ||||
|             [self.a1], lambda x: x) | ||||
|         # The non-optimized version works, too. | ||||
|         self.assertQuerysetEqual( | ||||
|             baseqs.filter(birthdate__year__lt=2012), | ||||
|             [self.a1], lambda x: x) | ||||
|  | ||||
|     @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") | ||||
|     def test_year_lte_fexpr(self): | ||||
|         self.a2.age = 2011 | ||||
|         self.a2.save() | ||||
|         self.a3.age = 2012 | ||||
|         self.a3.save() | ||||
|         self.a4.age = 2013 | ||||
|         self.a4.save() | ||||
|         baseqs = Author.objects.order_by('name') | ||||
|         self.assertQuerysetEqual( | ||||
|             baseqs.filter(birthdate__year__lte=models.F('age')), | ||||
|             [self.a3, self.a4], lambda x: x) | ||||
|         self.assertQuerysetEqual( | ||||
|             baseqs.filter(birthdate__year__lt=models.F('age')), | ||||
|             [self.a4], lambda x: x) | ||||
|  | ||||
|     def test_year_lte_sql(self): | ||||
|         # This test will just check the generated SQL for __lte. This | ||||
|         # doesn't require running on PostgreSQL and spots the most likely | ||||
|         # error - not running YearLte SQL at all. | ||||
|         baseqs = Author.objects.order_by('name') | ||||
|         self.assertIn( | ||||
|             '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query)) | ||||
|         self.assertIn( | ||||
|             '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) | ||||
|   | ||||
| @@ -41,9 +41,6 @@ class NullQueriesTests(TestCase): | ||||
|         # Can't use None on anything other than __exact | ||||
|         self.assertRaises(ValueError, Choice.objects.filter, id__gt=None) | ||||
|  | ||||
|         # Can't use None on anything other than __exact | ||||
|         self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None) | ||||
|  | ||||
|         # Related managers use __exact=None implicitly if the object hasn't been saved. | ||||
|         p2 = Poll(question="How?") | ||||
|         self.assertEqual(repr(p2.choice_set.all()), '[]') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user