From 5009e45dfe366ebf520eecb49d72d5a11cee1623 Mon Sep 17 00:00:00 2001
From: Luke Plant <L.Plant.98@cantab.net>
Date: Fri, 30 Sep 2011 10:41:25 +0000
Subject: [PATCH] Fixed #14270 - related manager classes should be cached

Thanks to Alex Gaynor for the report and initial patch, and mrmachine for
more work on it.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16916 bcc190cf-cafb-0310-a4f2-bffc1f526a37
---
 django/db/models/fields/related.py         | 89 ++++++++++++----------
 django/utils/functional.py                 | 14 +++-
 tests/modeltests/many_to_one/tests.py      | 10 +++
 tests/regressiontests/m2m_regress/tests.py | 14 ++++
 4 files changed, 87 insertions(+), 40 deletions(-)

diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index eee2ecf4d6..95c4cac253 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -9,7 +9,7 @@ from django.db.models.query_utils import QueryWrapper
 from django.db.models.deletion import CASCADE
 from django.utils.encoding import smart_unicode
 from django.utils.translation import ugettext_lazy as _, string_concat
-from django.utils.functional import curry
+from django.utils.functional import curry, memoize, cached_property
 from django.core import exceptions
 from django import forms
 
@@ -386,8 +386,8 @@ class ForeignRelatedObjectsDescriptor(object):
         if instance is None:
             return self
 
-        return self.create_manager(instance,
-                self.related.model._default_manager.__class__)
+        manager = self.create_manager(self.related.model._default_manager.__class__)
+        return manager(instance)
 
     def __set__(self, instance, value):
         if instance is None:
@@ -406,22 +406,30 @@ class ForeignRelatedObjectsDescriptor(object):
         than the default manager, as returned by __get__). Used by
         Model.delete().
         """
-        return self.create_manager(instance,
-                self.related.model._base_manager.__class__)
+        manager = self.create_manager(self.related.model._base_manager.__class__)
+        return manager(instance)
 
-    def create_manager(self, instance, superclass):
+    def create_manager(self, superclass):
         """
         Creates the managers used by other methods (__get__() and delete()).
         """
+
+        # We use closures for these values so that we only need to memoize this
+        # function on the one argument of 'superclass', and the two places that
+        # call create_manager simply need to pass instance to the manager
+        # __init__
         rel_field = self.related.field
+        rel_model = self.related.model
+        attname = rel_field.rel.get_related_field().attname
+
         class RelatedManager(superclass):
-            def __init__(self, model=None, core_filters=None, instance=None,
-                         rel_field=None):
+            def __init__(self, instance):
                 super(RelatedManager, self).__init__()
-                self.model = model
-                self.core_filters = core_filters
                 self.instance = instance
-                self.rel_field = rel_field
+                self.core_filters = {
+                    '%s__%s' % (rel_field.name, attname): getattr(instance, attname)
+                }
+                self.model = rel_model
 
             def get_query_set(self):
                 db = self._db or router.db_for_read(self.model, instance=self.instance)
@@ -431,12 +439,12 @@ class ForeignRelatedObjectsDescriptor(object):
                 for obj in objs:
                     if not isinstance(obj, self.model):
                         raise TypeError("'%s' instance expected" % self.model._meta.object_name)
-                    setattr(obj, self.rel_field.name, self.instance)
+                    setattr(obj, rel_field.name, self.instance)
                     obj.save()
             add.alters_data = True
 
             def create(self, **kwargs):
-                kwargs[self.rel_field.name] = self.instance
+                kwargs[rel_field.name] = self.instance
                 db = router.db_for_write(self.model, instance=self.instance)
                 return super(RelatedManager, self.db_manager(db)).create(**kwargs)
             create.alters_data = True
@@ -444,7 +452,7 @@ class ForeignRelatedObjectsDescriptor(object):
             def get_or_create(self, **kwargs):
                 # Update kwargs with the related object that this
                 # ForeignRelatedObjectsDescriptor knows about.
-                kwargs[self.rel_field.name] = self.instance
+                kwargs[rel_field.name] = self.instance
                 db = router.db_for_write(self.model, instance=self.instance)
                 return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
             get_or_create.alters_data = True
@@ -452,27 +460,22 @@ class ForeignRelatedObjectsDescriptor(object):
             # remove() and clear() are only provided if the ForeignKey can have a value of null.
             if rel_field.null:
                 def remove(self, *objs):
-                    val = getattr(self.instance, self.rel_field.rel.get_related_field().attname)
+                    val = getattr(self.instance, attname)
                     for obj in objs:
                         # Is obj actually part of this descriptor set?
-                        if getattr(obj, self.rel_field.attname) == val:
-                            setattr(obj, self.rel_field.name, None)
+                        if getattr(obj, rel_field.attname) == val:
+                            setattr(obj, rel_field.name, None)
                             obj.save()
                         else:
-                            raise self.rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
+                            raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
                 remove.alters_data = True
 
                 def clear(self):
-                    self.update(**{self.rel_field.name: None})
+                    self.update(**{rel_field.name: None})
                 clear.alters_data = True
 
-        attname = rel_field.rel.get_related_field().name
-        return RelatedManager(model=self.related.model,
-                              core_filters = {'%s__%s' % (rel_field.name, attname):
-                                                  getattr(instance, attname)},
-                              instance=instance,
-                              rel_field=rel_field,
-                              )
+        return RelatedManager
+    create_manager = memoize(create_manager, {}, 2)
 
 
 def create_many_related_manager(superclass, rel):
@@ -663,17 +666,22 @@ class ManyRelatedObjectsDescriptor(object):
     def __init__(self, related):
         self.related = related   # RelatedObject instance
 
+    @cached_property
+    def related_manager_cls(self):
+        # Dynamically create a class that subclasses the related
+        # model's default manager.
+        return create_many_related_manager(
+            self.related.model._default_manager.__class__,
+            self.related.field.rel
+        )
+
     def __get__(self, instance, instance_type=None):
         if instance is None:
             return self
 
-        # Dynamically create a class that subclasses the related
-        # model's default manager.
         rel_model = self.related.model
-        superclass = rel_model._default_manager.__class__
-        RelatedManager = create_many_related_manager(superclass, self.related.field.rel)
 
-        manager = RelatedManager(
+        manager = self.related_manager_cls(
             model=rel_model,
             core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()},
             instance=instance,
@@ -716,18 +724,21 @@ class ReverseManyRelatedObjectsDescriptor(object):
         # a property to ensure that the fully resolved value is returned.
         return self.field.rel.through
 
+    @cached_property
+    def related_manager_cls(self):
+        # Dynamically create a class that subclasses the related model's
+        # default manager.
+        return create_many_related_manager(
+            self.field.rel.to._default_manager.__class__,
+            self.field.rel
+        )
+
     def __get__(self, instance, instance_type=None):
         if instance is None:
             return self
 
-        # Dynamically create a class that subclasses the related
-        # model's default manager.
-        rel_model=self.field.rel.to
-        superclass = rel_model._default_manager.__class__
-        RelatedManager = create_many_related_manager(superclass, self.field.rel)
-
-        manager = RelatedManager(
-            model=rel_model,
+        manager = self.related_manager_cls(
+            model=self.field.rel.to,
             core_filters={'%s__pk' % self.field.related_query_name(): instance._get_pk_val()},
             instance=instance,
             symmetrical=self.field.rel.symmetrical,
diff --git a/django/utils/functional.py b/django/utils/functional.py
index 1345d3b005..67b727f012 100644
--- a/django/utils/functional.py
+++ b/django/utils/functional.py
@@ -28,6 +28,18 @@ def memoize(func, cache, num_args):
         return result
     return wrapper
 
+class cached_property(object):
+    """
+    Decorator that creates converts a method with a single
+    self argument into a property cached on the instance.
+    """
+    def __init__(self, func):
+        self.func = func
+
+    def __get__(self, instance, type):
+        res = instance.__dict__[self.func.__name__] = self.func(instance)
+        return res
+
 class Promise(object):
     """
     This is just a base class for the proxy class created in
@@ -288,4 +300,4 @@ def partition(predicate, values):
     results = ([], [])
     for item in values:
         results[predicate(item)].append(item)
-    return results
\ No newline at end of file
+    return results
diff --git a/tests/modeltests/many_to_one/tests.py b/tests/modeltests/many_to_one/tests.py
index 9f60c21f47..4f561b40ca 100644
--- a/tests/modeltests/many_to_one/tests.py
+++ b/tests/modeltests/many_to_one/tests.py
@@ -399,3 +399,13 @@ class ManyToOneTests(TestCase):
         self.assertEqual(repr(a3),
                          repr(Article.objects.get(reporter_id=self.r2.id,
                                              pub_date=datetime(2011, 5, 7))))
+
+    def test_manager_class_caching(self):
+        r1 = Reporter.objects.create(first_name='Mike')
+        r2 = Reporter.objects.create(first_name='John')
+
+        # Same twice
+        self.assertTrue(r1.article_set.__class__ is r1.article_set.__class__)
+
+        # Same as each other
+        self.assertTrue(r1.article_set.__class__ is r2.article_set.__class__)
diff --git a/tests/regressiontests/m2m_regress/tests.py b/tests/regressiontests/m2m_regress/tests.py
index 7bf2381a91..9ae888a6ce 100644
--- a/tests/regressiontests/m2m_regress/tests.py
+++ b/tests/regressiontests/m2m_regress/tests.py
@@ -73,3 +73,17 @@ class M2MRegressionTests(TestCase):
 
         self.assertQuerysetEqual(c1.tags.all(), ["<Tag: t1>", "<Tag: t2>"])
         self.assertQuerysetEqual(t1.tag_collections.all(), ["<TagCollection: c1>"])
+
+    def test_manager_class_caching(self):
+        e1 = Entry.objects.create()
+        e2 = Entry.objects.create()
+        t1 = Tag.objects.create()
+        t2 = Tag.objects.create()
+
+        # Get same manager twice in a row:
+        self.assertTrue(t1.entry_set.__class__ is t1.entry_set.__class__)
+        self.assertTrue(e1.topics.__class__ is e1.topics.__class__)
+
+        # Get same manager for different instances
+        self.assertTrue(e1.topics.__class__ is e2.topics.__class__)
+        self.assertTrue(t1.entry_set.__class__ is t2.entry_set.__class__)