mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #20946 -- model inheritance + m2m failure
Cleaned up the internal implementation of m2m fields by removing related.py _get_fk_val(). The _get_fk_val() was doing the wrong thing if asked for the foreign key value on foreign key to parent model's primary key when child model had different primary key field.
This commit is contained in:
@@ -503,8 +503,6 @@ def create_many_related_manager(superclass, rel):
|
|||||||
self.through = through
|
self.through = through
|
||||||
self.prefetch_cache_name = prefetch_cache_name
|
self.prefetch_cache_name = prefetch_cache_name
|
||||||
self.related_val = source_field.get_foreign_related_value(instance)
|
self.related_val = source_field.get_foreign_related_value(instance)
|
||||||
# Used for single column related auto created models
|
|
||||||
self._fk_val = self.related_val[0]
|
|
||||||
if None in self.related_val:
|
if None in self.related_val:
|
||||||
raise ValueError('"%r" needs to have a value for field "%s" before '
|
raise ValueError('"%r" needs to have a value for field "%s" before '
|
||||||
'this many-to-many relationship can be used.' %
|
'this many-to-many relationship can be used.' %
|
||||||
@@ -517,18 +515,6 @@ def create_many_related_manager(superclass, rel):
|
|||||||
"a many-to-many relationship can be used." %
|
"a many-to-many relationship can be used." %
|
||||||
instance.__class__.__name__)
|
instance.__class__.__name__)
|
||||||
|
|
||||||
def _get_fk_val(self, obj, field_name):
|
|
||||||
"""
|
|
||||||
Returns the correct value for this relationship's foreign key. This
|
|
||||||
might be something else than pk value when to_field is used.
|
|
||||||
"""
|
|
||||||
fk = self.through._meta.get_field(field_name)
|
|
||||||
if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
|
|
||||||
attname = fk.rel.get_related_field().get_attname()
|
|
||||||
return fk.get_prep_lookup('exact', getattr(obj, attname))
|
|
||||||
else:
|
|
||||||
return obj.pk
|
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
try:
|
try:
|
||||||
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
|
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
|
||||||
@@ -626,11 +612,12 @@ def create_many_related_manager(superclass, rel):
|
|||||||
if not router.allow_relation(obj, self.instance):
|
if not router.allow_relation(obj, self.instance):
|
||||||
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
|
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
|
||||||
(obj, self.instance._state.db, obj._state.db))
|
(obj, self.instance._state.db, obj._state.db))
|
||||||
fk_val = self._get_fk_val(obj, target_field_name)
|
fk_val = self.through._meta.get_field(
|
||||||
|
target_field_name).get_foreign_related_value(obj)[0]
|
||||||
if fk_val is None:
|
if fk_val is None:
|
||||||
raise ValueError('Cannot add "%r": the value for field "%s" is None' %
|
raise ValueError('Cannot add "%r": the value for field "%s" is None' %
|
||||||
(obj, target_field_name))
|
(obj, target_field_name))
|
||||||
new_ids.add(self._get_fk_val(obj, target_field_name))
|
new_ids.add(fk_val)
|
||||||
elif isinstance(obj, Model):
|
elif isinstance(obj, Model):
|
||||||
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
|
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
|
||||||
else:
|
else:
|
||||||
@@ -638,7 +625,7 @@ def create_many_related_manager(superclass, rel):
|
|||||||
db = router.db_for_write(self.through, instance=self.instance)
|
db = router.db_for_write(self.through, instance=self.instance)
|
||||||
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
|
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
|
||||||
vals = vals.filter(**{
|
vals = vals.filter(**{
|
||||||
source_field_name: self._fk_val,
|
source_field_name: self.related_val[0],
|
||||||
'%s__in' % target_field_name: new_ids,
|
'%s__in' % target_field_name: new_ids,
|
||||||
})
|
})
|
||||||
new_ids = new_ids - set(vals)
|
new_ids = new_ids - set(vals)
|
||||||
@@ -652,7 +639,7 @@ def create_many_related_manager(superclass, rel):
|
|||||||
# Add the ones that aren't there already
|
# Add the ones that aren't there already
|
||||||
self.through._default_manager.using(db).bulk_create([
|
self.through._default_manager.using(db).bulk_create([
|
||||||
self.through(**{
|
self.through(**{
|
||||||
'%s_id' % source_field_name: self._fk_val,
|
'%s_id' % source_field_name: self.related_val[0],
|
||||||
'%s_id' % target_field_name: obj_id,
|
'%s_id' % target_field_name: obj_id,
|
||||||
})
|
})
|
||||||
for obj_id in new_ids
|
for obj_id in new_ids
|
||||||
@@ -676,7 +663,9 @@ def create_many_related_manager(superclass, rel):
|
|||||||
old_ids = set()
|
old_ids = set()
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if isinstance(obj, self.model):
|
if isinstance(obj, self.model):
|
||||||
old_ids.add(self._get_fk_val(obj, target_field_name))
|
fk_val = self.through._meta.get_field(
|
||||||
|
target_field_name).get_foreign_related_value(obj)[0]
|
||||||
|
old_ids.add(fk_val)
|
||||||
else:
|
else:
|
||||||
old_ids.add(obj)
|
old_ids.add(obj)
|
||||||
# Work out what DB we're operating on
|
# Work out what DB we're operating on
|
||||||
@@ -690,7 +679,7 @@ def create_many_related_manager(superclass, rel):
|
|||||||
model=self.model, pk_set=old_ids, using=db)
|
model=self.model, pk_set=old_ids, using=db)
|
||||||
# Remove the specified objects from the join table
|
# Remove the specified objects from the join table
|
||||||
self.through._default_manager.using(db).filter(**{
|
self.through._default_manager.using(db).filter(**{
|
||||||
source_field_name: self._fk_val,
|
source_field_name: self.related_val[0],
|
||||||
'%s__in' % target_field_name: old_ids
|
'%s__in' % target_field_name: old_ids
|
||||||
}).delete()
|
}).delete()
|
||||||
if self.reverse or source_field_name == self.source_field_name:
|
if self.reverse or source_field_name == self.source_field_name:
|
||||||
@@ -994,10 +983,13 @@ class ForeignObject(RelatedField):
|
|||||||
# Gotcha: in some cases (like fixture loading) a model can have
|
# Gotcha: in some cases (like fixture loading) a model can have
|
||||||
# different values in parent_ptr_id and parent's id. So, use
|
# different values in parent_ptr_id and parent's id. So, use
|
||||||
# instance.pk (that is, parent_ptr_id) when asked for instance.id.
|
# instance.pk (that is, parent_ptr_id) when asked for instance.id.
|
||||||
|
opts = instance._meta
|
||||||
if field.primary_key:
|
if field.primary_key:
|
||||||
ret.append(instance.pk)
|
possible_parent_link = opts.get_ancestor_link(field.model)
|
||||||
else:
|
if not possible_parent_link or possible_parent_link.primary_key:
|
||||||
ret.append(getattr(instance, field.attname))
|
ret.append(instance.pk)
|
||||||
|
continue
|
||||||
|
ret.append(getattr(instance, field.attname))
|
||||||
return tuple(ret)
|
return tuple(ret)
|
||||||
|
|
||||||
def get_attname_column(self):
|
def get_attname_column(self):
|
||||||
|
@@ -162,3 +162,9 @@ class Mixin(object):
|
|||||||
|
|
||||||
class MixinModel(models.Model, Mixin):
|
class MixinModel(models.Model, Mixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class Base(models.Model):
|
||||||
|
titles = models.ManyToManyField(Title)
|
||||||
|
|
||||||
|
class SubBase(Base):
|
||||||
|
sub_id = models.IntegerField(primary_key=True)
|
||||||
|
@@ -10,7 +10,8 @@ from django.utils import six
|
|||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post,
|
Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post,
|
||||||
Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel)
|
Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel,
|
||||||
|
Title, Base, SubBase)
|
||||||
|
|
||||||
|
|
||||||
class ModelInheritanceTests(TestCase):
|
class ModelInheritanceTests(TestCase):
|
||||||
@@ -357,3 +358,16 @@ class ModelInheritanceTests(TestCase):
|
|||||||
[Place.objects.get(pk=s.pk)],
|
[Place.objects.get(pk=s.pk)],
|
||||||
lambda x: x
|
lambda x: x
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_custompk_m2m(self):
|
||||||
|
b = Base.objects.create()
|
||||||
|
b.titles.add(Title.objects.create(title="foof"))
|
||||||
|
s = SubBase.objects.create(sub_id=b.id)
|
||||||
|
b = Base.objects.get(pk=s.id)
|
||||||
|
self.assertNotEqual(b.pk, s.pk)
|
||||||
|
# Low-level test for related_val
|
||||||
|
self.assertEqual(s.titles.related_val, (s.id,))
|
||||||
|
# Higher level test for correct query values (title foof not
|
||||||
|
# accidentally found).
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
s.titles.all(), [])
|
||||||
|
Reference in New Issue
Block a user