From 33a0862215edeaa7848e625cd0b0777fb4885de1 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Thu, 28 Feb 2008 12:57:10 +0000 Subject: [PATCH] queryset-refactor: Fixed exclude() filtering for the various N-to-many relations. This means we can now do nested SQL queries (since we need nested queries to get the right answer). It requires poking directly at the Query class. Might add support for this through QuerySets later. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7170 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/fields/related.py | 3 + django/db/models/sql/query.py | 107 +++++++++++++++++++++--- django/db/models/sql/where.py | 12 ++- tests/modeltests/many_to_many/models.py | 5 ++ tests/regressiontests/queries/models.py | 19 ++--- 5 files changed, 119 insertions(+), 27 deletions(-) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 4ebf48beab..f6594d617a 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -100,6 +100,9 @@ class RelatedField(object): pass return v + if hasattr(value, 'as_sql'): + sql, params = value.as_sql() + return ('(%s)' % sql), params if lookup_type == 'exact': return [pk_trace(value)] if lookup_type == 'in': diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 205735f940..c4494a1b31 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -45,6 +45,7 @@ class Query(object): self.default_cols = True self.default_ordering = True self.standard_ordering = True + self.start_meta = None # SQL-related attributes self.select = [] @@ -81,6 +82,20 @@ class Query(object): sql, params = self.as_sql() return sql % params + def __deepcopy__(self, memo): + result= self.clone() + memo[id(self)] = result + return result + + def get_meta(self): + """ + Returns the Options instance (the model._meta) from which to start + processing. Normally, this is self.model._meta, but it can change. + """ + if self.start_meta: + return self.start_meta + return self.model._meta + def quote_name_unless_alias(self, name): """ A wrapper around connection.ops.quote_name that doesn't quote aliases @@ -114,6 +129,7 @@ class Query(object): obj.default_cols = self.default_cols obj.default_ordering = self.default_ordering obj.standard_ordering = self.standard_ordering + obj.start_meta = self.start_meta obj.select = self.select[:] obj.tables = self.tables[:] obj.where = copy.deepcopy(self.where) @@ -384,7 +400,7 @@ class Query(object): join_type = None alias_str = '' name = alias - if join_type: + if join_type and not first: result.append('%s %s%s ON (%s.%s = %s.%s)' % (join_type, qn(name), alias_str, qn(lhs), qn(lhs_col), qn(alias), qn(col))) @@ -649,7 +665,7 @@ class Query(object): # We've recursed far enough; bail out. return if not opts: - opts = self.model._meta + opts = self.get_meta() root_alias = self.tables[0] self.select.extend([(root_alias, f.column) for f in opts.fields]) if not used: @@ -681,9 +697,14 @@ class Query(object): self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, used, next, restricted) - def add_filter(self, filter_expr, connector=AND, negate=False): + def add_filter(self, filter_expr, connector=AND, negate=False, trim=False): """ - Add a single filter to the query. + Add a single filter to the query. The 'filter_expr' is a pair: + (filter_string, value). E.g. ('name__contains', 'fred') + + If 'negate' is True, this is an exclude() filter. If 'trim' is True, we + automatically trim the final join group (used internally when + constructing nested queries). """ arg, value = filter_expr parts = arg.split(LOOKUP_SEP) @@ -706,12 +727,24 @@ class Query(object): elif callable(value): value = value() - opts = self.model._meta + opts = self.get_meta() alias = self.join((None, opts.db_table, None, None)) + allow_many = trim or not negate - field, target, opts, join_list, = self.setup_joins(parts, opts, - alias, (connector == AND)) - col = target.column + result = self.setup_joins(parts, opts, alias, (connector == AND), + allow_many) + if isinstance(result, int): + self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:result])) + return + field, target, opts, join_list = result + if trim and len(join_list) > 1: + extra = join_list[-1] + join_list = join_list[:-1] + col = self.alias_map[extra[0]][ALIAS_JOIN][LHS_JOIN_COL] + for alias in extra: + self.unref_alias(alias) + else: + col = target.column alias = join_list[-1][-1] if join_list: @@ -729,7 +762,7 @@ class Query(object): len(join_list[0]) > 1): # If the comparison is against NULL, we need to use a left outer # join when connecting to the previous model. We make that - # adjustment here. We don't do this unless needed because it's less + # adjustment here. We don't do this unless needed as it's less # efficient at the database level. self.promote_alias(join_list[-1][0]) @@ -767,6 +800,8 @@ class Query(object): flag = True self.where.negate() if flag: + # XXX: Change this to the field we joined against to allow + # for node sharing and where-tree optimisation? self.where.add([alias, col, field, 'isnull', True], OR) def add_q(self, q_object): @@ -797,7 +832,7 @@ class Query(object): if subtree: self.where.end_subtree() - def setup_joins(self, names, opts, alias, dupe_multis): + def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True): """ Compute the necessary table joins for the passage through the fields given in 'names'. 'opts' is the Options class for the current model @@ -807,9 +842,8 @@ class Query(object): disjunctive filters). Returns the final field involved in the join, the target database - column (used for any 'where' constraint), the final 'opts' value, the - list of tables joined and a list indicating whether or not each join - can be null. + column (used for any 'where' constraint), the final 'opts' value and the + list of tables joined. """ joins = [[alias]] for pos, name in enumerate(names): @@ -822,6 +856,11 @@ class Query(object): names = opts.get_all_field_names() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) + if not allow_many and (m2m or not direct): + for join in joins: + for alias in join: + self.unref_alias(alias) + return pos + 1 if model: # The field lives on a base class of the current model. alias_list = [] @@ -929,6 +968,19 @@ class Query(object): return field, target, opts, joins + def split_exclude(self, filter_expr, prefix): + """ + When doing an exclude against any kind of N-to-many relation, we need + to use a subquery. This method constructs the nested query, given the + original exclude filter (filter_expr) and the portion up to the first + N-to-many relation field. + """ + query = Query(self.model, self.connection) + query.add_filter(filter_expr) + query.set_start(prefix) + query.clear_ordering(True) + self.add_filter(('%s__in' % prefix, query), negate=True, trim=True) + def set_limits(self, low=None, high=None): """ Adjusts the limits on the rows retrieved. We use low/high to set these, @@ -1047,6 +1099,35 @@ class Query(object): d = d.setdefault(part, {}) self.select_related = field_dict + def set_start(self, start): + """ + Sets the table from which to start joining. The start position is + specified by the related attribute from the base model. This will + automatically set to the select column to be the column linked from the + previous table. + + This method is primarily for internal use and the error checking isn't + as friendly as add_filter(). Mostly useful for querying directly + against the join table of many-to-many relation in a subquery. + """ + opts = self.model._meta + alias = self.join((None, opts.db_table, None, None)) + field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP), + opts, alias, False) + alias = joins[-1][0] + self.select = [(alias, self.alias_map[alias][ALIAS_JOIN][RHS_JOIN_COL])] + self.start_meta = opts + + # The call to setup_joins add an extra reference to everything in + # joins. So we need to unref everything once, and everything prior to + # the final join a second time. + for join in joins[:-1]: + for alias in join: + self.unref_alias(alias) + self.unref_alias(alias) + for alias in joins[-1]: + self.unref_alias(alias) + def execute_sql(self, result_type=MULTI): """ Run the query against the database and returns the result(s). The diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index e699e96375..c4d5637658 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -105,21 +105,27 @@ class WhereNode(tree.Node): else: cast_sql = '%s' - format = "%s %%s" % connection.ops.lookup_cast(lookup_type) params = field.get_db_prep_lookup(lookup_type, value) + if isinstance(params, tuple): + extra, params = params + else: + extra = '' if lookup_type in connection.operators: + format = "%s %%s %s" % (connection.ops.lookup_cast(lookup_type), + extra) return (format % (field_sql, connection.operators[lookup_type] % cast_sql), params) if lookup_type == 'in': if not value: raise EmptyResultSet + if extra: + return ('%s IN %s' % (field_sql, extra), params) return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), params) elif lookup_type in ('range', 'year'): - return ('%s BETWEEN %%s and %%s' % field_sql, - params) + return ('%s BETWEEN %%s and %%s' % field_sql, params) elif lookup_type in ('month', 'day'): return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql), params) diff --git a/tests/modeltests/many_to_many/models.py b/tests/modeltests/many_to_many/models.py index 198c95c4d5..e09fd825f8 100644 --- a/tests/modeltests/many_to_many/models.py +++ b/tests/modeltests/many_to_many/models.py @@ -126,6 +126,11 @@ __test__ = {'API_TESTS':""" >>> Publication.objects.filter(article__in=[a1,a2]).distinct() [, , , ] +# Excluding a related item works as you would expect, too (although the SQL +# involved is a little complex). +>>> Article.objects.exclude(publications=p2) +[] + # If we delete a Publication, its Articles won't be able to access it. >>> p1.delete() >>> Publication.objects.all() diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index c95026a54e..50c002d14d 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -322,22 +322,19 @@ Bug #5324 >>> Author.objects.exclude(item__name='one').distinct().order_by('name') [, , ] + +# Excluding across a m2m relation when there is more than one related object +# associated was problematic. +>>> Item.objects.exclude(tags__name='t1').order_by('name') +[, ] +>>> Item.objects.exclude(tags__name='t1').exclude(tags__name='t4') +[] + # Excluding from a relation that cannot be NULL should not use outer joins. >>> query = Item.objects.exclude(creator__in=[a1, a2]).query >>> query.LOUTER not in [x[2][2] for x in query.alias_map.values()] True -# When only one of the joins is nullable (here, the Author -> Item join), we -# should only get outer joins after that point (one, in this case). We also -# show that three tables (so, two joins) are involved. ->>> qs = Report.objects.exclude(creator__item__name='one') ->>> list(qs) -[] ->>> len([x[2][2] for x in qs.query.alias_map.values() if x[2][2] == query.LOUTER]) -1 ->>> len(qs.query.alias_map) -3 - Similarly, when one of the joins cannot possibly, ever, involve NULL values (Author -> ExtraInfo, in the following), it should never be promoted to a left outer join. So hte following query should only involve one "left outer" join (Author -> Item is 0-to-many). >>> qs = Author.objects.filter(id=a1.id).filter(Q(extra__note=n1)|Q(item__note=n3)) >>> len([x[2][2] for x in qs.query.alias_map.values() if x[2][2] == query.LOUTER])