diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index f40c2c713c..0242c148fa 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -359,10 +359,46 @@ class GenericRelation(ForeignObject): self.to_fields = [self.model._meta.pk.name] return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)] + def _get_path_info_with_parent(self): + """ + Return the path that joins the current model through any parent models. + The idea is that if you have a GFK defined on a parent model then we + need to join the parent model first, then the child model. + """ + # With an inheritance chain ChildTag -> Tag and Tag defines the + # GenericForeignKey, and a TaggedItem model has a GenericRelation to + # ChildTag, then we need to generate a join from TaggedItem to Tag + # (as Tag.object_id == TaggedItem.pk), and another join from Tag to + # ChildTag (as that is where the relation is to). Do this by first + # generating a join to the parent model, then generating joins to the + # child models. + path = [] + opts = self.remote_field.model._meta + parent_opts = opts.get_field(self.object_id_field_name).model._meta + target = parent_opts.pk + path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False)) + # Collect joins needed for the parent -> child chain. This is easiest + # to do if we collect joins for the child -> parent chain and then + # reverse the direction (call to reverse() and use of + # field.remote_field.get_path_info()). + parent_field_chain = [] + while parent_opts != opts: + field = opts.get_ancestor_link(parent_opts.model) + parent_field_chain.append(field) + opts = field.remote_field.model._meta + parent_field_chain.reverse() + for field in parent_field_chain: + path.extend(field.remote_field.get_path_info()) + return path + def get_path_info(self): opts = self.remote_field.model._meta - target = opts.pk - return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)] + object_id_field = opts.get_field(self.object_id_field_name) + if object_id_field.model != opts.model: + return self._get_path_info_with_parent() + else: + target = opts.pk + return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)] def get_reverse_path_info(self): opts = self.model._meta diff --git a/tests/generic_relations/tests.py b/tests/generic_relations/tests.py index 3d3e2acfe5..1d3d3d373b 100644 --- a/tests/generic_relations/tests.py +++ b/tests/generic_relations/tests.py @@ -405,6 +405,13 @@ class GenericRelationsTests(TestCase): # GenericRelations to models that use multi-table inheritance work. granite = ValuableRock.objects.create(name='granite', hardness=5) ValuableTaggedItem.objects.create(content_object=granite, tag="countertop", value=1) + self.assertEqual(ValuableRock.objects.filter(tags__value=1).count(), 1) + # We're generating a slightly inefficient query for tags__tag - we + # first join ValuableRock -> TaggedItem -> ValuableTaggedItem, and then + # we fetch tag by joining TaggedItem from ValuableTaggedItem. The last + # join isn't necessary, as TaggedItem <-> ValuableTaggedItem is a + # one-to-one join. + self.assertEqual(ValuableRock.objects.filter(tags__tag="countertop").count(), 1) granite.delete() # deleting the rock should delete the related tag. self.assertEqual(ValuableTaggedItem.objects.count(), 0)