From 04a2a6b0f9cb6bb98edfe84bf4361216d60a4e38 Mon Sep 17 00:00:00 2001 From: Loic Bistuer Date: Thu, 19 Sep 2013 00:31:07 +0700 Subject: [PATCH] Fixed #3871 -- Custom managers when traversing reverse relations. --- django/contrib/contenttypes/generic.py | 17 +++ django/db/models/fields/related.py | 188 ++++++++++++++----------- docs/releases/1.7.txt | 6 + docs/topics/db/queries.txt | 25 ++++ tests/custom_managers/models.py | 20 +++ tests/custom_managers/tests.py | 97 +++++++++++-- 6 files changed, 262 insertions(+), 91 deletions(-) diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index 9db672c8a3..f3d66eaf31 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -319,6 +319,23 @@ def create_generic_related_manager(superclass): '%s__exact' % object_id_field_name: instance._get_pk_val(), } + def __call__(self, **kwargs): + # We use **kwargs rather than a kwarg argument to enforce the + # `manager='manager_name'` syntax. + manager = getattr(self.model, kwargs.pop('manager')) + manager_class = create_generic_related_manager(manager.__class__) + return manager_class( + model = self.model, + instance = self.instance, + symmetrical = self.symmetrical, + source_col_name = self.source_col_name, + target_col_name = self.target_col_name, + content_type = self.content_type, + content_type_field_name = self.content_type_field_name, + object_id_field_name = self.object_id_field_name, + prefetch_cache_name = self.prefetch_cache_name, + ) + def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index fd9e8fa4d8..23959f6b18 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -365,6 +365,92 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec setattr(value, self.field.related.get_cache_name(), instance) +def create_foreign_related_manager(superclass, rel_field, rel_model): + class RelatedManager(superclass): + def __init__(self, instance): + super(RelatedManager, self).__init__() + self.instance = instance + self.core_filters = {'%s__exact' % rel_field.name: instance} + self.model = rel_model + + def __call__(self, **kwargs): + # We use **kwargs rather than a kwarg argument to enforce the + # `manager='manager_name'` syntax. + manager = getattr(self.model, kwargs.pop('manager')) + manager_class = create_foreign_related_manager(manager.__class__, rel_field, rel_model) + return manager_class(self.instance) + + def get_queryset(self): + try: + return self.instance._prefetched_objects_cache[rel_field.related_query_name()] + except (AttributeError, KeyError): + db = self._db or router.db_for_read(self.model, instance=self.instance) + qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters) + empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls + for field in rel_field.foreign_related_fields: + val = getattr(self.instance, field.attname) + if val is None or (val == '' and empty_strings_as_null): + return qs.none() + qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}} + return qs + + def get_prefetch_queryset(self, instances): + rel_obj_attr = rel_field.get_local_related_value + instance_attr = rel_field.get_foreign_related_value + instances_dict = dict((instance_attr(inst), inst) for inst in instances) + db = self._db or router.db_for_read(self.model, instance=instances[0]) + query = {'%s__in' % rel_field.name: instances} + qs = super(RelatedManager, self).get_queryset().using(db).filter(**query) + # Since we just bypassed this class' get_queryset(), we must manage + # the reverse relation manually. + for rel_obj in qs: + instance = instances_dict[rel_obj_attr(rel_obj)] + setattr(rel_obj, rel_field.name, instance) + cache_name = rel_field.related_query_name() + return qs, rel_obj_attr, instance_attr, False, cache_name + + def add(self, *objs): + for obj in objs: + if not isinstance(obj, self.model): + raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) + setattr(obj, rel_field.name, self.instance) + obj.save() + add.alters_data = True + + def create(self, **kwargs): + kwargs[rel_field.name] = self.instance + db = router.db_for_write(self.model, instance=self.instance) + return super(RelatedManager, self.db_manager(db)).create(**kwargs) + create.alters_data = True + + def get_or_create(self, **kwargs): + # Update kwargs with the related object that this + # ForeignRelatedObjectsDescriptor knows about. + kwargs[rel_field.name] = self.instance + db = router.db_for_write(self.model, instance=self.instance) + return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs) + get_or_create.alters_data = True + + # remove() and clear() are only provided if the ForeignKey can have a value of null. + if rel_field.null: + def remove(self, *objs): + val = rel_field.get_foreign_related_value(self.instance) + for obj in objs: + # Is obj actually part of this descriptor set? + if rel_field.get_local_related_value(obj) == val: + setattr(obj, rel_field.name, None) + obj.save() + else: + raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance)) + remove.alters_data = True + + def clear(self): + self.update(**{rel_field.name: None}) + clear.alters_data = True + + return RelatedManager + + class ForeignRelatedObjectsDescriptor(object): # This class provides the functionality that makes the related-object # managers available as attributes on a model class, for fields that have @@ -392,86 +478,11 @@ class ForeignRelatedObjectsDescriptor(object): def related_manager_cls(self): # Dynamically create a class that subclasses the related model's default # manager. - superclass = self.related.model._default_manager.__class__ - rel_field = self.related.field - rel_model = self.related.model - - class RelatedManager(superclass): - def __init__(self, instance): - super(RelatedManager, self).__init__() - self.instance = instance - self.core_filters = {'%s__exact' % rel_field.name: instance} - self.model = rel_model - - def get_queryset(self): - try: - return self.instance._prefetched_objects_cache[rel_field.related_query_name()] - except (AttributeError, KeyError): - db = self._db or router.db_for_read(self.model, instance=self.instance) - qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters) - empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls - for field in rel_field.foreign_related_fields: - val = getattr(self.instance, field.attname) - if val is None or (val == '' and empty_strings_as_null): - return qs.none() - qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}} - return qs - - def get_prefetch_queryset(self, instances): - rel_obj_attr = rel_field.get_local_related_value - instance_attr = rel_field.get_foreign_related_value - instances_dict = dict((instance_attr(inst), inst) for inst in instances) - db = self._db or router.db_for_read(self.model, instance=instances[0]) - query = {'%s__in' % rel_field.name: instances} - qs = super(RelatedManager, self).get_queryset().using(db).filter(**query) - # Since we just bypassed this class' get_queryset(), we must manage - # the reverse relation manually. - for rel_obj in qs: - instance = instances_dict[rel_obj_attr(rel_obj)] - setattr(rel_obj, rel_field.name, instance) - cache_name = rel_field.related_query_name() - return qs, rel_obj_attr, instance_attr, False, cache_name - - def add(self, *objs): - for obj in objs: - if not isinstance(obj, self.model): - raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) - setattr(obj, rel_field.name, self.instance) - obj.save() - add.alters_data = True - - def create(self, **kwargs): - kwargs[rel_field.name] = self.instance - db = router.db_for_write(self.model, instance=self.instance) - return super(RelatedManager, self.db_manager(db)).create(**kwargs) - create.alters_data = True - - def get_or_create(self, **kwargs): - # Update kwargs with the related object that this - # ForeignRelatedObjectsDescriptor knows about. - kwargs[rel_field.name] = self.instance - db = router.db_for_write(self.model, instance=self.instance) - return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs) - get_or_create.alters_data = True - - # remove() and clear() are only provided if the ForeignKey can have a value of null. - if rel_field.null: - def remove(self, *objs): - val = rel_field.get_foreign_related_value(self.instance) - for obj in objs: - # Is obj actually part of this descriptor set? - if rel_field.get_local_related_value(obj) == val: - setattr(obj, rel_field.name, None) - obj.save() - else: - raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance)) - remove.alters_data = True - - def clear(self): - self.update(**{rel_field.name: None}) - clear.alters_data = True - - return RelatedManager + return create_foreign_related_manager( + self.related.model._default_manager.__class__, + self.related.field, + self.related.model, + ) def create_many_related_manager(superclass, rel): @@ -513,6 +524,23 @@ def create_many_related_manager(superclass, rel): "a many-to-many relationship can be used." % instance.__class__.__name__) + def __call__(self, **kwargs): + # We use **kwargs rather than a kwarg argument to enforce the + # `manager='manager_name'` syntax. + manager = getattr(self.model, kwargs.pop('manager')) + manager_class = create_many_related_manager(manager.__class__, rel) + return manager_class( + model=self.model, + query_field_name=self.query_field_name, + instance=self.instance, + symmetrical=self.symmetrical, + source_field_name=self.source_field_name, + target_field_name=self.target_field_name, + reverse=self.reverse, + through=self.through, + prefetch_cache_name=self.prefetch_cache_name, + ) + def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] diff --git a/docs/releases/1.7.txt b/docs/releases/1.7.txt index 924ec51ba9..6a29635d18 100644 --- a/docs/releases/1.7.txt +++ b/docs/releases/1.7.txt @@ -92,6 +92,12 @@ The :meth:`QuerySet.as_manager() ` class method has been added to :ref:`create Manager with QuerySet methods `. +Using a custom manager when traversing reverse relations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +It is now possible to :ref:`specify a custom manager +` when traversing a reverse relationship. + Admin shortcuts support time zones ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/topics/db/queries.txt b/docs/topics/db/queries.txt index 23f4ec7398..696bddac95 100644 --- a/docs/topics/db/queries.txt +++ b/docs/topics/db/queries.txt @@ -1136,6 +1136,31 @@ above example code would look like this:: >>> b.entries.filter(headline__contains='Lennon') >>> b.entries.count() +.. _using-custom-reverse-manager: + +Using a custom reverse manager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 1.7 + +By default the :class:`~django.db.models.fields.related.RelatedManager` used +for reverse relations is a subclass of the :ref:`default manager ` +for that model. If you would like to specify a different manager for a given +query you can use the following syntax:: + + from django.db import models + + class Entry(models.Model): + #... + objects = models.Manager() # Default Manager + entries = EntryManager() # Custom Manager + + >>> b = Blog.objects.get(id=1) + >>> b.entry_set(manager='entries').all() + +Additional methods to handle related objects +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + In addition to the :class:`~django.db.models.query.QuerySet` methods defined in "Retrieving objects" above, the :class:`~django.db.models.ForeignKey` :class:`~django.db.models.Manager` has additional methods used to handle the diff --git a/tests/custom_managers/models.py b/tests/custom_managers/models.py index cba375f4d7..26d848b7c0 100644 --- a/tests/custom_managers/models.py +++ b/tests/custom_managers/models.py @@ -11,6 +11,7 @@ returns. from __future__ import unicode_literals +from django.contrib.contenttypes import generic from django.db import models from django.utils.encoding import python_2_unicode_compatible @@ -63,12 +64,28 @@ class BaseCustomManager(models.Manager): CustomManager = BaseCustomManager.from_queryset(CustomQuerySet) +class FunPeopleManager(models.Manager): + def get_queryset(self): + return super(FunPeopleManager, self).get_queryset().filter(fun=True) + +class BoringPeopleManager(models.Manager): + def get_queryset(self): + return super(BoringPeopleManager, self).get_queryset().filter(fun=False) + @python_2_unicode_compatible class Person(models.Model): first_name = models.CharField(max_length=30) last_name = models.CharField(max_length=30) fun = models.BooleanField(default=False) + + favorite_book = models.ForeignKey('Book', null=True, related_name='favorite_books') + favorite_thing_type = models.ForeignKey('contenttypes.ContentType', null=True) + favorite_thing_id = models.IntegerField(null=True) + favorite_thing = generic.GenericForeignKey('favorite_thing_type', 'favorite_thing_id') + objects = PersonManager() + fun_people = FunPeopleManager() + boring_people = BoringPeopleManager() custom_queryset_default_manager = CustomQuerySet.as_manager() custom_queryset_custom_manager = CustomManager('hello') @@ -84,6 +101,9 @@ class Book(models.Model): published_objects = PublishedBookManager() authors = models.ManyToManyField(Person, related_name='books') + favorite_things = generic.GenericRelation(Person, + content_type_field='favorite_thing_type', object_id_field='favorite_thing_id') + def __str__(self): return self.title diff --git a/tests/custom_managers/tests.py b/tests/custom_managers/tests.py index ff14ad8439..f9a9f33d87 100644 --- a/tests/custom_managers/tests.py +++ b/tests/custom_managers/tests.py @@ -7,10 +7,15 @@ from .models import Person, Book, Car, PersonManager, PublishedBookManager class CustomManagerTests(TestCase): - def test_manager(self): - Person.objects.create(first_name="Bugs", last_name="Bunny", fun=True) - p2 = Person.objects.create(first_name="Droopy", last_name="Dog", fun=False) + def setUp(self): + self.b1 = Book.published_objects.create( + title="How to program", author="Rodney Dangerfield", is_published=True) + self.b2 = Book.published_objects.create( + title="How to be smart", author="Albert Einstein", is_published=False) + self.p1 = Person.objects.create(first_name="Bugs", last_name="Bunny", fun=True) + self.p2 = Person.objects.create(first_name="Droopy", last_name="Dog", fun=False) + def test_manager(self): # Test a custom `Manager` method. self.assertQuerysetEqual( Person.objects.get_fun_people(), [ @@ -61,14 +66,8 @@ class CustomManagerTests(TestCase): # The RelatedManager used on the 'books' descriptor extends the default # manager - self.assertIsInstance(p2.books, PublishedBookManager) + self.assertIsInstance(self.p2.books, PublishedBookManager) - Book.published_objects.create( - title="How to program", author="Rodney Dangerfield", is_published=True - ) - b2 = Book.published_objects.create( - title="How to be smart", author="Albert Einstein", is_published=False - ) # The default manager, "objects", doesn't exist, because a custom one # was provided. @@ -76,7 +75,7 @@ class CustomManagerTests(TestCase): # The RelatedManager used on the 'authors' descriptor extends the # default manager - self.assertIsInstance(b2.authors, PersonManager) + self.assertIsInstance(self.b2.authors, PersonManager) self.assertQuerysetEqual( Book.published_objects.all(), [ @@ -114,3 +113,79 @@ class CustomManagerTests(TestCase): ], lambda c: c.name ) + + def test_related_manager_fk(self): + self.p1.favorite_book = self.b1 + self.p1.save() + self.p2.favorite_book = self.b1 + self.p2.save() + + self.assertQuerysetEqual( + self.b1.favorite_books.order_by('first_name').all(), [ + "Bugs", + "Droopy", + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.favorite_books(manager='boring_people').all(), [ + "Droopy", + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.favorite_books(manager='fun_people').all(), [ + "Bugs", + ], + lambda c: c.first_name + ) + + def test_related_manager_gfk(self): + self.p1.favorite_thing = self.b1 + self.p1.save() + self.p2.favorite_thing = self.b1 + self.p2.save() + + self.assertQuerysetEqual( + self.b1.favorite_things.order_by('first_name').all(), [ + "Bugs", + "Droopy", + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.favorite_things(manager='boring_people').all(), [ + "Droopy", + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.favorite_things(manager='fun_people').all(), [ + "Bugs", + ], + lambda c: c.first_name + ) + + def test_related_manager_m2m(self): + self.b1.authors.add(self.p1) + self.b1.authors.add(self.p2) + + self.assertQuerysetEqual( + self.b1.authors.order_by('first_name').all(), [ + "Bugs", + "Droopy", + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.authors(manager='boring_people').all(), [ + "Droopy", + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.authors(manager='fun_people').all(), [ + "Bugs", + ], + lambda c: c.first_name + )