diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 59b31dcb79..294e0b21fd 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -569,7 +569,7 @@ class SQLCompiler(object): if (len(self.query.get_meta().concrete_fields) == len(self.query.select) and self.connection.features.allows_group_by_pk): self.query.group_by = [ - (self.query.get_meta().db_table, self.query.get_meta().pk.column) + (self.query.get_initial_alias(), self.query.get_meta().pk.column) ] select_cols = [] seen = set() diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 691de177e0..c03d782440 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1084,6 +1084,32 @@ class Query(object): (lookup, self.get_meta().model.__name__)) return lookup_parts, field_parts, False + def check_query_object_type(self, value, opts): + """ + Checks whether the object passed while querying is of the correct type. + If not, it raises a ValueError specifying the wrong object. + """ + if hasattr(value, '_meta'): + if not (value._meta.concrete_model == opts.concrete_model + or opts.concrete_model in value._meta.get_parent_list() + or value._meta.concrete_model in opts.get_parent_list()): + raise ValueError( + 'Cannot query "%s": Must be "%s" instance.' % + (value, opts.object_name)) + + def check_related_objects(self, field, value, opts): + """ + Checks the type of object passed to query relations. + """ + if field.rel: + # testing for iterable of models + if hasattr(value, '__iter__'): + for v in value: + self.check_query_object_type(v, opts) + else: + # expecting single model instance here + self.check_query_object_type(value, opts) + def build_lookup(self, lookups, lhs, rhs): lookups = lookups[:] while lookups: @@ -1159,6 +1185,9 @@ class Query(object): try: field, sources, opts, join_list, path = self.setup_joins( parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) + + self.check_related_objects(field, value, opts) + # split_exclude() needs to know which joins were generated for the # lookup parts self._lookup_joins = join_list diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 7534a16e3d..793c781e2f 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -350,6 +350,21 @@ The check also applies to the columns generated in an implicit and then specify :attr:`~django.db.models.Field.db_column` on its column(s) as needed. +Query relation lookups now check object types +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Querying for model lookups now checks if the object passed is of correct type +and raises a :exc:`ValueError` if not. Previously, Django didn't care if the +object was of correct type; it just used the object's related field attribute +(e.g. ``id``) for the lookup. Now, an error is raised to prevent incorrect +lookups:: + + >>> book = Book.objects.create(name="Django") + >>> book = Book.objects.filter(author=book) + Traceback (most recent call last): + ... + ValueError: Cannot query "": Must be "Author" instance. + Miscellaneous ~~~~~~~~~~~~~ diff --git a/tests/one_to_one/tests.py b/tests/one_to_one/tests.py index f0a7a175a7..af983fed74 100644 --- a/tests/one_to_one/tests.py +++ b/tests/one_to_one/tests.py @@ -100,7 +100,6 @@ class OneToOneTests(TestCase): assert_filter_waiters(restaurant__place__exact=self.p1) assert_filter_waiters(restaurant__place__pk=self.p1.pk) assert_filter_waiters(restaurant__exact=self.p1.pk) - assert_filter_waiters(restaurant__exact=self.p1) assert_filter_waiters(restaurant__pk=self.p1.pk) assert_filter_waiters(restaurant=self.p1.pk) assert_filter_waiters(restaurant=self.r) diff --git a/tests/queries/models.py b/tests/queries/models.py index 0dd4833b53..465b8f2446 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -409,6 +409,15 @@ class ObjectA(models.Model): return self.name +class ProxyObjectA(ObjectA): + class Meta: + proxy = True + + +class ChildObjectA(ObjectA): + pass + + @python_2_unicode_compatible class ObjectB(models.Model): name = models.CharField(max_length=50) @@ -419,11 +428,17 @@ class ObjectB(models.Model): return self.name +class ProxyObjectB(ObjectB): + class Meta: + proxy = True + + @python_2_unicode_compatible class ObjectC(models.Model): name = models.CharField(max_length=50) objecta = models.ForeignKey(ObjectA, null=True) objectb = models.ForeignKey(ObjectB, null=True) + childobjecta = models.ForeignKey(ChildObjectA, null=True, related_name='ca_pk') def __str__(self): return self.name diff --git a/tests/queries/tests.py b/tests/queries/tests.py index f6e8d79309..a47461bb86 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -22,12 +22,12 @@ from .models import ( ExtraInfo, Fan, Item, LeafA, Join, LeafB, LoopX, LoopZ, ManagedModel, Member, NamedCategory, Note, Number, Plaything, PointerA, Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Node, ObjectA, - ObjectB, ObjectC, CategoryItem, SimpleCategory, SpecialCategory, - OneToOneCategory, NullableName, ProxyCategory, SingleObject, RelatedObject, - ModelA, ModelB, ModelC, ModelD, Responsibility, Job, JobResponsibilities, - BaseA, FK1, Identifier, Program, Channel, Page, Paragraph, Chapter, Book, - MyObject, Order, OrderItem, SharedConnection, Task, Staff, StaffUser, - CategoryRelationship, Ticket21203Parent, Ticket21203Child, Person, + ProxyObjectA, ChildObjectA, ObjectB, ProxyObjectB, ObjectC, CategoryItem, + SimpleCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, + SingleObject, RelatedObject, ModelA, ModelB, ModelC, ModelD, Responsibility, Job, + JobResponsibilities, BaseA, FK1, Identifier, Program, Channel, Page, Paragraph, + Chapter, Book, MyObject, Order, OrderItem, SharedConnection, Task, Staff, + StaffUser, CategoryRelationship, Ticket21203Parent, Ticket21203Child, Person, Company, Employment, CustomPk, CustomPkTag, Classroom, School, Student) @@ -3361,20 +3361,85 @@ class Ticket12807Tests(TestCase): class RelatedLookupTypeTests(TestCase): + error = 'Cannot query "%s": Must be "%s" instance.' + + def setUp(self): + self.oa = ObjectA.objects.create(name="oa") + self.poa = ProxyObjectA.objects.get(name="oa") + self.coa = ChildObjectA.objects.create(name="coa") + self.wrong_type = Order.objects.create(id=self.oa.pk) + self.ob = ObjectB.objects.create(name="ob", objecta=self.oa, num=1) + ProxyObjectB.objects.create(name="pob", objecta=self.oa, num=2) + self.pob = ProxyObjectB.objects.all() + ObjectC.objects.create(childobjecta=self.coa) + def test_wrong_type_lookup(self): - oa = ObjectA.objects.create(name="oa") - wrong_type = Order.objects.create(id=oa.pk) - ob = ObjectB.objects.create(name="ob", objecta=oa, num=1) - # Currently Django doesn't care if the object is of correct - # type, it will just use the objecta's related fields attribute - # (id) for model lookup. Making things more restrictive could - # be a good idea... - self.assertQuerysetEqual( - ObjectB.objects.filter(objecta=wrong_type), - [ob], lambda x: x) - self.assertQuerysetEqual( - ObjectB.objects.filter(objecta__in=[wrong_type]), - [ob], lambda x: x) + """ + A ValueError is raised when the incorrect object type is passed to a + query lookup. + """ + # Passing incorrect object type + with self.assertRaisesMessage(ValueError, + self.error % (self.wrong_type, ObjectA._meta.object_name)): + ObjectB.objects.get(objecta=self.wrong_type) + + with self.assertRaisesMessage(ValueError, + self.error % (self.wrong_type, ObjectA._meta.object_name)): + ObjectB.objects.filter(objecta__in=[self.wrong_type]) + + with self.assertRaisesMessage(ValueError, + self.error % (self.wrong_type, ObjectA._meta.object_name)): + ObjectB.objects.filter(objecta=self.wrong_type) + + with self.assertRaisesMessage(ValueError, + self.error % (self.wrong_type, ObjectB._meta.object_name)): + ObjectA.objects.filter(objectb__in=[self.wrong_type, self.ob]) + + # Passing an object of the class on which query is done. + with self.assertRaisesMessage(ValueError, + self.error % (self.ob, ObjectA._meta.object_name)): + ObjectB.objects.filter(objecta__in=[self.poa, self.ob]) + + with self.assertRaisesMessage(ValueError, + self.error % (self.ob, ChildObjectA._meta.object_name)): + ObjectC.objects.exclude(childobjecta__in=[self.coa, self.ob]) + + def test_wrong_backward_lookup(self): + """ + A ValueError is raised when the incorrect object type is passed to a + query lookup for backward relations. + """ + with self.assertRaisesMessage(ValueError, + self.error % (self.oa, ObjectB._meta.object_name)): + ObjectA.objects.filter(objectb__in=[self.oa, self.ob]) + + with self.assertRaisesMessage(ValueError, + self.error % (self.oa, ObjectB._meta.object_name)): + ObjectA.objects.exclude(objectb=self.oa) + + with self.assertRaisesMessage(ValueError, + self.error % (self.wrong_type, ObjectB._meta.object_name)): + ObjectA.objects.get(objectb=self.wrong_type) + + def test_correct_lookup(self): + """ + When passing proxy model objects, child objects, or parent objects, + lookups work fine. + """ + out_a = ['', ] + out_b = ['', ''] + out_c = [''] + + # proxy model objects + self.assertQuerysetEqual(ObjectB.objects.filter(objecta=self.poa).order_by('name'), out_b) + self.assertQuerysetEqual(ObjectA.objects.filter(objectb__in=self.pob).order_by('pk'), out_a * 2) + + # child objects + self.assertQuerysetEqual(ObjectB.objects.filter(objecta__in=[self.coa]), []) + self.assertQuerysetEqual(ObjectB.objects.filter(objecta__in=[self.poa, self.coa]).order_by('name'), out_b) + + # parent objects + self.assertQuerysetEqual(ObjectC.objects.exclude(childobjecta=self.oa), out_c) class Ticket14056Tests(TestCase):