From 056ace0f395a58eeac03da9f9ee7e3872e1e407b Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Wed, 2 Jan 2013 21:42:52 +0100
Subject: [PATCH] [1.5.x] Fixed #19547 -- Caching of related instances.

When &'ing or |'ing querysets, wrong values could be cached, and crashes
could happen.

Thanks Marc Tamlyn for figuring out the problem and writing the patch.

Backport of 07fbc6a.
---
 django/db/models/fields/related.py            |  2 +-
 django/db/models/query.py                     | 29 +++++++++----
 .../fixtures/tournament.json                  | 11 +++++
 .../known_related_objects/models.py           |  4 ++
 .../modeltests/known_related_objects/tests.py | 42 ++++++++++++++++++-
 5 files changed, 79 insertions(+), 9 deletions(-)

diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index e2947b0093..4638a981bc 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -497,7 +497,7 @@ class ForeignRelatedObjectsDescriptor(object):
                 except (AttributeError, KeyError):
                     db = self._db or router.db_for_read(self.model, instance=self.instance)
                     qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
-                    qs._known_related_object = (rel_field.name, self.instance)
+                    qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
                     return qs
 
             def get_prefetch_query_set(self, instances):
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 9d0fbc02dc..d1388a5c80 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -44,7 +44,7 @@ class QuerySet(object):
         self._for_write = False
         self._prefetch_related_lookups = []
         self._prefetch_done = False
-        self._known_related_object = None       # (attname, rel_obj)
+        self._known_related_objects = {}        # {rel_field, {pk: rel_obj}}
 
     ########################
     # PYTHON MAGIC METHODS #
@@ -221,6 +221,7 @@ class QuerySet(object):
         if isinstance(other, EmptyQuerySet):
             return other._clone()
         combined = self._clone()
+        combined._merge_known_related_objects(other)
         combined.query.combine(other.query, sql.AND)
         return combined
 
@@ -229,6 +230,7 @@ class QuerySet(object):
         combined = self._clone()
         if isinstance(other, EmptyQuerySet):
             return combined
+        combined._merge_known_related_objects(other)
         combined.query.combine(other.query, sql.OR)
         return combined
 
@@ -289,10 +291,9 @@ class QuerySet(object):
                     init_list.append(field.attname)
             model_cls = deferred_class_factory(self.model, skip)
 
-        # Cache db, model and known_related_object outside the loop
+        # Cache db and model outside the loop
         db = self.db
         model = self.model
-        kro_attname, kro_instance = self._known_related_object or (None, None)
         compiler = self.query.get_compiler(using=db)
         if fill_cache:
             klass_info = get_klass_info(model, max_depth=max_depth,
@@ -323,9 +324,16 @@ class QuerySet(object):
                 for i, aggregate in enumerate(aggregate_select):
                     setattr(obj, aggregate, row[i + aggregate_start])
 
-            # Add the known related object to the model, if there is one
-            if kro_instance:
-                setattr(obj, kro_attname, kro_instance)
+            # Add the known related objects to the model, if there are any
+            if self._known_related_objects:
+                for field, rel_objs in self._known_related_objects.items():
+                    pk = getattr(obj, field.get_attname())
+                    try:
+                        rel_obj = rel_objs[pk]
+                    except KeyError:
+                        pass               # may happen in qs1 | qs2 scenarios
+                    else:
+                        setattr(obj, field.name, rel_obj)
 
             yield obj
 
@@ -902,7 +910,7 @@ class QuerySet(object):
         c = klass(model=self.model, query=query, using=self._db)
         c._for_write = self._for_write
         c._prefetch_related_lookups = self._prefetch_related_lookups[:]
-        c._known_related_object = self._known_related_object
+        c._known_related_objects = self._known_related_objects
         c.__dict__.update(kwargs)
         if setup and hasattr(c, '_setup_query'):
             c._setup_query()
@@ -942,6 +950,13 @@ class QuerySet(object):
         """
         pass
 
+    def _merge_known_related_objects(self, other):
+        """
+        Keep track of all known related objects from either QuerySet instance.
+        """
+        for field, objects in other._known_related_objects.items():
+            self._known_related_objects.setdefault(field, {}).update(objects)
+
     def _setup_aggregate_query(self, aggregates):
         """
         Prepare the query for computing a result that contains aggregate annotations.
diff --git a/tests/modeltests/known_related_objects/fixtures/tournament.json b/tests/modeltests/known_related_objects/fixtures/tournament.json
index 2f2b1c5627..b8f053e152 100644
--- a/tests/modeltests/known_related_objects/fixtures/tournament.json
+++ b/tests/modeltests/known_related_objects/fixtures/tournament.json
@@ -13,11 +13,19 @@
             "name": "Tourney 2"
             }
         },
+    {
+        "pk": 1,
+        "model": "known_related_objects.organiser",
+        "fields": {
+            "name": "Organiser 1"
+            }
+        },
     {
         "pk": 1,
         "model": "known_related_objects.pool",
         "fields": {
             "tournament": 1,
+            "organiser": 1,
             "name": "T1 Pool 1"
             }
         },
@@ -26,6 +34,7 @@
         "model": "known_related_objects.pool",
         "fields": {
             "tournament": 1,
+            "organiser": 1,
             "name": "T1 Pool 2"
             }
         },
@@ -34,6 +43,7 @@
         "model": "known_related_objects.pool",
         "fields": {
             "tournament": 2,
+            "organiser": 1,
             "name": "T2 Pool 1"
             }
         },
@@ -42,6 +52,7 @@
         "model": "known_related_objects.pool",
         "fields": {
             "tournament": 2,
+            "organiser": 1,
             "name": "T2 Pool 2"
             }
         },
diff --git a/tests/modeltests/known_related_objects/models.py b/tests/modeltests/known_related_objects/models.py
index 4c516dd7e8..e256cc38f2 100644
--- a/tests/modeltests/known_related_objects/models.py
+++ b/tests/modeltests/known_related_objects/models.py
@@ -9,9 +9,13 @@ from django.db import models
 class Tournament(models.Model):
     name = models.CharField(max_length=30)
 
+class Organiser(models.Model):
+    name = models.CharField(max_length=30)
+
 class Pool(models.Model):
     name = models.CharField(max_length=30)
     tournament = models.ForeignKey(Tournament)
+    organiser = models.ForeignKey(Organiser)
 
 class PoolStyle(models.Model):
     name = models.CharField(max_length=30)
diff --git a/tests/modeltests/known_related_objects/tests.py b/tests/modeltests/known_related_objects/tests.py
index 24feab2241..2371ac2e20 100644
--- a/tests/modeltests/known_related_objects/tests.py
+++ b/tests/modeltests/known_related_objects/tests.py
@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 from django.test import TestCase
 
-from .models import Tournament, Pool, PoolStyle
+from .models import Tournament, Organiser, Pool, PoolStyle
 
 class ExistingRelatedInstancesTests(TestCase):
     fixtures = ['tournament.json']
@@ -27,6 +27,46 @@ class ExistingRelatedInstancesTests(TestCase):
             pool2 = tournaments[1].pool_set.all()[0]
             self.assertIs(tournaments[1], pool2.tournament)
 
+    def test_queryset_or(self):
+        tournament_1 = Tournament.objects.get(pk=1)
+        tournament_2 = Tournament.objects.get(pk=2)
+        with self.assertNumQueries(1):
+            pools = tournament_1.pool_set.all() | tournament_2.pool_set.all()
+            related_objects = set(pool.tournament for pool in pools)
+            self.assertEqual(related_objects, set((tournament_1, tournament_2)))
+
+    def test_queryset_or_different_cached_items(self):
+        tournament = Tournament.objects.get(pk=1)
+        organiser = Organiser.objects.get(pk=1)
+        with self.assertNumQueries(1):
+            pools = tournament.pool_set.all() | organiser.pool_set.all()
+            first = pools.filter(pk=1)[0]
+            self.assertIs(first.tournament, tournament)
+            self.assertIs(first.organiser, organiser)
+
+    def test_queryset_or_only_one_with_precache(self):
+        tournament_1 = Tournament.objects.get(pk=1)
+        tournament_2 = Tournament.objects.get(pk=2)
+        # 2 queries here as pool id 3 has tournament 2, which is not cached
+        with self.assertNumQueries(2):
+            pools = tournament_1.pool_set.all() | Pool.objects.filter(pk=3)
+            related_objects = set(pool.tournament for pool in pools)
+            self.assertEqual(related_objects, set((tournament_1, tournament_2)))
+        # and the other direction
+        with self.assertNumQueries(2):
+            pools = Pool.objects.filter(pk=3) | tournament_1.pool_set.all()
+            related_objects = set(pool.tournament for pool in pools)
+            self.assertEqual(related_objects, set((tournament_1, tournament_2)))
+
+    def test_queryset_and(self):
+        tournament = Tournament.objects.get(pk=1)
+        organiser = Organiser.objects.get(pk=1)
+        with self.assertNumQueries(1):
+            pools = tournament.pool_set.all() & organiser.pool_set.all()
+            first = pools.filter(pk=1)[0]
+            self.assertIs(first.tournament, tournament)
+            self.assertIs(first.organiser, organiser)
+
     def test_one_to_one(self):
         with self.assertNumQueries(2):
             style = PoolStyle.objects.get(pk=1)