From 534d8d875eebac6aee278f0ffa6cc59760dac546 Mon Sep 17 00:00:00 2001
From: Adnan Umer <u.adnan@outlook.com>
Date: Wed, 18 Apr 2018 22:30:25 +0500
Subject: [PATCH] Fixed #28600 -- Added prefetch_related() support to
 RawQuerySet.

---
 django/db/models/query.py       | 30 ++++++++++++++++++++++++++-
 docs/releases/2.1.txt           |  2 ++
 tests/prefetch_related/tests.py | 36 ++++++++++++++++++++++++++++++++-
 3 files changed, 66 insertions(+), 2 deletions(-)

diff --git a/django/db/models/query.py b/django/db/models/query.py
index f16af1e91d..8d4c2d083c 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -727,7 +727,9 @@ class QuerySet:
     def raw(self, raw_query, params=None, translations=None, using=None):
         if using is None:
             using = self.db
-        return RawQuerySet(raw_query, model=self.model, params=params, translations=translations, using=using)
+        qs = RawQuerySet(raw_query, model=self.model, params=params, translations=translations, using=using)
+        qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
+        return qs
 
     def _values(self, *fields, **expressions):
         clone = self._chain()
@@ -1278,6 +1280,8 @@ class RawQuerySet:
         self.params = params or ()
         self.translations = translations or {}
         self._result_cache = None
+        self._prefetch_related_lookups = ()
+        self._prefetch_done = False
 
     def resolve_model_init_order(self):
         """Resolve the init field names and value positions."""
@@ -1289,9 +1293,33 @@ class RawQuerySet:
         model_init_names = [f.attname for f in model_init_fields]
         return model_init_names, model_init_order, annotation_fields
 
+    def prefetch_related(self, *lookups):
+        """Same as QuerySet.prefetch_related()"""
+        clone = self._clone()
+        if lookups == (None,):
+            clone._prefetch_related_lookups = ()
+        else:
+            clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
+        return clone
+
+    def _prefetch_related_objects(self):
+        prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
+        self._prefetch_done = True
+
+    def _clone(self):
+        """Same as QuerySet._clone()"""
+        c = self.__class__(
+            self.raw_query, model=self.model, query=self.query, params=self.params,
+            translations=self.translations, using=self._db, hints=self._hints
+        )
+        c._prefetch_related_lookups = self._prefetch_related_lookups[:]
+        return c
+
     def _fetch_all(self):
         if self._result_cache is None:
             self._result_cache = list(self.iterator())
+        if self._prefetch_related_lookups and not self._prefetch_done:
+            self._prefetch_related_objects()
 
     def __len__(self):
         self._fetch_all()
diff --git a/docs/releases/2.1.txt b/docs/releases/2.1.txt
index f3a99e5fc0..7595be46c0 100644
--- a/docs/releases/2.1.txt
+++ b/docs/releases/2.1.txt
@@ -239,6 +239,8 @@ Models
 * The new :meth:`.QuerySet.explain` method displays the database's execution
   plan of a queryset's query.
 
+*  :meth:`.QuerySet.raw` now supports :meth:`~.QuerySet.prefetch_related`.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py
index e92d7f349f..5a701bffec 100644
--- a/tests/prefetch_related/tests.py
+++ b/tests/prefetch_related/tests.py
@@ -14,7 +14,7 @@ from .models import (
 )
 
 
-class PrefetchRelatedTests(TestCase):
+class TestDataMixin:
     @classmethod
     def setUpTestData(cls):
         cls.book1 = Book.objects.create(title='Poems')
@@ -38,6 +38,8 @@ class PrefetchRelatedTests(TestCase):
         cls.reader1.books_read.add(cls.book1, cls.book4)
         cls.reader2.books_read.add(cls.book2, cls.book4)
 
+
+class PrefetchRelatedTests(TestDataMixin, TestCase):
     def assertWhereContains(self, sql, needle):
         where_idx = sql.index('WHERE')
         self.assertEqual(
@@ -281,6 +283,38 @@ class PrefetchRelatedTests(TestCase):
         self.assertWhereContains(sql, self.author1.id)
 
 
+class RawQuerySetTests(TestDataMixin, TestCase):
+    def test_basic(self):
+        with self.assertNumQueries(2):
+            books = Book.objects.raw(
+                "SELECT * FROM prefetch_related_book WHERE id = %s",
+                (self.book1.id,)
+            ).prefetch_related('authors')
+            book1 = list(books)[0]
+
+        with self.assertNumQueries(0):
+            self.assertCountEqual(book1.authors.all(), [self.author1, self.author2, self.author3])
+
+    def test_prefetch_before_raw(self):
+        with self.assertNumQueries(2):
+            books = Book.objects.prefetch_related('authors').raw(
+                "SELECT * FROM prefetch_related_book WHERE id = %s",
+                (self.book1.id,)
+            )
+            book1 = list(books)[0]
+
+        with self.assertNumQueries(0):
+            self.assertCountEqual(book1.authors.all(), [self.author1, self.author2, self.author3])
+
+    def test_clear(self):
+        with self.assertNumQueries(5):
+            with_prefetch = Author.objects.raw(
+                "SELECT * FROM prefetch_related_author"
+            ).prefetch_related('books')
+            without_prefetch = with_prefetch.prefetch_related(None)
+            [list(a.books.all()) for a in without_prefetch]
+
+
 class CustomPrefetchTests(TestCase):
     @classmethod
     def traverse_qs(cls, obj_iter, path):