From 425e4662a42e8ab324bcd056c1b0f16751692b3d Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Sun, 14 Oct 2007 02:15:28 +0000 Subject: [PATCH] queryset-refactor: Fixed the SQL construction when excluding items across nullable joins. This is #5324 plus a few more complex variations on that theme. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6494 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/sql/query.py | 34 +++++++++++++++---------- tests/regressiontests/queries/models.py | 29 ++++++++++++++++----- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 505f7124db..20fbadfe51 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -558,13 +558,14 @@ class Query(object): alias = self.join((None, opts.db_table, None, None)) dupe_multis = (connection == AND) join_list = [] - done_split = not self.where + split = not self.where + null_point = None # FIXME: Using enumerate() here is expensive. We only need 'i' to # check we aren't joining against a non-joinable field. Find a # better way to do this! for i, name in enumerate(parts): - joins, opts, orig_field, target_field, target_col = \ + joins, opts, orig_field, target_field, target_col, nullable = \ self.get_next_join(name, opts, alias, dupe_multis) if name == 'pk': name = target_field.name @@ -572,9 +573,11 @@ class Query(object): join_list.append(joins) last = joins alias = joins[-1] - if connection == OR and not done_split: + if not null_point and nullable: + null_point = len(join_list) + if connection == OR and not split: if self.alias_map[joins[0]][ALIAS_REFCOUNT] == 1: - done_split = True + split = True self.promote_alias(joins[0]) all_aliases = [] for a in join_list: @@ -611,10 +614,12 @@ class Query(object): self.where.add([alias, col, orig_field, lookup_type, value], connection) - if negate: + if negate and null_point: if join_list: - self.promote_alias(last[0]) + for join in last: + self.promote_alias(join) self.where.negate() + self.where.add([alias, col, orig_field, 'isnull', True], OR) def add_q(self, q_object): """ @@ -644,8 +649,9 @@ class Query(object): always create a new alias (necessary for disjunctive filters). Returns a list of aliases involved in the join, the next value for - 'opts' and the field class that was matched. For a non-joining field, - the first value (join alias) is None. + 'opts', the field instance that was matched, the new field to include + in the join, the column name on the rhs of the join and whether the + join can include NULL results. """ if name == 'pk': name = opts.pk.name @@ -660,7 +666,7 @@ class Query(object): field.m2m_reverse_name(), remote_opts.pk.column), dupe_multis, merge_separate=True) return ([int_alias, far_alias], remote_opts, field, remote_opts.pk, - None) + None, field.null) field = find_field(name, opts.get_all_related_many_to_many_objects(), True) @@ -675,7 +681,7 @@ class Query(object): dupe_multis, merge_separate=True) # XXX: Why is the final component able to be None here? return ([int_alias, far_alias], remote_opts, field, remote_opts.pk, - None) + None, True) field = find_field(name, opts.get_all_related_objects(), True) if field: @@ -686,7 +692,8 @@ class Query(object): alias = self.join((root_alias, remote_opts.db_table, local_field.column, field.column), dupe_multis, merge_separate=True) - return ([alias], remote_opts, field, field, remote_opts.pk.column) + return ([alias], remote_opts, field, field, remote_opts.pk.column, + True) field = find_field(name, opts.fields, False) @@ -701,11 +708,12 @@ class Query(object): target = field.rel.get_related_field() alias = self.join((root_alias, remote_opts.db_table, field.column, target.column)) - return [alias], remote_opts, field, target, target.column + return ([alias], remote_opts, field, target, target.column, + field.null) # Only remaining possibility is a normal (direct lookup) field. No # join is required. - return None, opts, field, field, None + return None, opts, field, field, None, False def set_limits(self, low=None, high=None): """ diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index 689304d57a..9fa9b1e1e6 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -3,7 +3,7 @@ Various combination queries that have been problematic in the past. """ from django.db import models -from django.db.models.query import Q +from django.db.models.query import Q, QNot class Tag(models.Model): name = models.CharField(maxlength=10) @@ -14,7 +14,7 @@ class Tag(models.Model): class Author(models.Model): name = models.CharField(maxlength=10) - num = models.IntegerField() + num = models.IntegerField(unique=True) def __unicode__(self): return self.name @@ -69,6 +69,8 @@ __test__ = {'API_TESTS':""" >>> r1 = Report(name='r1', creator=a1) >>> r1.save() +>>> r2 = Report(name='r2', creator=a3) +>>> r2.save() Bug #1050 >>> Item.objects.filter(tags__isnull=True) @@ -149,11 +151,26 @@ Bug #4510 Bug #5324 >>> Item.objects.filter(tags__name='t4') [] +>>> Item.objects.exclude(tags__name='t4').order_by('name').distinct() +[, , ] +>>> Author.objects.exclude(item__name='one').distinct().order_by('name') +[, , ] -# FIXME: We seem to be constructing the right SQL here, but maybe a NULL test -# for the pk of Tag is needed or something? -# >>> Item.objects.exclude(tags__name='t4').order_by('name').distinct() -# [, , ] +# Excluding from a relation that cannot be NULL should not use outer joins. +>>> query = Item.objects.exclude(creator__in=[a1, a2]).query +>>> query.LOUTER not in [x[2][2] for x in query.alias_map.values()] +True + +# When only one of the joins is nullable (here, the Author -> Item join), we +# should only get outer joins after that point (one, in this case). We also +# show that three tables (so, two joins) are involved. +>>> qs = Report.objects.exclude(creator__item__name='one') +>>> list(qs) +[] +>>> len([x[2][2] for x in qs.query.alias_map.values() if x[2][2] == query.LOUTER]) +1 +>>> len(qs.query.alias_map) +3 Bug #2091 >>> t = Tag.objects.get(name='t4')