diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index 862f2787bc..6aff07e568 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -205,17 +205,16 @@ class GenericRelation(RelatedField, Field): # same db_type as well. return None - def extra_filters(self, pieces, pos, negate): + def get_content_type(self): """ - Return an extra filter to the queryset so that the results are filtered - on the appropriate content type. + Returns the content type associated with this field's model. """ - if negate: - return [] - content_type = ContentType.objects.get_for_model(self.model) - prefix = "__".join(pieces[:pos + 1]) - return [("%s__%s" % (prefix, self.content_type_field_name), - content_type)] + return ContentType.objects.get_for_model(self.model) + + def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias): + extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column + contenttype = self.get_content_type().pk + return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype] def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): """ @@ -246,9 +245,6 @@ class ReverseGenericRelatedObjectsDescriptor(object): if instance is None: return self - # This import is done here to avoid circular import importing this module - from django.contrib.contenttypes.models import ContentType - # Dynamically create a class that subclasses the related model's # default manager. rel_model = self.field.rel.to @@ -379,8 +375,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet): def __init__(self, data=None, files=None, instance=None, save_as_new=None, prefix=None, queryset=None): - # Avoid a circular import. - from django.contrib.contenttypes.models import ContentType opts = self.model._meta self.instance = instance self.rel_name = '-'.join(( @@ -409,8 +403,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet): )) def save_new(self, form, commit=True): - # Avoid a circular import. - from django.contrib.contenttypes.models import ContentType kwargs = { self.ct_field.get_attname(): ContentType.objects.get_for_model(self.instance).pk, self.ct_fk_field.get_attname(): self.instance.pk, @@ -432,8 +424,6 @@ def generic_inlineformset_factory(model, form=ModelForm, defaults ``content_type`` and ``object_id`` respectively. """ opts = model._meta - # Avoid a circular import. - from django.contrib.contenttypes.models import ContentType # if there is no field called `ct_field` let the exception propagate ct_field = opts.get_field(ct_field) if not isinstance(ct_field, models.ForeignKey) or ct_field.rel.to != ContentType: diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 8cfb12a8e3..4d846fb438 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -274,7 +274,8 @@ class SQLCompiler(object): except KeyError: link_field = opts.get_ancestor_link(model) alias = self.query.join((start_alias, model._meta.db_table, - link_field.column, model._meta.pk.column)) + link_field.column, model._meta.pk.column), + join_field=link_field) seen[model] = alias else: # If we're starting from the base model of the queryset, the @@ -448,8 +449,8 @@ class SQLCompiler(object): """ if not alias: alias = self.query.get_initial_alias() - field, target, opts, joins, _, _ = self.query.setup_joins(pieces, - opts, alias, REUSE_ALL) + field, target, opts, joins, _ = self.query.setup_joins( + pieces, opts, alias, REUSE_ALL) # We will later on need to promote those joins that were added to the # query afresh above. joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2] @@ -501,20 +502,27 @@ class SQLCompiler(object): qn = self.quote_name_unless_alias qn2 = self.connection.ops.quote_name first = True + from_params = [] for alias in self.query.tables: if not self.query.alias_refcount[alias]: continue try: - name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias] + name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias] except KeyError: # Extra tables can end up in self.tables, but not in the # alias_map if they aren't in a join. That's OK. We skip them. continue alias_str = (alias != name and ' %s' % alias or '') if join_type and not first: - result.append('%s %s%s ON (%s.%s = %s.%s)' - % (join_type, qn(name), alias_str, qn(lhs), - qn2(lhs_col), qn(alias), qn2(col))) + if join_field and hasattr(join_field, 'get_extra_join_sql'): + extra_cond, extra_params = join_field.get_extra_join_sql( + self.connection, qn, lhs, alias) + from_params.extend(extra_params) + else: + extra_cond = "" + result.append('%s %s%s ON (%s.%s = %s.%s%s)' % + (join_type, qn(name), alias_str, qn(lhs), + qn2(lhs_col), qn(alias), qn2(col), extra_cond)) else: connector = not first and ', ' or '' result.append('%s%s%s' % (connector, qn(name), alias_str)) @@ -528,7 +536,7 @@ class SQLCompiler(object): connector = not first and ', ' or '' result.append('%s%s' % (connector, qn(alias))) first = False - return result, [] + return result, from_params def get_grouping(self, ordering_group_by): """ @@ -638,7 +646,7 @@ class SQLCompiler(object): alias = self.query.join((alias, table, f.column, f.rel.get_related_field().column), - promote=promote) + promote=promote, join_field=f) columns, aliases = self.get_default_columns(start_alias=alias, opts=f.rel.to._meta, as_pairs=True) self.query.related_select_cols.extend( @@ -685,7 +693,7 @@ class SQLCompiler(object): alias_chain.append(alias) alias = self.query.join( (alias, table, f.rel.get_related_field().column, f.column), - promote=True + promote=True, join_field=f ) from_parent = (opts.model if issubclass(model, opts.model) else None) diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 6e1d2dd87a..1c34f70169 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -18,12 +18,19 @@ QUERY_TERMS = set([ # Larger values are slightly faster at the expense of more storage space. GET_ITERATOR_CHUNK_SIZE = 100 -# Constants to make looking up tuple values clearer. +# Namedtuples for sql.* internal use. + # Join lists (indexes into the tuples that are values in the alias_map # dictionary in the Query class). JoinInfo = namedtuple('JoinInfo', 'table_name rhs_alias join_type lhs_alias ' - 'lhs_join_col rhs_join_col nullable') + 'lhs_join_col rhs_join_col nullable join_field') + +# PathInfo is used when converting lookups (fk__somecol). The contents +# describe the join in Model terms (model Options and Fields for both +# sides of the join. The rel_field is the field we are joining along. +PathInfo = namedtuple('PathInfo', + 'from_field to_field from_opts to_opts join_field') # Pairs of column clauses to select, and (possibly None) field for the clause. SelectInfo = namedtuple('SelectInfo', 'col field') diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index c809e25b49..af7e45e74e 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -50,10 +50,10 @@ class SQLEvaluator(object): self.cols.append((node, query.aggregate_select[node.name])) else: try: - field, source, opts, join_list, last, _ = query.setup_joins( + field, source, opts, join_list, path = query.setup_joins( field_list, query.get_meta(), query.get_initial_alias(), self.reuse) - col, _, join_list = query.trim_joins(source, join_list, last, False) + col, _, join_list = query.trim_joins(source, join_list, path) if self.reuse is not None and self.reuse != REUSE_ALL: self.reuse.update(join_list) self.cols.append((node, (join_list[-1], col))) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 0e7bd92aa2..f6b812c54d 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -14,13 +14,13 @@ from django.utils.encoding import force_text from django.utils.tree import Node from django.utils import six from django.db import connections, DEFAULT_DB_ALIAS -from django.db.models import signals from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist +from django.db.models.loading import get_model from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, - ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo) + ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo, PathInfo) from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, @@ -119,7 +119,7 @@ class Query(object): self.filter_is_sticky = False self.included_inherited_models = {} - # SQL-related attributes + # SQL-related attributes # Select and related select clauses as SelectInfo instances. # The select is used for cases where we want to set up the select # clause to contain other than default fields (values(), annotate(), @@ -201,6 +201,16 @@ class Query(object): (s.col, s.field is not None and s.field.name or None) for s in obj_dict['select'] ] + # alias_map can also contain references to fields. + new_alias_map = {} + for alias, join_info in obj_dict['alias_map'].items(): + if join_info.join_field is None: + new_alias_map[alias] = join_info + else: + model = join_info.join_field.model._meta + field_id = (model.app_label, model.object_name, join_info.join_field.name) + new_alias_map[alias] = join_info._replace(join_field=field_id) + obj_dict['alias_map'] = new_alias_map return obj_dict def __setstate__(self, obj_dict): @@ -213,6 +223,15 @@ class Query(object): SelectInfo(tpl[0], tpl[1] is not None and opts.get_field(tpl[1]) or None) for tpl in obj_dict['select'] ] + new_alias_map = {} + for alias, join_info in obj_dict['alias_map'].items(): + if join_info.join_field is None: + new_alias_map[alias] = join_info + else: + field_id = join_info.join_field + new_alias_map[alias] = join_info._replace( + join_field=get_model(field_id[0], field_id[1])._meta.get_field(field_id[2])) + obj_dict['alias_map'] = new_alias_map self.__dict__.update(obj_dict) @@ -479,21 +498,26 @@ class Query(object): # Now, add the joins from rhs query into the new query (skipping base # table). for alias in rhs.tables[1:]: - if not rhs.alias_refcount[alias]: - continue - table, _, join_type, lhs, lhs_col, col, nullable = rhs.alias_map[alias] + table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias] promote = (join_type == self.LOUTER) # If the left side of the join was already relabeled, use the # updated alias. lhs = change_map.get(lhs, lhs) new_alias = self.join( (lhs, table, lhs_col, col), reuse=reuse, promote=promote, - outer_if_first=not conjunction, nullable=nullable) + outer_if_first=not conjunction, nullable=nullable, + join_field=join_field) # We can't reuse the same join again in the query. If we have two # distinct joins for the same connection in rhs query, then the # combined query must have two joins, too. reuse.discard(new_alias) change_map[alias] = new_alias + if not rhs.alias_refcount[alias]: + # The alias was unused in the rhs query. Unref it so that it + # will be unused in the new query, too. We have to add and + # unref the alias so that join promotion has information of + # the join type for the unused alias. + self.unref_alias(new_alias) # So that we don't exclude valid results in an "or" query combination, # all joins exclusive to either the lhs or the rhs must be converted @@ -868,7 +892,7 @@ class Query(object): return len([1 for count in self.alias_refcount.values() if count]) def join(self, connection, reuse=REUSE_ALL, promote=False, - outer_if_first=False, nullable=False): + outer_if_first=False, nullable=False, join_field=None): """ Returns an alias for the join in 'connection', either reusing an existing alias for that join or creating a new one. 'connection' is a @@ -897,6 +921,8 @@ class Query(object): If 'nullable' is True, the join can potentially involve NULL values and is a candidate for promotion (to "left outer") when combining querysets. + + The 'join_field' is the field we are joining along (if any). """ lhs, table, lhs_col, col = connection existing = self.join_map.get(connection, ()) @@ -906,8 +932,13 @@ class Query(object): reuse = set() else: reuse = [a for a in existing if a in reuse] - if reuse: - alias = reuse[0] + for alias in reuse: + if join_field and self.alias_map[alias].join_field != join_field: + # The join_map doesn't contain join_field (mainly because + # fields in Query structs are problematic in pickling), so + # check that the existing join is created using the same + # join_field used for the under work join. + continue self.ref_alias(alias) if promote or (lhs and self.alias_map[lhs].join_type == self.LOUTER): self.promote_joins([alias]) @@ -926,7 +957,8 @@ class Query(object): join_type = self.LOUTER else: join_type = self.INNER - join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable) + join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable, + join_field) self.alias_map[alias] = join if connection in self.join_map: self.join_map[connection] += (alias,) @@ -1007,11 +1039,11 @@ class Query(object): # - this is an annotation over a model field # then we need to explore the joins that are required. - field, source, opts, join_list, last, _ = self.setup_joins( + field, source, opts, join_list, path = self.setup_joins( field_list, opts, self.get_initial_alias(), REUSE_ALL) # Process the join chain to see if it can be trimmed - col, _, join_list = self.trim_joins(source, join_list, last, False) + col, _, join_list = self.trim_joins(source, join_list, path) # If the aggregate references a model or field that requires a join, # those joins must be LEFT OUTER - empty join rows must be returned @@ -1030,7 +1062,7 @@ class Query(object): aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) def add_filter(self, filter_expr, connector=AND, negate=False, - can_reuse=None, process_extras=True, force_having=False): + can_reuse=None, force_having=False): """ Add a single filter to the query. The 'filter_expr' is a pair: (filter_string, value). E.g. ('name__contains', 'fred') @@ -1047,10 +1079,6 @@ class Query(object): will be a set of table aliases that can be reused in this filter, even if we would otherwise force the creation of new aliases for a join (needed for nested Q-filters). The set is updated by this method. - - If 'process_extras' is set, any extra filters returned from the table - joining process will be processed. This parameter is set to False - during the processing of extra filters to avoid infinite recursion. """ arg, value = filter_expr parts = arg.split(LOOKUP_SEP) @@ -1115,10 +1143,11 @@ class Query(object): allow_many = not negate try: - field, target, opts, join_list, last, extra_filters = self.setup_joins( + field, target, opts, join_list, path = self.setup_joins( parts, opts, alias, can_reuse, allow_many, - allow_explicit_fk=True, negate=negate, - process_extras=process_extras) + allow_explicit_fk=True) + if can_reuse is not None: + can_reuse.update(join_list) except MultiJoin as e: self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), can_reuse) @@ -1136,10 +1165,10 @@ class Query(object): join_promote = True # Process the join list to see if we can remove any inner joins from - # the far end (fewer tables in a query is better). - nonnull_comparison = (lookup_type == 'isnull' and value is False) - col, alias, join_list = self.trim_joins(target, join_list, last, - nonnull_comparison) + # the far end (fewer tables in a query is better). Note that join + # promotion must happen before join trimming to have the join type + # information available when reusing joins. + col, alias, join_list = self.trim_joins(target, join_list, path) if connector == OR: # Some joins may need to be promoted when adding a new filter to a @@ -1212,12 +1241,6 @@ class Query(object): # is added in upper layers of the code. self.where.add((Constraint(alias, col, None), 'isnull', False), AND) - if can_reuse is not None: - can_reuse.update(join_list) - if process_extras: - for filter in extra_filters: - self.add_filter(filter, negate=negate, can_reuse=can_reuse, - process_extras=False) def add_q(self, q_object, used_aliases=None, force_having=False): """ @@ -1270,37 +1293,24 @@ class Query(object): if self.filter_is_sticky: self.used_aliases = used_aliases - def setup_joins(self, names, opts, alias, can_reuse, allow_many=True, - allow_explicit_fk=False, negate=False, process_extras=True): + def names_to_path(self, names, opts, allow_many=False, + allow_explicit_fk=True): """ - Compute the necessary table joins for the passage through the fields - given in 'names'. 'opts' is the Options class for the current model - (which gives the table we are joining to), 'alias' is the alias for the - table we are joining to. + Walks the names path and turns them PathInfo tuples. Note that a + single name in 'names' can generate multiple PathInfos (m2m for + example). - The 'can_reuse' defines the reverse foreign key joins we can reuse. It - can be either sql.constants.REUSE_ALL in which case all joins are - reusable or a set of aliases that can be reused. Non-reverse foreign - keys are always reusable. + 'names' is the path of names to travle, 'opts' is the model Options we + start the name resolving from, 'allow_many' and 'allow_explicit_fk' + are as for setup_joins(). - The 'allow_explicit_fk' controls if field.attname is allowed in the - lookups. - - Finally, 'negate' is used in the same sense as for add_filter() - -- it indicates an exclude() filter, or something similar. It is only - passed in here so that it can be passed to a field's extra_filter() for - customized behavior. - - Returns the final field involved in the join, the target database - column (used for any 'where' constraint), the final 'opts' value and the - list of tables joined. + Returns a list of PathInfo tuples. In addition returns the final field + (the last used join field), and target (which is a field guaranteed to + contain the same value as the final field). """ - joins = [alias] - last = [0] - extra_filters = [] - int_alias = None + path = [] + multijoin_pos = None for pos, name in enumerate(names): - last.append(len(joins)) if name == 'pk': name = opts.pk.name try: @@ -1314,14 +1324,12 @@ class Query(object): field, model, direct, m2m = opts.get_field_by_name(f.name) break else: - names = opts.get_all_field_names() + list(self.aggregate_select) + available = opts.get_all_field_names() + list(self.aggregate_select) raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (name, ", ".join(names))) - - if not allow_many and (m2m or not direct): - for alias in joins: - self.unref_alias(alias) - raise MultiJoin(pos + 1) + "Choices are: %s" % (name, ", ".join(available))) + # Check if we need any joins for concrete inheritance cases (the + # field lives in parent, but we are currently in one of its + # children) if model: # The field lives on a base class of the current model. # Skip the chain of proxy to the concrete proxied model @@ -1331,172 +1339,179 @@ class Query(object): if int_model is proxied_model: opts = int_model._meta else: - lhs_col = opts.parents[int_model].column + final_field = opts.parents[int_model] + target = final_field.rel.get_related_field() opts = int_model._meta - alias = self.join((alias, opts.db_table, lhs_col, - opts.pk.column)) - joins.append(alias) - cached_data = opts._join_cache.get(name) - orig_opts = opts - - if process_extras and hasattr(field, 'extra_filters'): - extra_filters.extend(field.extra_filters(names, pos, negate)) - if direct: - if m2m: - # Many-to-many field defined on the current model. - if cached_data: - (table1, from_col1, to_col1, table2, from_col2, - to_col2, opts, target) = cached_data - else: - table1 = field.m2m_db_table() - from_col1 = opts.get_field_by_name( - field.m2m_target_field_name())[0].column - to_col1 = field.m2m_column_name() - opts = field.rel.to._meta - table2 = opts.db_table - from_col2 = field.m2m_reverse_name() - to_col2 = opts.get_field_by_name( - field.m2m_reverse_target_field_name())[0].column - target = opts.pk - orig_opts._join_cache[name] = (table1, from_col1, - to_col1, table2, from_col2, to_col2, opts, - target) - - int_alias = self.join((alias, table1, from_col1, to_col1), - reuse=can_reuse, nullable=True) - if int_alias == table2 and from_col2 == to_col2: - joins.append(int_alias) - alias = int_alias - else: - alias = self.join( - (int_alias, table2, from_col2, to_col2), - reuse=can_reuse, nullable=True) - joins.extend([int_alias, alias]) - elif field.rel: - # One-to-one or many-to-one field - if cached_data: - (table, from_col, to_col, opts, target) = cached_data - else: - opts = field.rel.to._meta - target = field.rel.get_related_field() - table = opts.db_table - from_col = field.column - to_col = target.column - orig_opts._join_cache[name] = (table, from_col, to_col, - opts, target) - - alias = self.join((alias, table, from_col, to_col), - nullable=self.is_nullable(field)) - joins.append(alias) + path.append(PathInfo(final_field, target, final_field.model._meta, + opts, final_field)) + # We have five different cases to solve: foreign keys, reverse + # foreign keys, m2m fields (also reverse) and non-relational + # fields. We are mostly just using the related field API to + # fetch the from and to fields. The m2m fields are handled as + # two foreign keys, first one reverse, the second one direct. + if direct and not field.rel and not m2m: + # Local non-relational field. + final_field = target = field + break + elif direct and not m2m: + # Foreign Key + opts = field.rel.to._meta + target = field.rel.get_related_field() + final_field = field + from_opts = field.model._meta + path.append(PathInfo(field, target, from_opts, opts, field)) + elif not direct and not m2m: + # Revere foreign key + final_field = to_field = field.field + opts = to_field.model._meta + from_field = to_field.rel.get_related_field() + from_opts = from_field.model._meta + path.append( + PathInfo(from_field, to_field, from_opts, opts, to_field)) + if from_field.model is to_field.model: + # Recursive foreign key to self. + target = opts.get_field_by_name( + field.field.rel.field_name)[0] else: - # Non-relation fields. - target = field - break - else: - orig_field = field + target = opts.pk + elif direct and m2m: + if not field.rel.through: + # Gotcha! This is just a fake m2m field - a generic relation + # field). + from_field = opts.pk + opts = field.rel.to._meta + target = opts.get_field_by_name(field.object_id_field_name)[0] + final_field = field + # Note that we are using different field for the join_field + # than from_field or to_field. This is a hack, but we need the + # GenericRelation to generate the extra SQL. + path.append(PathInfo(from_field, target, field.model._meta, opts, + field)) + else: + # m2m field. We are travelling first to the m2m table along a + # reverse relation, then from m2m table to the target table. + from_field1 = opts.get_field_by_name( + field.m2m_target_field_name())[0] + opts = field.rel.through._meta + to_field1 = opts.get_field_by_name(field.m2m_field_name())[0] + path.append( + PathInfo(from_field1, to_field1, from_field1.model._meta, + opts, to_field1)) + final_field = from_field2 = opts.get_field_by_name( + field.m2m_reverse_field_name())[0] + opts = field.rel.to._meta + target = to_field2 = opts.get_field_by_name( + field.m2m_reverse_target_field_name())[0] + path.append( + PathInfo(from_field2, to_field2, from_field2.model._meta, + opts, from_field2)) + elif not direct and m2m: + # This one is just like above, except we are travelling the + # fields in opposite direction. field = field.field - if m2m: - # Many-to-many field defined on the target model. - if cached_data: - (table1, from_col1, to_col1, table2, from_col2, - to_col2, opts, target) = cached_data - else: - table1 = field.m2m_db_table() - from_col1 = opts.get_field_by_name( - field.m2m_reverse_target_field_name())[0].column - to_col1 = field.m2m_reverse_name() - opts = orig_field.opts - table2 = opts.db_table - from_col2 = field.m2m_column_name() - to_col2 = opts.get_field_by_name( - field.m2m_target_field_name())[0].column - target = opts.pk - orig_opts._join_cache[name] = (table1, from_col1, - to_col1, table2, from_col2, to_col2, opts, - target) + from_field1 = opts.get_field_by_name( + field.m2m_reverse_target_field_name())[0] + int_opts = field.rel.through._meta + to_field1 = int_opts.get_field_by_name( + field.m2m_reverse_field_name())[0] + path.append( + PathInfo(from_field1, to_field1, from_field1.model._meta, + int_opts, to_field1)) + final_field = from_field2 = int_opts.get_field_by_name( + field.m2m_field_name())[0] + opts = field.opts + target = to_field2 = opts.get_field_by_name( + field.m2m_target_field_name())[0] + path.append(PathInfo(from_field2, to_field2, from_field2.model._meta, + opts, from_field2)) - int_alias = self.join((alias, table1, from_col1, to_col1), - reuse=can_reuse, nullable=True) - alias = self.join((int_alias, table2, from_col2, to_col2), - reuse=can_reuse, nullable=True) - joins.extend([int_alias, alias]) - else: - # One-to-many field (ForeignKey defined on the target model) - if cached_data: - (table, from_col, to_col, opts, target) = cached_data - else: - local_field = opts.get_field_by_name( - field.rel.field_name)[0] - opts = orig_field.opts - table = opts.db_table - from_col = local_field.column - to_col = field.column - # In case of a recursive FK, use the to_field for - # reverse lookups as well - if orig_field.model is local_field.model: - target = opts.get_field_by_name( - field.rel.field_name)[0] - else: - target = opts.pk - orig_opts._join_cache[name] = (table, from_col, to_col, - opts, target) - - alias = self.join((alias, table, from_col, to_col), - reuse=can_reuse, nullable=True) - joins.append(alias) + if m2m and multijoin_pos is None: + multijoin_pos = pos + if not direct and not path[-1].to_field.unique and multijoin_pos is None: + multijoin_pos = pos if pos != len(names) - 1: if pos == len(names) - 2: - raise FieldError("Join on field %r not permitted. Did you misspell %r for the lookup type?" % (name, names[pos + 1])) + raise FieldError( + "Join on field %r not permitted. Did you misspell %r for " + "the lookup type?" % (name, names[pos + 1])) else: raise FieldError("Join on field %r not permitted." % name) + if multijoin_pos is not None and len(path) >= multijoin_pos and not allow_many: + raise MultiJoin(multijoin_pos + 1) + return path, final_field, target - return field, target, opts, joins, last, extra_filters - - def trim_joins(self, target, join_list, last, nonnull_check=False): + def setup_joins(self, names, opts, alias, can_reuse, allow_many=True, + allow_explicit_fk=False): """ - Sometimes joins at the end of a multi-table sequence can be trimmed. If - the final join is against the same column as we are comparing against, - and is an inner join, we can go back one step in a join chain and - compare against the LHS of the join instead (and then repeat the - optimization). The result, potentially, involves fewer table joins. + Compute the necessary table joins for the passage through the fields + given in 'names'. 'opts' is the Options class for the current model + (which gives the table we are starting from), 'alias' is the alias for + the table to start the joining from. - The 'target' parameter is the final field being joined to, 'join_list' - is the full list of join aliases. + The 'can_reuse' defines the reverse foreign key joins we can reuse. It + can be sql.constants.REUSE_ALL in which case all joins are reusable + or a set of aliases that can be reused. Note that Non-reverse foreign + keys are always reusable. - The 'last' list contains offsets into 'join_list', corresponding to - each component of the filter. Many-to-many relations, for example, add - two tables to the join list and we want to deal with both tables the - same way, so 'last' has an entry for the first of the two tables and - then the table immediately after the second table, in that case. + If 'allow_many' is False, then any reverse foreign key seen will + generate a MultiJoin exception. - The 'nonnull_check' parameter is True when we are using inner joins - between tables explicitly to exclude NULL entries. In that case, the - tables shouldn't be trimmed, because the very action of joining to them - alters the result set. + The 'allow_explicit_fk' controls if field.attname is allowed in the + lookups. + + Returns the final field involved in the joins, the target field (used + for any 'where' constraint), the final 'opts' value, the joins and the + field path travelled to generate the joins. + + The target field is the field containing the concrete value. Final + field can be something different, for example foreign key pointing to + that value. Final field is needed for example in some value + conversions (convert 'obj' in fk__id=obj to pk val using the foreign + key field for example). + """ + joins = [alias] + # First, generate the path for the names + path, final_field, target = self.names_to_path( + names, opts, allow_many, allow_explicit_fk) + # Then, add the path to the query's joins. Note that we can't trim + # joins at this stage - we will need the information about join type + # of the trimmed joins. + for pos, join in enumerate(path): + from_field, to_field, from_opts, opts, join_field = join + direct = join_field == from_field + if direct: + nullable = self.is_nullable(from_field) + else: + nullable = True + connection = alias, opts.db_table, from_field.column, to_field.column + alias = self.join(connection, reuse=can_reuse, nullable=nullable, + join_field=join_field) + joins.append(alias) + return final_field, target, opts, joins, path + + def trim_joins(self, target, joins, path): + """ + The 'target' parameter is the final field being joined to, 'joins' + is the full list of join aliases. The 'path' contain the PathInfos + used to create the joins. Returns the final active column and table alias and the new active - join_list. + joins. + + We will always trim any direct join if we have the target column + available already in the previous table. Reverse joins can't be + trimmed as we don't know if there is anything on the other side of + the join. """ - final = len(join_list) - penultimate = last.pop() - if penultimate == final: - penultimate = last.pop() - col = target.column - alias = join_list[-1] - while final > 1: - join = self.alias_map[alias] - if (col != join.rhs_join_col or join.join_type != self.INNER or - nonnull_check): + for info in reversed(path): + direct = info.join_field == info.from_field + if info.to_field == target and direct: + target = info.from_field + self.unref_alias(joins.pop()) + else: break - self.unref_alias(alias) - alias = join.lhs_alias - col = join.lhs_join_col - join_list.pop() - final -= 1 - if final == penultimate: - penultimate = last.pop() - return col, alias, join_list + return target.column, joins[-1], joins def split_exclude(self, filter_expr, prefix, can_reuse): """ @@ -1627,9 +1642,9 @@ class Query(object): try: for name in field_names: - field, target, u2, joins, u3, u4 = self.setup_joins( - name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, - allow_m2m, True) + field, target, u2, joins, u3 = self.setup_joins( + name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, allow_m2m, + True) final_alias = joins[-1] col = target.column if len(joins) > 1: @@ -1918,7 +1933,7 @@ class Query(object): """ opts = self.model._meta alias = self.get_initial_alias() - field, col, opts, joins, last, extra = self.setup_joins( + field, col, opts, joins, extra = self.setup_joins( start.split(LOOKUP_SEP), opts, alias, REUSE_ALL) select_col = self.alias_map[joins[1]].lhs_join_col select_alias = alias @@ -1975,18 +1990,6 @@ def get_order_dir(field, default='ASC'): return field, dirn[0] -def setup_join_cache(sender, **kwargs): - """ - The information needed to join between model fields is something that is - invariant over the life of the model, so we cache it in the model's Options - class, rather than recomputing it all the time. - - This method initialises the (empty) cache when the model is created. - """ - sender._meta._join_cache = {} - -signals.class_prepared.connect(setup_join_cache) - def add_to_dict(data, key, value): """ A helper function to add "value" to the set of values for "key", whether or diff --git a/tests/regressiontests/aggregation_regress/tests.py b/tests/regressiontests/aggregation_regress/tests.py index 9b3cd41e41..596ebbfaec 100644 --- a/tests/regressiontests/aggregation_regress/tests.py +++ b/tests/regressiontests/aggregation_regress/tests.py @@ -978,3 +978,7 @@ class AggregationTests(TestCase): ('The Definitive Guide to Django: Web Development Done Right', 2) ] ) + + def test_reverse_join_trimming(self): + qs = Author.objects.annotate(Count('book_contact_set__contact')) + self.assertIn(' JOIN ', str(qs.query)) diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index f0178a0256..73b9762150 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -283,6 +283,7 @@ class SingleObject(models.Model): class RelatedObject(models.Model): single = models.ForeignKey(SingleObject, null=True) + f = models.IntegerField(null=True) class Meta: ordering = ['single'] @@ -311,7 +312,7 @@ class Food(models.Model): @python_2_unicode_compatible class Eaten(models.Model): - food = models.ForeignKey(Food, to_field="name") + food = models.ForeignKey(Food, to_field="name", null=True) meal = models.CharField(max_length=20) def __str__(self): @@ -400,3 +401,23 @@ class ModelA(models.Model): name = models.TextField() b = models.ForeignKey(ModelB, null=True) d = models.ForeignKey(ModelD) + +@python_2_unicode_compatible +class Job(models.Model): + name = models.CharField(max_length=20, unique=True) + + def __str__(self): + return self.name + +class JobResponsibilities(models.Model): + job = models.ForeignKey(Job, to_field='name') + responsibility = models.ForeignKey('Responsibility', to_field='description') + +@python_2_unicode_compatible +class Responsibility(models.Model): + description = models.CharField(max_length=20, unique=True) + jobs = models.ManyToManyField(Job, through=JobResponsibilities, + related_name='responsibilities') + + def __str__(self): + return self.description diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index e4009cdf20..75e27769b4 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -23,7 +23,8 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover, Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, - SingleObject, RelatedObject, ModelA, ModelD) + SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, + JobResponsibilities) class BaseQuerysetTest(TestCase): @@ -243,7 +244,10 @@ class Queries1Tests(BaseQuerysetTest): q1 = Item.objects.order_by('name') q2 = Item.objects.filter(id=self.i1.id) list(q2) - self.assertEqual(len((q1 & q2).order_by('name').query.tables), 1) + combined_query = (q1 & q2).order_by('name').query + self.assertEqual(len([ + t for t in combined_query.tables if combined_query.alias_refcount[t] + ]), 1) def test_order_by_join_unref(self): """ @@ -883,6 +887,225 @@ class Queries1Tests(BaseQuerysetTest): Item.objects.filter(Q(tags__name__in=['t4', 't3'])), [repr(i) for i in Item.objects.filter(~~Q(tags__name__in=['t4', 't3']))]) + def test_ticket_10790_1(self): + # Querying direct fields with isnull should trim the left outer join. + # It also should not create INNER JOIN. + q = Tag.objects.filter(parent__isnull=True) + + self.assertQuerysetEqual(q, ['']) + self.assertTrue('JOIN' not in str(q.query)) + + q = Tag.objects.filter(parent__isnull=False) + + self.assertQuerysetEqual( + q, + ['', '', '', ''], + ) + self.assertTrue('JOIN' not in str(q.query)) + + q = Tag.objects.exclude(parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', ''], + ) + self.assertTrue('JOIN' not in str(q.query)) + + q = Tag.objects.exclude(parent__isnull=False) + self.assertQuerysetEqual(q, ['']) + self.assertTrue('JOIN' not in str(q.query)) + + q = Tag.objects.exclude(parent__parent__isnull=False) + + self.assertQuerysetEqual( + q, + ['', '', ''], + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1) + self.assertTrue('INNER JOIN' not in str(q.query)) + + def test_ticket_10790_2(self): + # Querying across several tables should strip only the last outer join, + # while preserving the preceeding inner joins. + q = Tag.objects.filter(parent__parent__isnull=False) + + self.assertQuerysetEqual( + q, + ['', ''], + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q.query).count('INNER JOIN') == 1) + + # Querying without isnull should not convert anything to left outer join. + q = Tag.objects.filter(parent__parent=self.t1) + self.assertQuerysetEqual( + q, + ['', ''], + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q.query).count('INNER JOIN') == 1) + + def test_ticket_10790_3(self): + # Querying via indirect fields should populate the left outer join + q = NamedCategory.objects.filter(tag__isnull=True) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1) + # join to dumbcategory ptr_id + self.assertTrue(str(q.query).count('INNER JOIN') == 1) + self.assertQuerysetEqual(q, []) + + # Querying across several tables should strip only the last join, while + # preserving the preceding left outer joins. + q = NamedCategory.objects.filter(tag__parent__isnull=True) + self.assertTrue(str(q.query).count('INNER JOIN') == 1) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1) + self.assertQuerysetEqual( q, ['']) + + def test_ticket_10790_4(self): + # Querying across m2m field should not strip the m2m table from join. + q = Author.objects.filter(item__tags__isnull=True) + self.assertQuerysetEqual( + q, + ['', ''], + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 2) + self.assertTrue('INNER JOIN' not in str(q.query)) + + q = Author.objects.filter(item__tags__parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', ''], + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 3) + self.assertTrue('INNER JOIN' not in str(q.query)) + + def test_ticket_10790_5(self): + # Querying with isnull=False across m2m field should not create outer joins + q = Author.objects.filter(item__tags__isnull=False) + self.assertQuerysetEqual( + q, + ['', '', '', '', ''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q.query).count('INNER JOIN') == 2) + + q = Author.objects.filter(item__tags__parent__isnull=False) + self.assertQuerysetEqual( + q, + ['', '', ''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q.query).count('INNER JOIN') == 3) + + q = Author.objects.filter(item__tags__parent__parent__isnull=False) + self.assertQuerysetEqual( + q, + [''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q.query).count('INNER JOIN') == 4) + + def test_ticket_10790_6(self): + # Querying with isnull=True across m2m field should not create inner joins + # and strip last outer join + q = Author.objects.filter(item__tags__parent__parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', '', + '', ''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 4) + self.assertTrue(str(q.query).count('INNER JOIN') == 0) + + q = Author.objects.filter(item__tags__parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', ''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 3) + self.assertTrue(str(q.query).count('INNER JOIN') == 0) + + def test_ticket_10790_7(self): + # Reverse querying with isnull should not strip the join + q = Author.objects.filter(item__isnull=True) + self.assertQuerysetEqual( + q, + [''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1) + self.assertTrue(str(q.query).count('INNER JOIN') == 0) + + q = Author.objects.filter(item__isnull=False) + self.assertQuerysetEqual( + q, + ['', '', '', ''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q.query).count('INNER JOIN') == 1) + + def test_ticket_10790_8(self): + # Querying with combined q-objects should also strip the left outer join + q = Tag.objects.filter(Q(parent__isnull=True) | Q(parent=self.t1)) + self.assertQuerysetEqual( + q, + ['', '', ''] + ) + self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q.query).count('INNER JOIN') == 0) + + def test_ticket_10790_combine(self): + # Combining queries should not re-populate the left outer join + q1 = Tag.objects.filter(parent__isnull=True) + q2 = Tag.objects.filter(parent__isnull=False) + + q3 = q1 | q2 + self.assertQuerysetEqual( + q3, + ['', '', '', '', ''], + ) + self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q3.query).count('INNER JOIN') == 0) + + q3 = q1 & q2 + self.assertQuerysetEqual(q3, []) + self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q3.query).count('INNER JOIN') == 0) + + q2 = Tag.objects.filter(parent=self.t1) + q3 = q1 | q2 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q3.query).count('INNER JOIN') == 0) + + q3 = q2 | q1 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0) + self.assertTrue(str(q3.query).count('INNER JOIN') == 0) + + q1 = Tag.objects.filter(parent__isnull=True) + q2 = Tag.objects.filter(parent__parent__isnull=True) + + q3 = q1 | q2 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 1) + self.assertTrue(str(q3.query).count('INNER JOIN') == 0) + + q3 = q2 | q1 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 1) + self.assertTrue(str(q3.query).count('INNER JOIN') == 0) + + class Queries2Tests(TestCase): def setUp(self): Number.objects.create(num=4) @@ -1037,6 +1260,10 @@ class Queries4Tests(BaseQuerysetTest): Item.objects.create(name='i2', created=datetime.datetime.now(), note=n1, creator=self.a3) def test_ticket14876(self): + # Note: when combining the query we need to have information available + # about the join type of the trimmed "creator__isnull" join. If we + # don't have that information, then the join is created as INNER JOIN + # and results will be incorrect. q1 = Report.objects.filter(Q(creator__isnull=True) | Q(creator__extra__info='e1')) q2 = Report.objects.filter(Q(creator__isnull=True)) | Report.objects.filter(Q(creator__extra__info='e1')) self.assertQuerysetEqual(q1, ["", ""], ordered=False) @@ -1405,17 +1632,19 @@ class NullableRelOrderingTests(TestCase): # the join type of already existing joins. Plaything.objects.create(name="p1") s = SingleObject.objects.create(name='s') - r = RelatedObject.objects.create(single=s) + r = RelatedObject.objects.create(single=s, f=1) Plaything.objects.create(name="p2", others=r) qs = Plaything.objects.all().filter(others__isnull=False).order_by('pk') + self.assertTrue('JOIN' not in str(qs.query)) + qs = Plaything.objects.all().filter(others__f__isnull=False).order_by('pk') self.assertTrue('INNER' in str(qs.query)) qs = qs.order_by('others__single__name') # The ordering by others__single__pk will add one new join (to single) # and that join must be LEFT join. The already existing join to related # objects must be kept INNER. So, we have both a INNER and a LEFT join # in the query. - self.assertTrue('LEFT' in str(qs.query)) - self.assertTrue('INNER' in str(qs.query)) + self.assertEquals(str(qs.query).count('LEFT'), 1) + self.assertEquals(str(qs.query).count('INNER'), 1) self.assertQuerysetEqual( qs, [''] @@ -1466,6 +1695,7 @@ class Queries6Tests(TestCase): # This next test used to cause really weird PostgreSQL behavior, but it was # only apparent much later when the full test suite ran. + # - Yeah, it leaves global ITER_CHUNK_SIZE to 2 instead of 100... #@unittest.expectedFailure def test_slicing_and_cache_interaction(self): # We can do slicing beyond what is currently in the result cache, @@ -1993,6 +2223,29 @@ class DefaultValuesInsertTest(TestCase): except TypeError: self.fail("Creation of an instance of a model with only the PK field shouldn't error out after bulk insert refactoring (#17056)") +class ExcludeTest(TestCase): + def setUp(self): + f1 = Food.objects.create(name='apples') + Food.objects.create(name='oranges') + Eaten.objects.create(food=f1, meal='dinner') + j1 = Job.objects.create(name='Manager') + r1 = Responsibility.objects.create(description='Playing golf') + j2 = Job.objects.create(name='Programmer') + r2 = Responsibility.objects.create(description='Programming') + JobResponsibilities.objects.create(job=j1, responsibility=r1) + JobResponsibilities.objects.create(job=j2, responsibility=r2) + + def test_to_field(self): + self.assertQuerysetEqual( + Food.objects.exclude(eaten__meal='dinner'), + ['']) + self.assertQuerysetEqual( + Job.objects.exclude(responsibilities__description='Playing golf'), + ['']) + self.assertQuerysetEqual( + Responsibility.objects.exclude(jobs__name='Manager'), + ['']) + class NullInExcludeTest(TestCase): def setUp(self): NullableName.objects.create(name='i1') @@ -2155,3 +2408,13 @@ class NullJoinPromotionOrTest(TestCase): # so we can use INNER JOIN for it. However, we can NOT use INNER JOIN # for the b->c join, as a->b is nullable. self.assertEqual(str(qset.query).count('INNER JOIN'), 1) + +class ReverseJoinTrimmingTest(TestCase): + def test_reverse_trimming(self): + # Check that we don't accidentally trim reverse joins - we can't know + # if there is anything on the other side of the join, so trimming + # reverse joins can't be done, ever. + t = Tag.objects.create() + qs = Tag.objects.filter(annotation__tag=t.pk) + self.assertIn('INNER JOIN', str(qs.query)) + self.assertEquals(list(qs), [])