diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 9678ba4154..a8d402dcfd 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -692,6 +692,11 @@ class ForeignKey(RelatedField, Field): def contribute_to_class(self, cls, name): super(ForeignKey, self).contribute_to_class(cls, name) setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self)) + if isinstance(self.rel.to, basestring): + target = self.rel.to + else: + target = self.rel.to._meta.db_table + cls._meta.duplicate_targets[self.column] = (target, "o2m") def contribute_to_related_class(self, cls, related): setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) @@ -826,6 +831,12 @@ class ManyToManyField(RelatedField, Field): # Set up the accessor for the m2m table name for the relation self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) + if isinstance(self.rel.to, basestring): + target = self.rel.to + else: + target = self.rel.to._meta.db_table + cls._meta.duplicate_targets[self.column] = (target, "m2m") + def contribute_to_related_class(self, cls, related): # m2m relations to self do not have a ManyRelatedObjectsDescriptor, # as it would be redundant - unless the field is non-symmetrical. diff --git a/django/db/models/options.py b/django/db/models/options.py index e5b30b4746..a81a34d722 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -44,6 +44,7 @@ class Options(object): self.one_to_one_field = None self.abstract = False self.parents = SortedDict() + self.duplicate_targets = {} def contribute_to_class(self, cls, name): from django.db import connection @@ -115,6 +116,24 @@ class Options(object): auto_created=True) model.add_to_class('id', auto) + # Determine any sets of fields that are pointing to the same targets + # (e.g. two ForeignKeys to the same remote model). The query + # construction code needs to know this. At the end of this, + # self.duplicate_targets will map each duplicate field column to the + # columns it duplicates. + collections = {} + for column, target in self.duplicate_targets.iteritems(): + try: + collections[target].add(column) + except KeyError: + collections[target] = set([column]) + self.duplicate_targets = {} + for elt in collections.itervalues(): + if len(elt) == 1: + continue + for column in elt: + self.duplicate_targets[column] = elt.difference(set([column])) + def add_field(self, field): # Insert the given field in the order in which it was created, using # the "creation_counter" attribute of the field. diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fc2a6b2b5d..a50f1f0c15 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -57,6 +57,7 @@ class Query(object): self.start_meta = None self.select_fields = [] self.related_select_fields = [] + self.dupe_avoidance = {} # SQL-related attributes self.select = [] @@ -165,6 +166,7 @@ class Query(object): obj.start_meta = self.start_meta obj.select_fields = self.select_fields[:] obj.related_select_fields = self.related_select_fields[:] + obj.dupe_avoidance = self.dupe_avoidance.copy() obj.select = self.select[:] obj.tables = self.tables[:] obj.where = deepcopy(self.where) @@ -830,8 +832,8 @@ class Query(object): if reuse and always_create and table in self.table_map: # Convert the 'reuse' to case to be "exclude everything but the - # reusable set for this table". - exclusions = set(self.table_map[table]).difference(reuse) + # reusable set, minus exclusions, for this table". + exclusions = set(self.table_map[table]).difference(reuse).union(set(exclusions)) always_create = False t_ident = (lhs_table, table, lhs_col, col) if not always_create: @@ -866,7 +868,8 @@ class Query(object): return alias def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, - used=None, requested=None, restricted=None, nullable=None): + used=None, requested=None, restricted=None, nullable=None, + dupe_set=None): """ Fill in the information needed for a select_related query. The current depth is measured as the number of connections away from the root model @@ -876,6 +879,7 @@ class Query(object): if not restricted and self.max_depth and cur_depth > self.max_depth: # We've recursed far enough; bail out. return + if not opts: opts = self.get_meta() root_alias = self.get_initial_alias() @@ -883,6 +887,10 @@ class Query(object): self.related_select_fields = [] if not used: used = set() + if dupe_set is None: + dupe_set = set() + orig_dupe_set = dupe_set + orig_used = used # Setup for the case when only particular related fields should be # included in the related selection. @@ -897,6 +905,8 @@ class Query(object): if (not f.rel or (restricted and f.name not in requested) or (not restricted and f.null) or f.rel.parent_link): continue + dupe_set = orig_dupe_set.copy() + used = orig_used.copy() table = f.rel.to._meta.db_table if nullable or f.null: promote = True @@ -907,12 +917,26 @@ class Query(object): alias = root_alias for int_model in opts.get_base_chain(model): lhs_col = int_opts.parents[int_model].column + dedupe = lhs_col in opts.duplicate_targets + if dedupe: + used.update(self.dupe_avoidance.get(id(opts), lhs_col), + ()) + dupe_set.add((opts, lhs_col)) int_opts = int_model._meta alias = self.join((alias, int_opts.db_table, lhs_col, int_opts.pk.column), exclusions=used, promote=promote) + for (dupe_opts, dupe_col) in dupe_set: + self.update_dupe_avoidance(dupe_opts, dupe_col, alias) else: alias = root_alias + + dedupe = f.column in opts.duplicate_targets + if dupe_set or dedupe: + used.update(self.dupe_avoidance.get((id(opts), f.column), ())) + if dedupe: + dupe_set.add((opts, f.column)) + alias = self.join((alias, table, f.column, f.rel.get_related_field().column), exclusions=used, promote=promote) @@ -928,8 +952,10 @@ class Query(object): new_nullable = f.null else: new_nullable = None + for dupe_opts, dupe_col in dupe_set: + self.update_dupe_avoidance(dupe_opts, dupe_col, alias) self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, - used, next, restricted, new_nullable) + used, next, restricted, new_nullable, dupe_set) def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, can_reuse=None): @@ -1128,7 +1154,9 @@ class Query(object): (which gives the table we are joining to), 'alias' is the alias for the table we are joining to. If dupe_multis is True, any many-to-many or many-to-one joins will always create a new alias (necessary for - disjunctive filters). + disjunctive filters). If can_reuse is not None, it's a list of aliases + that can be reused in these joins (nothing else can be reused in this + case). Returns the final field involved in the join, the target database column (used for any 'where' constraint), the final 'opts' value and the @@ -1136,7 +1164,14 @@ class Query(object): """ joins = [alias] last = [0] + dupe_set = set() + exclusions = set() for pos, name in enumerate(names): + try: + exclusions.add(int_alias) + except NameError: + pass + exclusions.add(alias) last.append(len(joins)) if name == 'pk': name = opts.pk.name @@ -1155,6 +1190,7 @@ class Query(object): names = opts.get_all_field_names() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) + if not allow_many and (m2m or not direct): for alias in joins: self.unref_alias(alias) @@ -1164,12 +1200,27 @@ class Query(object): alias_list = [] for int_model in opts.get_base_chain(model): lhs_col = opts.parents[int_model].column + dedupe = lhs_col in opts.duplicate_targets + if dedupe: + exclusions.update(self.dupe_avoidance.get( + (id(opts), lhs_col), ())) + dupe_set.add((opts, lhs_col)) opts = int_model._meta alias = self.join((alias, opts.db_table, lhs_col, - opts.pk.column), exclusions=joins) + opts.pk.column), exclusions=exclusions) joins.append(alias) + exclusions.add(alias) + for (dupe_opts, dupe_col) in dupe_set: + self.update_dupe_avoidance(dupe_opts, dupe_col, alias) cached_data = opts._join_cache.get(name) orig_opts = opts + dupe_col = direct and field.column or field.field.column + dedupe = dupe_col in opts.duplicate_targets + if dupe_set or dedupe: + if dedupe: + dupe_set.add((opts, dupe_col)) + exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col), + ())) if direct: if m2m: @@ -1191,9 +1242,11 @@ class Query(object): target) int_alias = self.join((alias, table1, from_col1, to_col1), - dupe_multis, joins, nullable=True, reuse=can_reuse) + dupe_multis, exclusions, nullable=True, + reuse=can_reuse) alias = self.join((int_alias, table2, from_col2, to_col2), - dupe_multis, joins, nullable=True, reuse=can_reuse) + dupe_multis, exclusions, nullable=True, + reuse=can_reuse) joins.extend([int_alias, alias]) elif field.rel: # One-to-one or many-to-one field @@ -1209,7 +1262,7 @@ class Query(object): opts, target) alias = self.join((alias, table, from_col, to_col), - exclusions=joins, nullable=field.null) + exclusions=exclusions, nullable=field.null) joins.append(alias) else: # Non-relation fields. @@ -1237,9 +1290,11 @@ class Query(object): target) int_alias = self.join((alias, table1, from_col1, to_col1), - dupe_multis, joins, nullable=True, reuse=can_reuse) + dupe_multis, exclusions, nullable=True, + reuse=can_reuse) alias = self.join((int_alias, table2, from_col2, to_col2), - dupe_multis, joins, nullable=True, reuse=can_reuse) + dupe_multis, exclusions, nullable=True, + reuse=can_reuse) joins.extend([int_alias, alias]) else: # One-to-many field (ForeignKey defined on the target model) @@ -1257,14 +1312,34 @@ class Query(object): opts, target) alias = self.join((alias, table, from_col, to_col), - dupe_multis, joins, nullable=True, reuse=can_reuse) + dupe_multis, exclusions, nullable=True, + reuse=can_reuse) joins.append(alias) + for (dupe_opts, dupe_col) in dupe_set: + try: + self.update_dupe_avoidance(dupe_opts, dupe_col, int_alias) + except NameError: + self.update_dupe_avoidance(dupe_opts, dupe_col, alias) + if pos != len(names) - 1: raise FieldError("Join on field %r not permitted." % name) return field, target, opts, joins, last + def update_dupe_avoidance(self, opts, col, alias): + """ + For a column that is one of multiple pointing to the same table, update + the internal data structures to note that this alias shouldn't be used + for those other columns. + """ + ident = id(opts) + for name in opts.duplicate_targets[col]: + try: + self.dupe_avoidance[ident, name].add(alias) + except KeyError: + self.dupe_avoidance[ident, name] = set([alias]) + def split_exclude(self, filter_expr, prefix): """ When doing an exclude against any kind of N-to-many relation, we need diff --git a/tests/regressiontests/many_to_one_regress/models.py b/tests/regressiontests/many_to_one_regress/models.py index 4e49df1555..429bdd7558 100644 --- a/tests/regressiontests/many_to_one_regress/models.py +++ b/tests/regressiontests/many_to_one_regress/models.py @@ -28,6 +28,24 @@ class Child(models.Model): parent = models.ForeignKey(Parent) +# Multiple paths to the same model (#7110, #7125) +class Category(models.Model): + name = models.CharField(max_length=20) + + def __unicode__(self): + return self.name + +class Record(models.Model): + category = models.ForeignKey(Category) + +class Relation(models.Model): + left = models.ForeignKey(Record, related_name='left_set') + right = models.ForeignKey(Record, related_name='right_set') + + def __unicode__(self): + return u"%s - %s" % (self.left.category.name, self.right.category.name) + + __test__ = {'API_TESTS':""" >>> Third.objects.create(id='3', name='An example') @@ -73,4 +91,26 @@ Traceback (most recent call last): ... ValueError: Cannot assign "": "Child.parent" must be a "Parent" instance. +# Test of multiple ForeignKeys to the same model (bug #7125) + +>>> c1 = Category.objects.create(name='First') +>>> c2 = Category.objects.create(name='Second') +>>> c3 = Category.objects.create(name='Third') +>>> r1 = Record.objects.create(category=c1) +>>> r2 = Record.objects.create(category=c1) +>>> r3 = Record.objects.create(category=c2) +>>> r4 = Record.objects.create(category=c2) +>>> r5 = Record.objects.create(category=c3) +>>> r = Relation.objects.create(left=r1, right=r2) +>>> r = Relation.objects.create(left=r3, right=r4) +>>> r = Relation.objects.create(left=r1, right=r3) +>>> r = Relation.objects.create(left=r5, right=r2) +>>> r = Relation.objects.create(left=r3, right=r2) + +>>> Relation.objects.filter(left__category__name__in=['First'], right__category__name__in=['Second']) +[] + +>>> Category.objects.filter(record__left_set__right__category__name='Second').order_by('name') +[, ] + """} diff --git a/tests/regressiontests/select_related_regress/__init__.py b/tests/regressiontests/select_related_regress/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/regressiontests/select_related_regress/models.py b/tests/regressiontests/select_related_regress/models.py new file mode 100644 index 0000000000..1688053e2d --- /dev/null +++ b/tests/regressiontests/select_related_regress/models.py @@ -0,0 +1,60 @@ +from django.db import models + +class Building(models.Model): + name = models.CharField(max_length=10) + + def __unicode__(self): + return u"Building: %s" % self.name + +class Device(models.Model): + building = models.ForeignKey('Building') + name = models.CharField(max_length=10) + + def __unicode__(self): + return u"device '%s' in building %s" % (self.name, self.building) + +class Port(models.Model): + device = models.ForeignKey('Device') + number = models.CharField(max_length=10) + + def __unicode__(self): + return u"%s/%s" % (self.device.name, self.number) + +class Connection(models.Model): + start = models.ForeignKey(Port, related_name='connection_start', + unique=True) + end = models.ForeignKey(Port, related_name='connection_end', unique=True) + + def __unicode__(self): + return u"%s to %s" % (self.start, self.end) + +__test__ = {'API_TESTS': """ +Regression test for bug #7110. When using select_related(), we must query the +Device and Building tables using two different aliases (each) in order to +differentiate the start and end Connection fields. The net result is that both +the "connections = ..." queries here should give the same results. + +>>> b=Building.objects.create(name='101') +>>> dev1=Device.objects.create(name="router", building=b) +>>> dev2=Device.objects.create(name="switch", building=b) +>>> dev3=Device.objects.create(name="server", building=b) +>>> port1=Port.objects.create(number='4',device=dev1) +>>> port2=Port.objects.create(number='7',device=dev2) +>>> port3=Port.objects.create(number='1',device=dev3) +>>> c1=Connection.objects.create(start=port1, end=port2) +>>> c2=Connection.objects.create(start=port2, end=port3) + +>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).order_by('id') +>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections] +[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')] + +>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).select_related().order_by('id') +>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections] +[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')] + +# This final query should only join seven tables (port, device and building +# twice each, plus connection once). +>>> connections.query.count_active_tables() +7 + +"""}