From 227c5e80dbff2b731a1cba8c4dd94c00526a3e76 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 25 Jan 2011 03:14:28 +0000 Subject: [PATCH] Fixed #11319 - Added lookup support for ForeignKey.to_field. Also reverted no-longer-needed model formsets workaround for lack of such support from r10756. Thanks Russell and Alex for review. git-svn-id: http://code.djangoproject.com/svn/django/trunk@15303 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/fields/related.py | 13 +++- django/db/models/sql/query.py | 7 ++- django/forms/models.py | 6 +- tests/modeltests/custom_pk/tests.py | 8 +-- .../regressiontests/delete_regress/models.py | 7 +++ tests/regressiontests/delete_regress/tests.py | 12 +++- tests/regressiontests/queries/models.py | 20 ++++++ tests/regressiontests/queries/tests.py | 63 ++++++++++++++++++- 8 files changed, 122 insertions(+), 14 deletions(-) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index f2adc7af0c..b9ffcbd238 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -178,9 +178,20 @@ class RelatedField(object): # the primary key may itself be an object - so we need to keep drilling # down until we hit a value that can be used for a comparison. v = value + + # In the case of an FK to 'self', this check allows to_field to be used + # for both forwards and reverse lookups across the FK. (For normal FKs, + # it's only relevant for forward lookups). + if isinstance(v, self.rel.to): + field_name = getattr(self.rel, "field_name", None) + else: + field_name = None try: while True: - v = getattr(v, v._meta.pk.name) + if field_name is None: + field_name = v._meta.pk.name + v = getattr(v, field_name) + field_name = None except AttributeError: pass except exceptions.ObjectDoesNotExist: diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 1c58a24d45..e028800d72 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1364,7 +1364,12 @@ class Query(object): table = opts.db_table from_col = local_field.column to_col = field.column - target = opts.pk + # In case of a recursive FK, use the to_field for + # reverse lookups as well + if orig_field.model is local_field.model: + target = opts.get_field(field.rel.field_name) + else: + target = opts.pk orig_opts._join_cache[name] = (table, from_col, to_col, opts, target) diff --git a/django/forms/models.py b/django/forms/models.py index de52b6a7e0..6babebbe36 100644 --- a/django/forms/models.py +++ b/django/forms/models.py @@ -700,13 +700,9 @@ class BaseInlineFormSet(BaseModelFormSet): self.save_as_new = save_as_new # is there a better way to get the object descriptor? self.rel_name = RelatedObject(self.fk.rel.to, self.model, self.fk).get_accessor_name() - if self.fk.rel.field_name == self.fk.rel.to._meta.pk.name: - backlink_value = self.instance - else: - backlink_value = getattr(self.instance, self.fk.rel.field_name) if queryset is None: queryset = self.model._default_manager - qs = queryset.filter(**{self.fk.name: backlink_value}) + qs = queryset.filter(**{self.fk.name: self.instance}) super(BaseInlineFormSet, self).__init__(data, files, prefix=prefix, queryset=qs) diff --git a/tests/modeltests/custom_pk/tests.py b/tests/modeltests/custom_pk/tests.py index 22975a8417..c410ad17e3 100644 --- a/tests/modeltests/custom_pk/tests.py +++ b/tests/modeltests/custom_pk/tests.py @@ -158,11 +158,9 @@ class CustomPKTests(TestCase): new_bar = Bar.objects.create() new_foo = Foo.objects.create(bar=new_bar) - # FIXME: This still doesn't work, but will require some changes in - # get_db_prep_lookup to fix it. - # f = Foo.objects.get(bar=new_bar.pk) - # self.assertEqual(f, new_foo) - # self.assertEqual(f.bar, new_bar) + f = Foo.objects.get(bar=new_bar.pk) + self.assertEqual(f, new_foo) + self.assertEqual(f.bar, new_bar) f = Foo.objects.get(bar=new_bar) self.assertEqual(f, new_foo), diff --git a/tests/regressiontests/delete_regress/models.py b/tests/regressiontests/delete_regress/models.py index 07b58317ae..5c77117719 100644 --- a/tests/regressiontests/delete_regress/models.py +++ b/tests/regressiontests/delete_regress/models.py @@ -44,3 +44,10 @@ class Email(Contact): class Researcher(models.Model): contacts = models.ManyToManyField(Contact, related_name="research_contacts") + +class Food(models.Model): + name = models.CharField(max_length=20, unique=True) + +class Eaten(models.Model): + food = models.ForeignKey(Food, to_field="name") + meal = models.CharField(max_length=20) diff --git a/tests/regressiontests/delete_regress/tests.py b/tests/regressiontests/delete_regress/tests.py index 7e243c841c..06f3b5c866 100644 --- a/tests/regressiontests/delete_regress/tests.py +++ b/tests/regressiontests/delete_regress/tests.py @@ -5,7 +5,7 @@ from django.db import backend, connection, transaction, DEFAULT_DB_ALIAS from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature from models import (Book, Award, AwardNote, Person, Child, Toy, PlayedWith, - PlayedWithNote, Contact, Email, Researcher) + PlayedWithNote, Contact, Email, Researcher, Food, Eaten) # Can't run this test under SQLite, because you can't @@ -119,6 +119,16 @@ class DeleteCascadeTransactionTests(TransactionTestCase): email.delete() + def test_to_field(self): + """ + Cascade deletion works with ForeignKey.to_field set to non-PK. + + """ + apple = Food.objects.create(name="apple") + eaten = Eaten.objects.create(food=apple, meal="lunch") + + apple.delete() + class LargeDeleteTests(TestCase): def test_large_deletes(self): "Regression for #13309 -- if the number of objects > chunk size, deletion still occurs" diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index 5247ef90ce..3b7a08aba2 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -274,3 +274,23 @@ class Plaything(models.Model): class Article(models.Model): name = models.CharField(max_length=20) created = models.DateTimeField() + +class Food(models.Model): + name = models.CharField(max_length=20, unique=True) + + def __unicode__(self): + return self.name + +class Eaten(models.Model): + food = models.ForeignKey(Food, to_field="name") + meal = models.CharField(max_length=20) + + def __unicode__(self): + return u"%s at %s" % (self.food, self.meal) + +class Node(models.Model): + num = models.IntegerField(unique=True) + parent = models.ForeignKey("self", to_field="num", null=True) + + def __unicode__(self): + return u"%s" % self.num diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index 9993acf942..4099fb6dad 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -14,7 +14,7 @@ from django.utils.datastructures import SortedDict from models import (Annotation, Article, Author, Celebrity, Child, Cover, Detail, DumbCategory, ExtraInfo, Fan, Item, LeafA, LoopX, LoopZ, ManagedModel, Member, NamedCategory, Note, Number, Plaything, PointerA, Ranking, Related, - Report, ReservedName, Tag, TvChef, Valid, X) + Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Node) class BaseQuerysetTest(TestCase): @@ -1515,6 +1515,67 @@ class EscapingTests(TestCase): ) +class ToFieldTests(TestCase): + def test_in_query(self): + apple = Food.objects.create(name="apple") + pear = Food.objects.create(name="pear") + lunch = Eaten.objects.create(food=apple, meal="lunch") + dinner = Eaten.objects.create(food=pear, meal="dinner") + + self.assertEqual( + set(Eaten.objects.filter(food__in=[apple, pear])), + set([lunch, dinner]), + ) + + def test_reverse_in(self): + apple = Food.objects.create(name="apple") + pear = Food.objects.create(name="pear") + lunch_apple = Eaten.objects.create(food=apple, meal="lunch") + lunch_pear = Eaten.objects.create(food=pear, meal="dinner") + + self.assertEqual( + set(Food.objects.filter(eaten__in=[lunch_apple, lunch_pear])), + set([apple, pear]) + ) + + def test_single_object(self): + apple = Food.objects.create(name="apple") + lunch = Eaten.objects.create(food=apple, meal="lunch") + dinner = Eaten.objects.create(food=apple, meal="dinner") + + self.assertEqual( + set(Eaten.objects.filter(food=apple)), + set([lunch, dinner]) + ) + + def test_single_object_reverse(self): + apple = Food.objects.create(name="apple") + lunch = Eaten.objects.create(food=apple, meal="lunch") + + self.assertEqual( + set(Food.objects.filter(eaten=lunch)), + set([apple]) + ) + + def test_recursive_fk(self): + node1 = Node.objects.create(num=42) + node2 = Node.objects.create(num=1, parent=node1) + + self.assertEqual( + list(Node.objects.filter(parent=node1)), + [node2] + ) + + def test_recursive_fk_reverse(self): + node1 = Node.objects.create(num=42) + node2 = Node.objects.create(num=1, parent=node1) + + self.assertEqual( + list(Node.objects.filter(node=node2)), + [node1] + ) + + class ConditionalTests(BaseQuerysetTest): """Tests whose execution depend on dfferent environment conditions like Python version or DB backend features"""