diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 4ebf48beab..f6594d617a 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -100,6 +100,9 @@ class RelatedField(object): pass return v + if hasattr(value, 'as_sql'): + sql, params = value.as_sql() + return ('(%s)' % sql), params if lookup_type == 'exact': return [pk_trace(value)] if lookup_type == 'in': diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 205735f940..c4494a1b31 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -45,6 +45,7 @@ class Query(object): self.default_cols = True self.default_ordering = True self.standard_ordering = True + self.start_meta = None # SQL-related attributes self.select = [] @@ -81,6 +82,20 @@ class Query(object): sql, params = self.as_sql() return sql % params + def __deepcopy__(self, memo): + result= self.clone() + memo[id(self)] = result + return result + + def get_meta(self): + """ + Returns the Options instance (the model._meta) from which to start + processing. Normally, this is self.model._meta, but it can change. + """ + if self.start_meta: + return self.start_meta + return self.model._meta + def quote_name_unless_alias(self, name): """ A wrapper around connection.ops.quote_name that doesn't quote aliases @@ -114,6 +129,7 @@ class Query(object): obj.default_cols = self.default_cols obj.default_ordering = self.default_ordering obj.standard_ordering = self.standard_ordering + obj.start_meta = self.start_meta obj.select = self.select[:] obj.tables = self.tables[:] obj.where = copy.deepcopy(self.where) @@ -384,7 +400,7 @@ class Query(object): join_type = None alias_str = '' name = alias - if join_type: + if join_type and not first: result.append('%s %s%s ON (%s.%s = %s.%s)' % (join_type, qn(name), alias_str, qn(lhs), qn(lhs_col), qn(alias), qn(col))) @@ -649,7 +665,7 @@ class Query(object): # We've recursed far enough; bail out. return if not opts: - opts = self.model._meta + opts = self.get_meta() root_alias = self.tables[0] self.select.extend([(root_alias, f.column) for f in opts.fields]) if not used: @@ -681,9 +697,14 @@ class Query(object): self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, used, next, restricted) - def add_filter(self, filter_expr, connector=AND, negate=False): + def add_filter(self, filter_expr, connector=AND, negate=False, trim=False): """ - Add a single filter to the query. + Add a single filter to the query. The 'filter_expr' is a pair: + (filter_string, value). E.g. ('name__contains', 'fred') + + If 'negate' is True, this is an exclude() filter. If 'trim' is True, we + automatically trim the final join group (used internally when + constructing nested queries). """ arg, value = filter_expr parts = arg.split(LOOKUP_SEP) @@ -706,12 +727,24 @@ class Query(object): elif callable(value): value = value() - opts = self.model._meta + opts = self.get_meta() alias = self.join((None, opts.db_table, None, None)) + allow_many = trim or not negate - field, target, opts, join_list, = self.setup_joins(parts, opts, - alias, (connector == AND)) - col = target.column + result = self.setup_joins(parts, opts, alias, (connector == AND), + allow_many) + if isinstance(result, int): + self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:result])) + return + field, target, opts, join_list = result + if trim and len(join_list) > 1: + extra = join_list[-1] + join_list = join_list[:-1] + col = self.alias_map[extra[0]][ALIAS_JOIN][LHS_JOIN_COL] + for alias in extra: + self.unref_alias(alias) + else: + col = target.column alias = join_list[-1][-1] if join_list: @@ -729,7 +762,7 @@ class Query(object): len(join_list[0]) > 1): # If the comparison is against NULL, we need to use a left outer # join when connecting to the previous model. We make that - # adjustment here. We don't do this unless needed because it's less + # adjustment here. We don't do this unless needed as it's less # efficient at the database level. self.promote_alias(join_list[-1][0]) @@ -767,6 +800,8 @@ class Query(object): flag = True self.where.negate() if flag: + # XXX: Change this to the field we joined against to allow + # for node sharing and where-tree optimisation? self.where.add([alias, col, field, 'isnull', True], OR) def add_q(self, q_object): @@ -797,7 +832,7 @@ class Query(object): if subtree: self.where.end_subtree() - def setup_joins(self, names, opts, alias, dupe_multis): + def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True): """ Compute the necessary table joins for the passage through the fields given in 'names'. 'opts' is the Options class for the current model @@ -807,9 +842,8 @@ class Query(object): disjunctive filters). Returns the final field involved in the join, the target database - column (used for any 'where' constraint), the final 'opts' value, the - list of tables joined and a list indicating whether or not each join - can be null. + column (used for any 'where' constraint), the final 'opts' value and the + list of tables joined. """ joins = [[alias]] for pos, name in enumerate(names): @@ -822,6 +856,11 @@ class Query(object): names = opts.get_all_field_names() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) + if not allow_many and (m2m or not direct): + for join in joins: + for alias in join: + self.unref_alias(alias) + return pos + 1 if model: # The field lives on a base class of the current model. alias_list = [] @@ -929,6 +968,19 @@ class Query(object): return field, target, opts, joins + def split_exclude(self, filter_expr, prefix): + """ + When doing an exclude against any kind of N-to-many relation, we need + to use a subquery. This method constructs the nested query, given the + original exclude filter (filter_expr) and the portion up to the first + N-to-many relation field. + """ + query = Query(self.model, self.connection) + query.add_filter(filter_expr) + query.set_start(prefix) + query.clear_ordering(True) + self.add_filter(('%s__in' % prefix, query), negate=True, trim=True) + def set_limits(self, low=None, high=None): """ Adjusts the limits on the rows retrieved. We use low/high to set these, @@ -1047,6 +1099,35 @@ class Query(object): d = d.setdefault(part, {}) self.select_related = field_dict + def set_start(self, start): + """ + Sets the table from which to start joining. The start position is + specified by the related attribute from the base model. This will + automatically set to the select column to be the column linked from the + previous table. + + This method is primarily for internal use and the error checking isn't + as friendly as add_filter(). Mostly useful for querying directly + against the join table of many-to-many relation in a subquery. + """ + opts = self.model._meta + alias = self.join((None, opts.db_table, None, None)) + field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP), + opts, alias, False) + alias = joins[-1][0] + self.select = [(alias, self.alias_map[alias][ALIAS_JOIN][RHS_JOIN_COL])] + self.start_meta = opts + + # The call to setup_joins add an extra reference to everything in + # joins. So we need to unref everything once, and everything prior to + # the final join a second time. + for join in joins[:-1]: + for alias in join: + self.unref_alias(alias) + self.unref_alias(alias) + for alias in joins[-1]: + self.unref_alias(alias) + def execute_sql(self, result_type=MULTI): """ Run the query against the database and returns the result(s). The diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index e699e96375..c4d5637658 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -105,21 +105,27 @@ class WhereNode(tree.Node): else: cast_sql = '%s' - format = "%s %%s" % connection.ops.lookup_cast(lookup_type) params = field.get_db_prep_lookup(lookup_type, value) + if isinstance(params, tuple): + extra, params = params + else: + extra = '' if lookup_type in connection.operators: + format = "%s %%s %s" % (connection.ops.lookup_cast(lookup_type), + extra) return (format % (field_sql, connection.operators[lookup_type] % cast_sql), params) if lookup_type == 'in': if not value: raise EmptyResultSet + if extra: + return ('%s IN %s' % (field_sql, extra), params) return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), params) elif lookup_type in ('range', 'year'): - return ('%s BETWEEN %%s and %%s' % field_sql, - params) + return ('%s BETWEEN %%s and %%s' % field_sql, params) elif lookup_type in ('month', 'day'): return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql), params) diff --git a/tests/modeltests/many_to_many/models.py b/tests/modeltests/many_to_many/models.py index 198c95c4d5..e09fd825f8 100644 --- a/tests/modeltests/many_to_many/models.py +++ b/tests/modeltests/many_to_many/models.py @@ -126,6 +126,11 @@ __test__ = {'API_TESTS':""" >>> Publication.objects.filter(article__in=[a1,a2]).distinct() [, , , ] +# Excluding a related item works as you would expect, too (although the SQL +# involved is a little complex). +>>> Article.objects.exclude(publications=p2) +[] + # If we delete a Publication, its Articles won't be able to access it. >>> p1.delete() >>> Publication.objects.all() diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index c95026a54e..50c002d14d 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -322,22 +322,19 @@ Bug #5324 >>> Author.objects.exclude(item__name='one').distinct().order_by('name') [, , ] + +# Excluding across a m2m relation when there is more than one related object +# associated was problematic. +>>> Item.objects.exclude(tags__name='t1').order_by('name') +[, ] +>>> Item.objects.exclude(tags__name='t1').exclude(tags__name='t4') +[] + # Excluding from a relation that cannot be NULL should not use outer joins. >>> query = Item.objects.exclude(creator__in=[a1, a2]).query >>> query.LOUTER not in [x[2][2] for x in query.alias_map.values()] True -# When only one of the joins is nullable (here, the Author -> Item join), we -# should only get outer joins after that point (one, in this case). We also -# show that three tables (so, two joins) are involved. ->>> qs = Report.objects.exclude(creator__item__name='one') ->>> list(qs) -[] ->>> len([x[2][2] for x in qs.query.alias_map.values() if x[2][2] == query.LOUTER]) -1 ->>> len(qs.query.alias_map) -3 - Similarly, when one of the joins cannot possibly, ever, involve NULL values (Author -> ExtraInfo, in the following), it should never be promoted to a left outer join. So hte following query should only involve one "left outer" join (Author -> Item is 0-to-many). >>> qs = Author.objects.filter(id=a1.id).filter(Q(extra__note=n1)|Q(item__note=n3)) >>> len([x[2][2] for x in qs.query.alias_map.values() if x[2][2] == query.LOUTER])