diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 0ca4e59633..e91de27876 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -15,7 +15,10 @@ from django.db.models.fields import ( BLANK_CHOICE_DASH, AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, ) -from django.db.models.lookups import IsNull +from django.db.models.fields.related_lookups import ( + RelatedExact, RelatedGreaterThan, RelatedGreaterThanOrEqual, RelatedIn, + RelatedLessThan, RelatedLessThanOrEqual, +) from django.db.models.query import QuerySet from django.db.models.query_utils import PathInfo from django.utils import six @@ -1336,6 +1339,16 @@ class ForeignObjectRel(object): def one_to_one(self): return self.field.one_to_one + def get_prep_lookup(self, lookup_name, value): + return self.field.get_prep_lookup(lookup_name, value) + + def get_internal_type(self): + return self.field.get_internal_type() + + @property + def db_type(self): + return self.field.db_type + def __repr__(self): return '<%s: %s.%s>' % ( type(self).__name__, @@ -1760,67 +1773,25 @@ class ForeignObject(RelatedField): pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] return pathinfos - def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, - raw_value): - from django.db.models.sql.where import SubqueryConstraint, AND, OR - root_constraint = constraint_class() - assert len(targets) == len(sources) - if len(lookups) > 1: - raise exceptions.FieldError( - "Cannot resolve keyword %r into field. Choices are: %s" % ( - lookups[0], - ", ".join(f.name for f in self.model._meta.get_fields()), - ) - ) - lookup_type = lookups[0] + def get_lookup(self, lookup_name): + if lookup_name == 'in': + return RelatedIn + elif lookup_name == 'exact': + return RelatedExact + elif lookup_name == 'gt': + return RelatedGreaterThan + elif lookup_name == 'gte': + return RelatedGreaterThanOrEqual + elif lookup_name == 'lt': + return RelatedLessThan + elif lookup_name == 'lte': + return RelatedLessThanOrEqual + elif lookup_name != 'isnull': + raise TypeError('Related Field got invalid lookup: %s' % lookup_name) + return super(ForeignObject, self).get_lookup(lookup_name) - def get_normalized_value(value): - from django.db.models import Model - if isinstance(value, Model): - value_list = [] - for source in sources: - # Account for one-to-one relations when sent a different model - while not isinstance(value, source.model) and source.rel: - source = source.rel.to._meta.get_field(source.rel.field_name) - value_list.append(getattr(value, source.attname)) - return tuple(value_list) - elif not isinstance(value, tuple): - return (value,) - return value - - is_multicolumn = len(self.related_fields) > 1 - if (hasattr(raw_value, '_as_sql') or - hasattr(raw_value, 'get_compiler')): - root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets], - [source.name for source in sources], raw_value), - AND) - elif lookup_type == 'isnull': - root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND) - elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte'] - and not is_multicolumn)): - value = get_normalized_value(raw_value) - for target, source, val in zip(targets, sources, value): - lookup_class = target.get_lookup(lookup_type) - root_constraint.add( - lookup_class(target.get_col(alias, source), val), AND) - elif lookup_type in ['range', 'in'] and not is_multicolumn: - values = [get_normalized_value(value) for value in raw_value] - value = [val[0] for val in values] - lookup_class = targets[0].get_lookup(lookup_type) - root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND) - elif lookup_type == 'in': - values = [get_normalized_value(value) for value in raw_value] - root_constraint.connector = OR - for value in values: - value_constraint = constraint_class() - for source, target, val in zip(sources, targets, value): - lookup_class = target.get_lookup('exact') - lookup = lookup_class(target.get_col(alias, source), val) - value_constraint.add(lookup, AND) - root_constraint.add(value_constraint, OR) - else: - raise TypeError('Related Field got invalid lookup: %s' % lookup_type) - return root_constraint + def get_transform(self, *args, **kwargs): + raise NotImplementedError('Relational fields do not support transforms.') @property def attnames(self): @@ -2017,6 +1988,9 @@ class ForeignKey(ForeignObject): else: return self.related_field.get_db_prep_save(value, connection=connection) + def get_db_prep_value(self, value, connection, prepared=False): + return self.related_field.get_db_prep_value(value, connection, prepared) + def value_to_string(self, obj): if not obj: # In required many-to-one fields with only one available choice, diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py new file mode 100644 index 0000000000..b689c9928e --- /dev/null +++ b/django/db/models/fields/related_lookups.py @@ -0,0 +1,130 @@ +from django.db.models.lookups import ( + Exact, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, +) + + +class MultiColSource(object): + contains_aggregate = False + + def __init__(self, alias, targets, sources, field): + self.targets, self.sources, self.field, self.alias = targets, sources, field, alias + self.output_field = self.field + + def __repr__(self): + return "{}({}, {})".format( + self.__class__.__name__, self.alias, self.field) + + def relabeled_clone(self, relabels): + return self.__class__(relabels.get(self.alias, self.alias), + self.targets, self.sources, self.field) + + +def get_normalized_value(value, lhs): + from django.db.models import Model + if isinstance(value, Model): + value_list = [] + # Account for one-to-one relations when sent a different model + sources = lhs.output_field.get_path_info()[-1].target_fields + for source in sources: + while not isinstance(value, source.model) and source.rel: + source = source.rel.to._meta.get_field(source.rel.field_name) + value_list.append(getattr(value, source.attname)) + return tuple(value_list) + if not isinstance(value, tuple): + return (value,) + return value + + +class RelatedIn(In): + def get_prep_lookup(self): + if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value(): + # If we get here, we are dealing with single-column relations. + self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] + # We need to run the related field's get_prep_lookup(). Consider case + # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself + # doesn't have validation for non-integers, so we must run validation + # using the target field. + if hasattr(self.lhs.output_field, 'get_path_info'): + # Run the target field's get_prep_lookup. We can safely assume there is + # only one as we don't get to the direct value branch otherwise. + self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup( + self.lookup_name, self.rhs) + return super(RelatedIn, self).get_prep_lookup() + + def as_sql(self, compiler, connection): + if isinstance(self.lhs, MultiColSource): + # For multicolumn lookups we need to build a multicolumn where clause. + # This clause is either a SubqueryConstraint (for values that need to be compiled to + # SQL) or a OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses. + from django.db.models.sql.where import WhereNode, SubqueryConstraint, AND, OR + + root_constraint = WhereNode(connector=OR) + if self.rhs_is_direct_value(): + values = [get_normalized_value(value, self.lhs) for value in self.rhs] + for value in values: + value_constraint = WhereNode() + for source, target, val in zip(self.lhs.sources, self.lhs.targets, value): + lookup_class = target.get_lookup('exact') + lookup = lookup_class(target.get_col(self.lhs.alias, source), val) + value_constraint.add(lookup, AND) + root_constraint.add(value_constraint, OR) + else: + root_constraint.add( + SubqueryConstraint( + self.lhs.alias, [target.column for target in self.lhs.targets], + [source.name for source in self.lhs.sources], self.rhs), + AND) + return root_constraint.as_sql(compiler, connection) + else: + return super(RelatedIn, self).as_sql(compiler, connection) + + +class RelatedLookupMixin(object): + def get_prep_lookup(self): + if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value(): + # If we get here, we are dealing with single-column relations. + self.rhs = get_normalized_value(self.rhs, self.lhs)[0] + # We need to run the related field's get_prep_lookup(). Consider case + # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself + # doesn't have validation for non-integers, so we must run validation + # using the target field. + if hasattr(self.lhs.output_field, 'get_path_info'): + # Get the target field. We can safely assume there is only one + # as we don't get to the direct value branch otherwise. + self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup( + self.lookup_name, self.rhs) + + return super(RelatedLookupMixin, self).get_prep_lookup() + + def as_sql(self, compiler, connection): + if isinstance(self.lhs, MultiColSource): + assert self.rhs_is_direct_value() + self.rhs = get_normalized_value(self.rhs, self.lhs) + from django.db.models.sql.where import WhereNode, AND + root_constraint = WhereNode() + for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs): + lookup_class = target.get_lookup(self.lookup_name) + root_constraint.add( + lookup_class(target.get_col(self.lhs.alias, source), val), AND) + return root_constraint.as_sql(compiler, connection) + return super(RelatedLookupMixin, self).as_sql(compiler, connection) + + +class RelatedExact(RelatedLookupMixin, Exact): + pass + + +class RelatedLessThan(RelatedLookupMixin, LessThan): + pass + + +class RelatedGreaterThan(RelatedLookupMixin, GreaterThan): + pass + + +class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual): + pass + + +class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual): + pass diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 005bd99956..35e3fe58bc 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -250,7 +250,7 @@ deferred_class_factory.__safe_for_unpickling__ = True def refs_aggregate(lookup_parts, aggregates): """ - A little helper method to check if the lookup_parts contains references + A helper method to check if the lookup_parts contains references to the given aggregates set. Because the LOOKUP_SEP is contained in the default annotation names we must check each prefix of the lookup_parts for a match. @@ -260,3 +260,17 @@ def refs_aggregate(lookup_parts, aggregates): if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate: return aggregates[level_n_lookup], lookup_parts[n:] return False, () + + +def refs_expression(lookup_parts, annotations): + """ + A helper method to check if the lookup_parts contains references + to the given annotations set. Because the LOOKUP_SEP is contained in the + default annotation names we must check each prefix of the lookup_parts + for a match. + """ + for n in range(len(lookup_parts) + 1): + level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) + if level_n_lookup in annotations and annotations[level_n_lookup]: + return annotations[level_n_lookup], lookup_parts[n:] + return False, () diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 14e079ee84..9c19e819fd 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -17,7 +17,8 @@ from django.db import DEFAULT_DB_ALIAS, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref -from django.db.models.query_utils import Q, PathInfo, refs_aggregate +from django.db.models.fields.related_lookups import MultiColSource +from django.db.models.query_utils import Q, PathInfo, refs_expression from django.db.models.sql.constants import ( INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE, ) @@ -1006,7 +1007,7 @@ class Query(object): """ lookup_splitted = lookup.split(LOOKUP_SEP) if self._annotations: - aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations) + aggregate, aggregate_lookups = refs_expression(lookup_splitted, self.annotations) if aggregate: return aggregate_lookups, (), aggregate _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) @@ -1157,24 +1158,26 @@ class Query(object): if can_reuse is not None: can_reuse.update(join_list) used_joins = set(used_joins).union(set(join_list)) - - # Process the join list to see if we can remove any non-needed joins from - # the far end (fewer tables in a query is better). targets, alias, join_list = self.trim_joins(sources, join_list, path) - if hasattr(field, 'get_lookup_constraint'): - # For now foreign keys get special treatment. This should be - # refactored when composite fields lands. - condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, - lookups, value) - lookup_type = lookups[-1] - else: - assert(len(targets) == 1) - if hasattr(targets[0], 'as_sql'): - # handle Expressions as annotations - col = targets[0] + if field.is_relation: + # No support for transforms for relational fields + assert len(lookups) == 1 + lookup_class = field.get_lookup(lookups[0]) + # Undo the changes done in setup_joins() if hasattr(final_field, 'field') branch + # This hack is needed as long as the field.rel isn't like a real field. + if field.get_path_info()[-1].target_fields != sources: + target_field = field.rel else: - col = targets[0].get_col(alias, field) + target_field = field + if len(targets) == 1: + lhs = targets[0].get_col(alias, target_field) + else: + lhs = MultiColSource(alias, targets, sources, target_field) + condition = lookup_class(lhs, value) + lookup_type = lookup_class.lookup_name + else: + col = targets[0].get_col(alias, field) condition = self.build_lookup(lookups, col, value) lookup_type = condition.lookup_name @@ -1284,14 +1287,6 @@ class Query(object): ) model = field.model._meta.concrete_model except FieldDoesNotExist: - # is it an annotation? - if self._annotations and name in self._annotations: - field, model = self._annotations[name], None - if not field.contains_aggregate: - # Local non-relational field. - final_field = field - targets = (field,) - break # We didn't find the current field, so move position back # one step. pos -= 1 @@ -1985,7 +1980,7 @@ def is_reverse_o2o(field): A little helper to check if the given field is reverse-o2o. The field is expected to be some sort of relation field or related object. """ - return not hasattr(field, 'rel') and field.field.unique + return field.is_relation and field.one_to_one and not field.concrete class JoinPromoter(object): diff --git a/tests/generic_relations_regress/tests.py b/tests/generic_relations_regress/tests.py index 0d78223725..b6782fe13f 100644 --- a/tests/generic_relations_regress/tests.py +++ b/tests/generic_relations_regress/tests.py @@ -144,22 +144,26 @@ class GenericRelationTests(TestCase): tag.save() def test_ticket_20378(self): + # Create a couple of extra HasLinkThing so that the autopk value + # isn't the same for Link and HasLinkThing. hs1 = HasLinkThing.objects.create() hs2 = HasLinkThing.objects.create() - l1 = Link.objects.create(content_object=hs1) - l2 = Link.objects.create(content_object=hs2) + hs3 = HasLinkThing.objects.create() + hs4 = HasLinkThing.objects.create() + l1 = Link.objects.create(content_object=hs3) + l2 = Link.objects.create(content_object=hs4) self.assertQuerysetEqual( HasLinkThing.objects.filter(links=l1), - [hs1], lambda x: x) + [hs3], lambda x: x) self.assertQuerysetEqual( HasLinkThing.objects.filter(links=l2), - [hs2], lambda x: x) + [hs4], lambda x: x) self.assertQuerysetEqual( HasLinkThing.objects.exclude(links=l2), - [hs1], lambda x: x) + [hs1, hs2, hs3], lambda x: x, ordered=False) self.assertQuerysetEqual( HasLinkThing.objects.exclude(links=l1), - [hs2], lambda x: x) + [hs1, hs2, hs4], lambda x: x, ordered=False) def test_ticket_20564(self): b1 = B.objects.create() diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 9c521c1c03..82d6d1fe7c 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -3678,3 +3678,11 @@ class TestTicket24279(TestCase): School.objects.create() qs = School.objects.filter(Q(pk__in=()) | Q()) self.assertQuerysetEqual(qs, []) + + +class TestInvalidValuesRelation(TestCase): + def test_invalid_values(self): + with self.assertRaises(ValueError): + Annotation.objects.filter(tag='abc') + with self.assertRaises(ValueError): + Annotation.objects.filter(tag__in=[123, 'abc'])