mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed handling of multiple fields in a model pointing to the same related model.
Thanks to ElliotM, mk and oyvind for some excellent test cases for this. Fixed #7110, #7125. git-svn-id: http://code.djangoproject.com/svn/django/trunk@7778 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -692,6 +692,11 @@ class ForeignKey(RelatedField, Field): | ||||
|     def contribute_to_class(self, cls, name): | ||||
|         super(ForeignKey, self).contribute_to_class(cls, name) | ||||
|         setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self)) | ||||
|         if isinstance(self.rel.to, basestring): | ||||
|             target = self.rel.to | ||||
|         else: | ||||
|             target = self.rel.to._meta.db_table | ||||
|         cls._meta.duplicate_targets[self.column] = (target, "o2m") | ||||
|  | ||||
|     def contribute_to_related_class(self, cls, related): | ||||
|         setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) | ||||
| @@ -826,6 +831,12 @@ class ManyToManyField(RelatedField, Field): | ||||
|         # Set up the accessor for the m2m table name for the relation | ||||
|         self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) | ||||
|  | ||||
|         if isinstance(self.rel.to, basestring): | ||||
|             target = self.rel.to | ||||
|         else: | ||||
|             target = self.rel.to._meta.db_table | ||||
|         cls._meta.duplicate_targets[self.column] = (target, "m2m") | ||||
|  | ||||
|     def contribute_to_related_class(self, cls, related): | ||||
|         # m2m relations to self do not have a ManyRelatedObjectsDescriptor, | ||||
|         # as it would be redundant - unless the field is non-symmetrical. | ||||
|   | ||||
| @@ -44,6 +44,7 @@ class Options(object): | ||||
|         self.one_to_one_field = None | ||||
|         self.abstract = False | ||||
|         self.parents = SortedDict() | ||||
|         self.duplicate_targets = {} | ||||
|  | ||||
|     def contribute_to_class(self, cls, name): | ||||
|         from django.db import connection | ||||
| @@ -115,6 +116,24 @@ class Options(object): | ||||
|                         auto_created=True) | ||||
|                 model.add_to_class('id', auto) | ||||
|  | ||||
|         # Determine any sets of fields that are pointing to the same targets | ||||
|         # (e.g. two ForeignKeys to the same remote model). The query | ||||
|         # construction code needs to know this. At the end of this, | ||||
|         # self.duplicate_targets will map each duplicate field column to the | ||||
|         # columns it duplicates. | ||||
|         collections = {} | ||||
|         for column, target in self.duplicate_targets.iteritems(): | ||||
|             try: | ||||
|                 collections[target].add(column) | ||||
|             except KeyError: | ||||
|                 collections[target] = set([column]) | ||||
|         self.duplicate_targets = {} | ||||
|         for elt in collections.itervalues(): | ||||
|             if len(elt) == 1: | ||||
|                 continue | ||||
|             for column in elt: | ||||
|                 self.duplicate_targets[column] = elt.difference(set([column])) | ||||
|  | ||||
|     def add_field(self, field): | ||||
|         # Insert the given field in the order in which it was created, using | ||||
|         # the "creation_counter" attribute of the field. | ||||
|   | ||||
| @@ -57,6 +57,7 @@ class Query(object): | ||||
|         self.start_meta = None | ||||
|         self.select_fields = [] | ||||
|         self.related_select_fields = [] | ||||
|         self.dupe_avoidance = {} | ||||
|  | ||||
|         # SQL-related attributes | ||||
|         self.select = [] | ||||
| @@ -165,6 +166,7 @@ class Query(object): | ||||
|         obj.start_meta = self.start_meta | ||||
|         obj.select_fields = self.select_fields[:] | ||||
|         obj.related_select_fields = self.related_select_fields[:] | ||||
|         obj.dupe_avoidance = self.dupe_avoidance.copy() | ||||
|         obj.select = self.select[:] | ||||
|         obj.tables = self.tables[:] | ||||
|         obj.where = deepcopy(self.where) | ||||
| @@ -830,8 +832,8 @@ class Query(object): | ||||
|  | ||||
|         if reuse and always_create and table in self.table_map: | ||||
|             # Convert the 'reuse' to case to be "exclude everything but the | ||||
|             # reusable set for this table". | ||||
|             exclusions = set(self.table_map[table]).difference(reuse) | ||||
|             # reusable set, minus exclusions, for this table". | ||||
|             exclusions = set(self.table_map[table]).difference(reuse).union(set(exclusions)) | ||||
|             always_create = False | ||||
|         t_ident = (lhs_table, table, lhs_col, col) | ||||
|         if not always_create: | ||||
| @@ -866,7 +868,8 @@ class Query(object): | ||||
|         return alias | ||||
|  | ||||
|     def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, | ||||
|             used=None, requested=None, restricted=None, nullable=None): | ||||
|             used=None, requested=None, restricted=None, nullable=None, | ||||
|             dupe_set=None): | ||||
|         """ | ||||
|         Fill in the information needed for a select_related query. The current | ||||
|         depth is measured as the number of connections away from the root model | ||||
| @@ -876,6 +879,7 @@ class Query(object): | ||||
|         if not restricted and self.max_depth and cur_depth > self.max_depth: | ||||
|             # We've recursed far enough; bail out. | ||||
|             return | ||||
|  | ||||
|         if not opts: | ||||
|             opts = self.get_meta() | ||||
|             root_alias = self.get_initial_alias() | ||||
| @@ -883,6 +887,10 @@ class Query(object): | ||||
|             self.related_select_fields = [] | ||||
|         if not used: | ||||
|             used = set() | ||||
|         if dupe_set is None: | ||||
|             dupe_set = set() | ||||
|         orig_dupe_set = dupe_set | ||||
|         orig_used = used | ||||
|  | ||||
|         # Setup for the case when only particular related fields should be | ||||
|         # included in the related selection. | ||||
| @@ -897,6 +905,8 @@ class Query(object): | ||||
|             if (not f.rel or (restricted and f.name not in requested) or | ||||
|                     (not restricted and f.null) or f.rel.parent_link): | ||||
|                 continue | ||||
|             dupe_set = orig_dupe_set.copy() | ||||
|             used = orig_used.copy() | ||||
|             table = f.rel.to._meta.db_table | ||||
|             if nullable or f.null: | ||||
|                 promote = True | ||||
| @@ -907,12 +917,26 @@ class Query(object): | ||||
|                 alias = root_alias | ||||
|                 for int_model in opts.get_base_chain(model): | ||||
|                     lhs_col = int_opts.parents[int_model].column | ||||
|                     dedupe = lhs_col in opts.duplicate_targets | ||||
|                     if dedupe: | ||||
|                         used.update(self.dupe_avoidance.get(id(opts), lhs_col), | ||||
|                                 ()) | ||||
|                         dupe_set.add((opts, lhs_col)) | ||||
|                     int_opts = int_model._meta | ||||
|                     alias = self.join((alias, int_opts.db_table, lhs_col, | ||||
|                             int_opts.pk.column), exclusions=used, | ||||
|                             promote=promote) | ||||
|                     for (dupe_opts, dupe_col) in dupe_set: | ||||
|                         self.update_dupe_avoidance(dupe_opts, dupe_col, alias) | ||||
|             else: | ||||
|                 alias = root_alias | ||||
|  | ||||
|             dedupe = f.column in opts.duplicate_targets | ||||
|             if dupe_set or dedupe: | ||||
|                 used.update(self.dupe_avoidance.get((id(opts), f.column), ())) | ||||
|                 if dedupe: | ||||
|                     dupe_set.add((opts, f.column)) | ||||
|  | ||||
|             alias = self.join((alias, table, f.column, | ||||
|                     f.rel.get_related_field().column), exclusions=used, | ||||
|                     promote=promote) | ||||
| @@ -928,8 +952,10 @@ class Query(object): | ||||
|                 new_nullable = f.null | ||||
|             else: | ||||
|                 new_nullable = None | ||||
|             for dupe_opts, dupe_col in dupe_set: | ||||
|                 self.update_dupe_avoidance(dupe_opts, dupe_col, alias) | ||||
|             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, | ||||
|                     used, next, restricted, new_nullable) | ||||
|                     used, next, restricted, new_nullable, dupe_set) | ||||
|  | ||||
|     def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, | ||||
|             can_reuse=None): | ||||
| @@ -1128,7 +1154,9 @@ class Query(object): | ||||
|         (which gives the table we are joining to), 'alias' is the alias for the | ||||
|         table we are joining to. If dupe_multis is True, any many-to-many or | ||||
|         many-to-one joins will always create a new alias (necessary for | ||||
|         disjunctive filters). | ||||
|         disjunctive filters). If can_reuse is not None, it's a list of aliases | ||||
|         that can be reused in these joins (nothing else can be reused in this | ||||
|         case). | ||||
|  | ||||
|         Returns the final field involved in the join, the target database | ||||
|         column (used for any 'where' constraint), the final 'opts' value and the | ||||
| @@ -1136,7 +1164,14 @@ class Query(object): | ||||
|         """ | ||||
|         joins = [alias] | ||||
|         last = [0] | ||||
|         dupe_set = set() | ||||
|         exclusions = set() | ||||
|         for pos, name in enumerate(names): | ||||
|             try: | ||||
|                 exclusions.add(int_alias) | ||||
|             except NameError: | ||||
|                 pass | ||||
|             exclusions.add(alias) | ||||
|             last.append(len(joins)) | ||||
|             if name == 'pk': | ||||
|                 name = opts.pk.name | ||||
| @@ -1155,6 +1190,7 @@ class Query(object): | ||||
|                     names = opts.get_all_field_names() | ||||
|                     raise FieldError("Cannot resolve keyword %r into field. " | ||||
|                             "Choices are: %s" % (name, ", ".join(names))) | ||||
|  | ||||
|             if not allow_many and (m2m or not direct): | ||||
|                 for alias in joins: | ||||
|                     self.unref_alias(alias) | ||||
| @@ -1164,12 +1200,27 @@ class Query(object): | ||||
|                 alias_list = [] | ||||
|                 for int_model in opts.get_base_chain(model): | ||||
|                     lhs_col = opts.parents[int_model].column | ||||
|                     dedupe = lhs_col in opts.duplicate_targets | ||||
|                     if dedupe: | ||||
|                         exclusions.update(self.dupe_avoidance.get( | ||||
|                                 (id(opts), lhs_col), ())) | ||||
|                         dupe_set.add((opts, lhs_col)) | ||||
|                     opts = int_model._meta | ||||
|                     alias = self.join((alias, opts.db_table, lhs_col, | ||||
|                             opts.pk.column), exclusions=joins) | ||||
|                             opts.pk.column), exclusions=exclusions) | ||||
|                     joins.append(alias) | ||||
|                     exclusions.add(alias) | ||||
|                     for (dupe_opts, dupe_col) in dupe_set: | ||||
|                         self.update_dupe_avoidance(dupe_opts, dupe_col, alias) | ||||
|             cached_data = opts._join_cache.get(name) | ||||
|             orig_opts = opts | ||||
|             dupe_col = direct and field.column or field.field.column | ||||
|             dedupe = dupe_col in opts.duplicate_targets | ||||
|             if dupe_set or dedupe: | ||||
|                 if dedupe: | ||||
|                     dupe_set.add((opts, dupe_col)) | ||||
|                 exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col), | ||||
|                         ())) | ||||
|  | ||||
|             if direct: | ||||
|                 if m2m: | ||||
| @@ -1191,9 +1242,11 @@ class Query(object): | ||||
|                                 target) | ||||
|  | ||||
|                     int_alias = self.join((alias, table1, from_col1, to_col1), | ||||
|                             dupe_multis, joins, nullable=True, reuse=can_reuse) | ||||
|                             dupe_multis, exclusions, nullable=True, | ||||
|                             reuse=can_reuse) | ||||
|                     alias = self.join((int_alias, table2, from_col2, to_col2), | ||||
|                             dupe_multis, joins, nullable=True, reuse=can_reuse) | ||||
|                             dupe_multis, exclusions, nullable=True, | ||||
|                             reuse=can_reuse) | ||||
|                     joins.extend([int_alias, alias]) | ||||
|                 elif field.rel: | ||||
|                     # One-to-one or many-to-one field | ||||
| @@ -1209,7 +1262,7 @@ class Query(object): | ||||
|                                 opts, target) | ||||
|  | ||||
|                     alias = self.join((alias, table, from_col, to_col), | ||||
|                             exclusions=joins, nullable=field.null) | ||||
|                             exclusions=exclusions, nullable=field.null) | ||||
|                     joins.append(alias) | ||||
|                 else: | ||||
|                     # Non-relation fields. | ||||
| @@ -1237,9 +1290,11 @@ class Query(object): | ||||
|                                 target) | ||||
|  | ||||
|                     int_alias = self.join((alias, table1, from_col1, to_col1), | ||||
|                             dupe_multis, joins, nullable=True, reuse=can_reuse) | ||||
|                             dupe_multis, exclusions, nullable=True, | ||||
|                             reuse=can_reuse) | ||||
|                     alias = self.join((int_alias, table2, from_col2, to_col2), | ||||
|                             dupe_multis, joins, nullable=True, reuse=can_reuse) | ||||
|                             dupe_multis, exclusions, nullable=True, | ||||
|                             reuse=can_reuse) | ||||
|                     joins.extend([int_alias, alias]) | ||||
|                 else: | ||||
|                     # One-to-many field (ForeignKey defined on the target model) | ||||
| @@ -1257,14 +1312,34 @@ class Query(object): | ||||
|                                 opts, target) | ||||
|  | ||||
|                     alias = self.join((alias, table, from_col, to_col), | ||||
|                             dupe_multis, joins, nullable=True, reuse=can_reuse) | ||||
|                             dupe_multis, exclusions, nullable=True, | ||||
|                             reuse=can_reuse) | ||||
|                     joins.append(alias) | ||||
|  | ||||
|             for (dupe_opts, dupe_col) in dupe_set: | ||||
|                 try: | ||||
|                     self.update_dupe_avoidance(dupe_opts, dupe_col, int_alias) | ||||
|                 except NameError: | ||||
|                     self.update_dupe_avoidance(dupe_opts, dupe_col, alias) | ||||
|  | ||||
|         if pos != len(names) - 1: | ||||
|             raise FieldError("Join on field %r not permitted." % name) | ||||
|  | ||||
|         return field, target, opts, joins, last | ||||
|  | ||||
|     def update_dupe_avoidance(self, opts, col, alias): | ||||
|         """ | ||||
|         For a column that is one of multiple pointing to the same table, update | ||||
|         the internal data structures to note that this alias shouldn't be used | ||||
|         for those other columns. | ||||
|         """ | ||||
|         ident = id(opts) | ||||
|         for name in opts.duplicate_targets[col]: | ||||
|             try: | ||||
|                 self.dupe_avoidance[ident, name].add(alias) | ||||
|             except KeyError: | ||||
|                 self.dupe_avoidance[ident, name] = set([alias]) | ||||
|  | ||||
|     def split_exclude(self, filter_expr, prefix): | ||||
|         """ | ||||
|         When doing an exclude against any kind of N-to-many relation, we need | ||||
|   | ||||
| @@ -28,6 +28,24 @@ class Child(models.Model): | ||||
|     parent = models.ForeignKey(Parent) | ||||
|  | ||||
|  | ||||
| # Multiple paths to the same model (#7110, #7125) | ||||
| class Category(models.Model): | ||||
|     name = models.CharField(max_length=20) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return self.name | ||||
|  | ||||
| class Record(models.Model): | ||||
|     category = models.ForeignKey(Category) | ||||
|  | ||||
| class Relation(models.Model): | ||||
|     left = models.ForeignKey(Record, related_name='left_set') | ||||
|     right = models.ForeignKey(Record, related_name='right_set') | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return u"%s - %s" % (self.left.category.name, self.right.category.name) | ||||
|  | ||||
|  | ||||
| __test__ = {'API_TESTS':""" | ||||
| >>> Third.objects.create(id='3', name='An example') | ||||
| <Third: Third object> | ||||
| @@ -73,4 +91,26 @@ Traceback (most recent call last): | ||||
|     ... | ||||
| ValueError: Cannot assign "<First: First object>": "Child.parent" must be a "Parent" instance. | ||||
|  | ||||
| # Test of multiple ForeignKeys to the same model (bug #7125) | ||||
|  | ||||
| >>> c1 = Category.objects.create(name='First') | ||||
| >>> c2 = Category.objects.create(name='Second') | ||||
| >>> c3 = Category.objects.create(name='Third') | ||||
| >>> r1 = Record.objects.create(category=c1) | ||||
| >>> r2 = Record.objects.create(category=c1) | ||||
| >>> r3 = Record.objects.create(category=c2) | ||||
| >>> r4 = Record.objects.create(category=c2) | ||||
| >>> r5 = Record.objects.create(category=c3) | ||||
| >>> r = Relation.objects.create(left=r1, right=r2) | ||||
| >>> r = Relation.objects.create(left=r3, right=r4) | ||||
| >>> r = Relation.objects.create(left=r1, right=r3) | ||||
| >>> r = Relation.objects.create(left=r5, right=r2) | ||||
| >>> r = Relation.objects.create(left=r3, right=r2) | ||||
|  | ||||
| >>> Relation.objects.filter(left__category__name__in=['First'], right__category__name__in=['Second']) | ||||
| [<Relation: First - Second>] | ||||
|  | ||||
| >>> Category.objects.filter(record__left_set__right__category__name='Second').order_by('name') | ||||
| [<Category: First>, <Category: Second>] | ||||
|  | ||||
| """} | ||||
|   | ||||
							
								
								
									
										60
									
								
								tests/regressiontests/select_related_regress/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								tests/regressiontests/select_related_regress/models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| from django.db import models | ||||
|  | ||||
| class Building(models.Model): | ||||
|     name = models.CharField(max_length=10) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return u"Building: %s" % self.name | ||||
|  | ||||
| class Device(models.Model): | ||||
|     building = models.ForeignKey('Building') | ||||
|     name = models.CharField(max_length=10) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return u"device '%s' in building %s" % (self.name, self.building) | ||||
|  | ||||
| class Port(models.Model): | ||||
|     device = models.ForeignKey('Device') | ||||
|     number = models.CharField(max_length=10) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return u"%s/%s" % (self.device.name, self.number) | ||||
|  | ||||
| class Connection(models.Model): | ||||
|     start = models.ForeignKey(Port, related_name='connection_start', | ||||
|             unique=True) | ||||
|     end = models.ForeignKey(Port, related_name='connection_end', unique=True) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return u"%s to %s" % (self.start, self.end) | ||||
|  | ||||
| __test__ = {'API_TESTS': """ | ||||
| Regression test for bug #7110. When using select_related(), we must query the | ||||
| Device and Building tables using two different aliases (each) in order to | ||||
| differentiate the start and end Connection fields. The net result is that both | ||||
| the "connections = ..." queries here should give the same results. | ||||
|  | ||||
| >>> b=Building.objects.create(name='101') | ||||
| >>> dev1=Device.objects.create(name="router", building=b) | ||||
| >>> dev2=Device.objects.create(name="switch", building=b) | ||||
| >>> dev3=Device.objects.create(name="server", building=b) | ||||
| >>> port1=Port.objects.create(number='4',device=dev1) | ||||
| >>> port2=Port.objects.create(number='7',device=dev2) | ||||
| >>> port3=Port.objects.create(number='1',device=dev3) | ||||
| >>> c1=Connection.objects.create(start=port1, end=port2) | ||||
| >>> c2=Connection.objects.create(start=port2, end=port3) | ||||
|  | ||||
| >>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).order_by('id') | ||||
| >>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections] | ||||
| [(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')] | ||||
|  | ||||
| >>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).select_related().order_by('id') | ||||
| >>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections] | ||||
| [(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')] | ||||
|  | ||||
| # This final query should only join seven tables (port, device and building | ||||
| # twice each, plus connection once). | ||||
| >>> connections.query.count_active_tables() | ||||
| 7 | ||||
|  | ||||
| """} | ||||
		Reference in New Issue
	
	Block a user