From abd0ad7681422d7c40a5ed12cc3c9ffca6b88422 Mon Sep 17 00:00:00 2001
From: oliver <myungsekyo@gmail.com>
Date: Thu, 2 Aug 2018 14:20:46 +0900
Subject: [PATCH] Fixed #29626, #29584 -- Added optimized versions of
 get_many() and delete_many() for the db cache backend.

---
 django/core/cache/backends/db.py | 86 ++++++++++++++++++++------------
 tests/cache/tests.py             | 14 ++++++
 2 files changed, 67 insertions(+), 33 deletions(-)

diff --git a/django/core/cache/backends/db.py b/django/core/cache/backends/db.py
index 76aff9c582..21b5aa88ad 100644
--- a/django/core/cache/backends/db.py
+++ b/django/core/cache/backends/db.py
@@ -49,8 +49,17 @@ class DatabaseCache(BaseDatabaseCache):
     pickle_protocol = pickle.HIGHEST_PROTOCOL
 
     def get(self, key, default=None, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        return self.get_many([key], version).get(key, default)
+
+    def get_many(self, keys, version=None):
+        if not keys:
+            return {}
+
+        key_map = {}
+        for key in keys:
+            self.validate_key(key)
+            key_map[self.make_key(key, version)] = key
+
         db = router.db_for_read(self.cache_model_class)
         connection = connections[db]
         quote_name = connection.ops.quote_name
@@ -58,43 +67,36 @@ class DatabaseCache(BaseDatabaseCache):
 
         with connection.cursor() as cursor:
             cursor.execute(
-                'SELECT %s, %s, %s FROM %s WHERE %s = %%s' % (
+                'SELECT %s, %s, %s FROM %s WHERE %s IN (%s)' % (
                     quote_name('cache_key'),
                     quote_name('value'),
                     quote_name('expires'),
                     table,
                     quote_name('cache_key'),
+                    ', '.join(['%s'] * len(key_map)),
                 ),
-                [key]
+                list(key_map),
             )
-            row = cursor.fetchone()
-        if row is None:
-            return default
+            rows = cursor.fetchall()
 
-        expires = row[2]
+        result = {}
+        expired_keys = []
         expression = models.Expression(output_field=models.DateTimeField())
-        for converter in (connection.ops.get_db_converters(expression) +
-                          expression.get_db_converters(connection)):
-            if func_supports_parameter(converter, 'context'):  # RemovedInDjango30Warning
-                expires = converter(expires, expression, connection, {})
+        converters = (connection.ops.get_db_converters(expression) + expression.get_db_converters(connection))
+        for key, value, expires in rows:
+            for converter in converters:
+                if func_supports_parameter(converter, 'context'):  # RemovedInDjango30Warning
+                    expires = converter(expires, expression, connection, {})
+                else:
+                    expires = converter(expires, expression, connection)
+            if expires < timezone.now():
+                expired_keys.append(key)
             else:
-                expires = converter(expires, expression, connection)
-
-        if expires < timezone.now():
-            db = router.db_for_write(self.cache_model_class)
-            connection = connections[db]
-            with connection.cursor() as cursor:
-                cursor.execute(
-                    'DELETE FROM %s WHERE %s = %%s' % (
-                        table,
-                        quote_name('cache_key'),
-                    ),
-                    [key]
-                )
-            return default
-
-        value = connection.ops.process_clob(row[1])
-        return pickle.loads(base64.b64decode(value.encode()))
+                value = connection.ops.process_clob(value)
+                value = pickle.loads(base64.b64decode(value.encode()))
+                result[key_map.get(key)] = value
+        self._base_delete_many(expired_keys)
+        return result
 
     def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
         key = self.make_key(key, version=version)
@@ -202,15 +204,33 @@ class DatabaseCache(BaseDatabaseCache):
                 return True
 
     def delete(self, key, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        self.delete_many([key], version)
+
+    def delete_many(self, keys, version=None):
+        key_list = []
+        for key in keys:
+            self.validate_key(key)
+            key_list.append(self.make_key(key, version))
+        self._base_delete_many(key_list)
+
+    def _base_delete_many(self, keys):
+        if not keys:
+            return
 
         db = router.db_for_write(self.cache_model_class)
         connection = connections[db]
-        table = connection.ops.quote_name(self._table)
+        quote_name = connection.ops.quote_name
+        table = quote_name(self._table)
 
         with connection.cursor() as cursor:
-            cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
+            cursor.execute(
+                'DELETE FROM %s WHERE %s IN (%s)' % (
+                    table,
+                    quote_name('cache_key'),
+                    ', '.join(['%s'] * len(keys)),
+                ),
+                keys,
+            )
 
     def has_key(self, key, version=None):
         key = self.make_key(key, version=version)
diff --git a/tests/cache/tests.py b/tests/cache/tests.py
index a101639f49..6578eb288f 100644
--- a/tests/cache/tests.py
+++ b/tests/cache/tests.py
@@ -1005,6 +1005,20 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
             table_name = connection.ops.quote_name('test cache table')
             cursor.execute('DROP TABLE %s' % table_name)
 
+    def test_get_many_num_queries(self):
+        cache.set_many({'a': 1, 'b': 2})
+        cache.set('expired', 'expired', 0.01)
+        with self.assertNumQueries(1):
+            self.assertEqual(cache.get_many(['a', 'b']), {'a': 1, 'b': 2})
+        time.sleep(0.02)
+        with self.assertNumQueries(2):
+            self.assertEqual(cache.get_many(['a', 'b', 'expired']), {'a': 1, 'b': 2})
+
+    def test_delete_many_num_queries(self):
+        cache.set_many({'a': 1, 'b': 2, 'c': 3})
+        with self.assertNumQueries(1):
+            cache.delete_many(['a', 'b', 'c'])
+
     def test_zero_cull(self):
         self._perform_cull_test(caches['zero_cull'], 50, 18)