From b68212f539f206679580afbfd008e7d329c9cd31 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?=
 <anssi.kaariainen@thl.fi>
Date: Mon, 2 Feb 2015 13:48:30 +0200
Subject: [PATCH] Refs #24267 -- Implemented lookups for related fields

Previously related fields didn't implement get_lookup, instead
related fields were treated specially. This commit removed some of
the special handling. In particular, related fields return Lookup
instances now, too.

Other notable changes in this commit is removal of support for
annotations in names_to_path().
---
 django/db/models/fields/related.py         |  96 ++++++---------
 django/db/models/fields/related_lookups.py | 130 +++++++++++++++++++++
 django/db/models/query_utils.py            |  16 ++-
 django/db/models/sql/query.py              |  47 ++++----
 tests/generic_relations_regress/tests.py   |  16 ++-
 tests/queries/tests.py                     |   8 ++
 6 files changed, 219 insertions(+), 94 deletions(-)
 create mode 100644 django/db/models/fields/related_lookups.py

diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 0ca4e59633..e91de27876 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -15,7 +15,10 @@ from django.db.models.fields import (
     BLANK_CHOICE_DASH, AutoField, Field, IntegerField, PositiveIntegerField,
     PositiveSmallIntegerField,
 )
-from django.db.models.lookups import IsNull
+from django.db.models.fields.related_lookups import (
+    RelatedExact, RelatedGreaterThan, RelatedGreaterThanOrEqual, RelatedIn,
+    RelatedLessThan, RelatedLessThanOrEqual,
+)
 from django.db.models.query import QuerySet
 from django.db.models.query_utils import PathInfo
 from django.utils import six
@@ -1336,6 +1339,16 @@ class ForeignObjectRel(object):
     def one_to_one(self):
         return self.field.one_to_one
 
+    def get_prep_lookup(self, lookup_name, value):
+        return self.field.get_prep_lookup(lookup_name, value)
+
+    def get_internal_type(self):
+        return self.field.get_internal_type()
+
+    @property
+    def db_type(self):
+        return self.field.db_type
+
     def __repr__(self):
         return '<%s: %s.%s>' % (
             type(self).__name__,
@@ -1760,67 +1773,25 @@ class ForeignObject(RelatedField):
         pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
         return pathinfos
 
-    def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups,
-                              raw_value):
-        from django.db.models.sql.where import SubqueryConstraint, AND, OR
-        root_constraint = constraint_class()
-        assert len(targets) == len(sources)
-        if len(lookups) > 1:
-            raise exceptions.FieldError(
-                "Cannot resolve keyword %r into field. Choices are: %s" % (
-                    lookups[0],
-                    ", ".join(f.name for f in self.model._meta.get_fields()),
-                )
-            )
-        lookup_type = lookups[0]
+    def get_lookup(self, lookup_name):
+        if lookup_name == 'in':
+            return RelatedIn
+        elif lookup_name == 'exact':
+            return RelatedExact
+        elif lookup_name == 'gt':
+            return RelatedGreaterThan
+        elif lookup_name == 'gte':
+            return RelatedGreaterThanOrEqual
+        elif lookup_name == 'lt':
+            return RelatedLessThan
+        elif lookup_name == 'lte':
+            return RelatedLessThanOrEqual
+        elif lookup_name != 'isnull':
+            raise TypeError('Related Field got invalid lookup: %s' % lookup_name)
+        return super(ForeignObject, self).get_lookup(lookup_name)
 
-        def get_normalized_value(value):
-            from django.db.models import Model
-            if isinstance(value, Model):
-                value_list = []
-                for source in sources:
-                    # Account for one-to-one relations when sent a different model
-                    while not isinstance(value, source.model) and source.rel:
-                        source = source.rel.to._meta.get_field(source.rel.field_name)
-                    value_list.append(getattr(value, source.attname))
-                return tuple(value_list)
-            elif not isinstance(value, tuple):
-                return (value,)
-            return value
-
-        is_multicolumn = len(self.related_fields) > 1
-        if (hasattr(raw_value, '_as_sql') or
-                hasattr(raw_value, 'get_compiler')):
-            root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets],
-                                                   [source.name for source in sources], raw_value),
-                                AND)
-        elif lookup_type == 'isnull':
-            root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND)
-        elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
-                                         and not is_multicolumn)):
-            value = get_normalized_value(raw_value)
-            for target, source, val in zip(targets, sources, value):
-                lookup_class = target.get_lookup(lookup_type)
-                root_constraint.add(
-                    lookup_class(target.get_col(alias, source), val), AND)
-        elif lookup_type in ['range', 'in'] and not is_multicolumn:
-            values = [get_normalized_value(value) for value in raw_value]
-            value = [val[0] for val in values]
-            lookup_class = targets[0].get_lookup(lookup_type)
-            root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND)
-        elif lookup_type == 'in':
-            values = [get_normalized_value(value) for value in raw_value]
-            root_constraint.connector = OR
-            for value in values:
-                value_constraint = constraint_class()
-                for source, target, val in zip(sources, targets, value):
-                    lookup_class = target.get_lookup('exact')
-                    lookup = lookup_class(target.get_col(alias, source), val)
-                    value_constraint.add(lookup, AND)
-                root_constraint.add(value_constraint, OR)
-        else:
-            raise TypeError('Related Field got invalid lookup: %s' % lookup_type)
-        return root_constraint
+    def get_transform(self, *args, **kwargs):
+        raise NotImplementedError('Relational fields do not support transforms.')
 
     @property
     def attnames(self):
@@ -2017,6 +1988,9 @@ class ForeignKey(ForeignObject):
         else:
             return self.related_field.get_db_prep_save(value, connection=connection)
 
+    def get_db_prep_value(self, value, connection, prepared=False):
+        return self.related_field.get_db_prep_value(value, connection, prepared)
+
     def value_to_string(self, obj):
         if not obj:
             # In required many-to-one fields with only one available choice,
diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py
new file mode 100644
index 0000000000..b689c9928e
--- /dev/null
+++ b/django/db/models/fields/related_lookups.py
@@ -0,0 +1,130 @@
+from django.db.models.lookups import (
+    Exact, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual,
+)
+
+
+class MultiColSource(object):
+    contains_aggregate = False
+
+    def __init__(self, alias, targets, sources, field):
+        self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
+        self.output_field = self.field
+
+    def __repr__(self):
+        return "{}({}, {})".format(
+            self.__class__.__name__, self.alias, self.field)
+
+    def relabeled_clone(self, relabels):
+        return self.__class__(relabels.get(self.alias, self.alias),
+                              self.targets, self.sources, self.field)
+
+
+def get_normalized_value(value, lhs):
+    from django.db.models import Model
+    if isinstance(value, Model):
+        value_list = []
+        # Account for one-to-one relations when sent a different model
+        sources = lhs.output_field.get_path_info()[-1].target_fields
+        for source in sources:
+            while not isinstance(value, source.model) and source.rel:
+                source = source.rel.to._meta.get_field(source.rel.field_name)
+            value_list.append(getattr(value, source.attname))
+        return tuple(value_list)
+    if not isinstance(value, tuple):
+        return (value,)
+    return value
+
+
+class RelatedIn(In):
+    def get_prep_lookup(self):
+        if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
+            # If we get here, we are dealing with single-column relations.
+            self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
+            # We need to run the related field's get_prep_lookup(). Consider case
+            # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
+            # doesn't have validation for non-integers, so we must run validation
+            # using the target field.
+            if hasattr(self.lhs.output_field, 'get_path_info'):
+                # Run the target field's get_prep_lookup. We can safely assume there is
+                # only one as we don't get to the direct value branch otherwise.
+                self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
+                    self.lookup_name, self.rhs)
+        return super(RelatedIn, self).get_prep_lookup()
+
+    def as_sql(self, compiler, connection):
+        if isinstance(self.lhs, MultiColSource):
+            # For multicolumn lookups we need to build a multicolumn where clause.
+            # This clause is either a SubqueryConstraint (for values that need to be compiled to
+            # SQL) or a OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
+            from django.db.models.sql.where import WhereNode, SubqueryConstraint, AND, OR
+
+            root_constraint = WhereNode(connector=OR)
+            if self.rhs_is_direct_value():
+                values = [get_normalized_value(value, self.lhs) for value in self.rhs]
+                for value in values:
+                    value_constraint = WhereNode()
+                    for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
+                        lookup_class = target.get_lookup('exact')
+                        lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
+                        value_constraint.add(lookup, AND)
+                    root_constraint.add(value_constraint, OR)
+            else:
+                root_constraint.add(
+                    SubqueryConstraint(
+                        self.lhs.alias, [target.column for target in self.lhs.targets],
+                        [source.name for source in self.lhs.sources], self.rhs),
+                    AND)
+            return root_constraint.as_sql(compiler, connection)
+        else:
+            return super(RelatedIn, self).as_sql(compiler, connection)
+
+
+class RelatedLookupMixin(object):
+    def get_prep_lookup(self):
+        if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
+            # If we get here, we are dealing with single-column relations.
+            self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
+            # We need to run the related field's get_prep_lookup(). Consider case
+            # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
+            # doesn't have validation for non-integers, so we must run validation
+            # using the target field.
+            if hasattr(self.lhs.output_field, 'get_path_info'):
+                # Get the target field. We can safely assume there is only one
+                # as we don't get to the direct value branch otherwise.
+                self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
+                    self.lookup_name, self.rhs)
+
+        return super(RelatedLookupMixin, self).get_prep_lookup()
+
+    def as_sql(self, compiler, connection):
+        if isinstance(self.lhs, MultiColSource):
+            assert self.rhs_is_direct_value()
+            self.rhs = get_normalized_value(self.rhs, self.lhs)
+            from django.db.models.sql.where import WhereNode, AND
+            root_constraint = WhereNode()
+            for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
+                lookup_class = target.get_lookup(self.lookup_name)
+                root_constraint.add(
+                    lookup_class(target.get_col(self.lhs.alias, source), val), AND)
+            return root_constraint.as_sql(compiler, connection)
+        return super(RelatedLookupMixin, self).as_sql(compiler, connection)
+
+
+class RelatedExact(RelatedLookupMixin, Exact):
+    pass
+
+
+class RelatedLessThan(RelatedLookupMixin, LessThan):
+    pass
+
+
+class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
+    pass
+
+
+class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
+    pass
+
+
+class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
+    pass
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
index 005bd99956..35e3fe58bc 100644
--- a/django/db/models/query_utils.py
+++ b/django/db/models/query_utils.py
@@ -250,7 +250,7 @@ deferred_class_factory.__safe_for_unpickling__ = True
 
 def refs_aggregate(lookup_parts, aggregates):
     """
-    A little helper method to check if the lookup_parts contains references
+    A helper method to check if the lookup_parts contains references
     to the given aggregates set. Because the LOOKUP_SEP is contained in the
     default annotation names we must check each prefix of the lookup_parts
     for a match.
@@ -260,3 +260,17 @@ def refs_aggregate(lookup_parts, aggregates):
         if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate:
             return aggregates[level_n_lookup], lookup_parts[n:]
     return False, ()
+
+
+def refs_expression(lookup_parts, annotations):
+    """
+    A helper method to check if the lookup_parts contains references
+    to the given annotations set. Because the LOOKUP_SEP is contained in the
+    default annotation names we must check each prefix of the lookup_parts
+    for a match.
+    """
+    for n in range(len(lookup_parts) + 1):
+        level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
+        if level_n_lookup in annotations and annotations[level_n_lookup]:
+            return annotations[level_n_lookup], lookup_parts[n:]
+    return False, ()
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 14e079ee84..9c19e819fd 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -17,7 +17,8 @@ from django.db import DEFAULT_DB_ALIAS, connections
 from django.db.models.aggregates import Count
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.expressions import Col, Ref
-from django.db.models.query_utils import Q, PathInfo, refs_aggregate
+from django.db.models.fields.related_lookups import MultiColSource
+from django.db.models.query_utils import Q, PathInfo, refs_expression
 from django.db.models.sql.constants import (
     INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE,
 )
@@ -1006,7 +1007,7 @@ class Query(object):
         """
         lookup_splitted = lookup.split(LOOKUP_SEP)
         if self._annotations:
-            aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations)
+            aggregate, aggregate_lookups = refs_expression(lookup_splitted, self.annotations)
             if aggregate:
                 return aggregate_lookups, (), aggregate
         _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
@@ -1157,24 +1158,26 @@ class Query(object):
         if can_reuse is not None:
             can_reuse.update(join_list)
         used_joins = set(used_joins).union(set(join_list))
-
-        # Process the join list to see if we can remove any non-needed joins from
-        # the far end (fewer tables in a query is better).
         targets, alias, join_list = self.trim_joins(sources, join_list, path)
 
-        if hasattr(field, 'get_lookup_constraint'):
-            # For now foreign keys get special treatment. This should be
-            # refactored when composite fields lands.
-            condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
-                                                    lookups, value)
-            lookup_type = lookups[-1]
-        else:
-            assert(len(targets) == 1)
-            if hasattr(targets[0], 'as_sql'):
-                # handle Expressions as annotations
-                col = targets[0]
+        if field.is_relation:
+            # No support for transforms for relational fields
+            assert len(lookups) == 1
+            lookup_class = field.get_lookup(lookups[0])
+            # Undo the changes done in setup_joins() if hasattr(final_field, 'field') branch
+            # This hack is needed as long as the field.rel isn't like a real field.
+            if field.get_path_info()[-1].target_fields != sources:
+                target_field = field.rel
             else:
-                col = targets[0].get_col(alias, field)
+                target_field = field
+            if len(targets) == 1:
+                lhs = targets[0].get_col(alias, target_field)
+            else:
+                lhs = MultiColSource(alias, targets, sources, target_field)
+            condition = lookup_class(lhs, value)
+            lookup_type = lookup_class.lookup_name
+        else:
+            col = targets[0].get_col(alias, field)
             condition = self.build_lookup(lookups, col, value)
             lookup_type = condition.lookup_name
 
@@ -1284,14 +1287,6 @@ class Query(object):
                     )
                 model = field.model._meta.concrete_model
             except FieldDoesNotExist:
-                # is it an annotation?
-                if self._annotations and name in self._annotations:
-                    field, model = self._annotations[name], None
-                    if not field.contains_aggregate:
-                        # Local non-relational field.
-                        final_field = field
-                        targets = (field,)
-                        break
                 # We didn't find the current field, so move position back
                 # one step.
                 pos -= 1
@@ -1985,7 +1980,7 @@ def is_reverse_o2o(field):
     A little helper to check if the given field is reverse-o2o. The field is
     expected to be some sort of relation field or related object.
     """
-    return not hasattr(field, 'rel') and field.field.unique
+    return field.is_relation and field.one_to_one and not field.concrete
 
 
 class JoinPromoter(object):
diff --git a/tests/generic_relations_regress/tests.py b/tests/generic_relations_regress/tests.py
index 0d78223725..b6782fe13f 100644
--- a/tests/generic_relations_regress/tests.py
+++ b/tests/generic_relations_regress/tests.py
@@ -144,22 +144,26 @@ class GenericRelationTests(TestCase):
         tag.save()
 
     def test_ticket_20378(self):
+        # Create a couple of extra HasLinkThing so that the autopk value
+        # isn't the same for Link and HasLinkThing.
         hs1 = HasLinkThing.objects.create()
         hs2 = HasLinkThing.objects.create()
-        l1 = Link.objects.create(content_object=hs1)
-        l2 = Link.objects.create(content_object=hs2)
+        hs3 = HasLinkThing.objects.create()
+        hs4 = HasLinkThing.objects.create()
+        l1 = Link.objects.create(content_object=hs3)
+        l2 = Link.objects.create(content_object=hs4)
         self.assertQuerysetEqual(
             HasLinkThing.objects.filter(links=l1),
-            [hs1], lambda x: x)
+            [hs3], lambda x: x)
         self.assertQuerysetEqual(
             HasLinkThing.objects.filter(links=l2),
-            [hs2], lambda x: x)
+            [hs4], lambda x: x)
         self.assertQuerysetEqual(
             HasLinkThing.objects.exclude(links=l2),
-            [hs1], lambda x: x)
+            [hs1, hs2, hs3], lambda x: x, ordered=False)
         self.assertQuerysetEqual(
             HasLinkThing.objects.exclude(links=l1),
-            [hs2], lambda x: x)
+            [hs1, hs2, hs4], lambda x: x, ordered=False)
 
     def test_ticket_20564(self):
         b1 = B.objects.create()
diff --git a/tests/queries/tests.py b/tests/queries/tests.py
index 9c521c1c03..82d6d1fe7c 100644
--- a/tests/queries/tests.py
+++ b/tests/queries/tests.py
@@ -3678,3 +3678,11 @@ class TestTicket24279(TestCase):
         School.objects.create()
         qs = School.objects.filter(Q(pk__in=()) | Q())
         self.assertQuerysetEqual(qs, [])
+
+
+class TestInvalidValuesRelation(TestCase):
+    def test_invalid_values(self):
+        with self.assertRaises(ValueError):
+            Annotation.objects.filter(tag='abc')
+        with self.assertRaises(ValueError):
+            Annotation.objects.filter(tag__in=[123, 'abc'])