diff --git a/django/core/cache/backends/redis.py b/django/core/cache/backends/redis.py index eda8ac9457..c47b8efcd3 100644 --- a/django/core/cache/backends/redis.py +++ b/django/core/cache/backends/redis.py @@ -4,6 +4,8 @@ import pickle import random import re +from redis.connection import parse_url + from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache from django.utils.functional import cached_property from django.utils.module_loading import import_string @@ -28,6 +30,8 @@ class RedisSerializer: class RedisCacheClient: + _pools = {} + def __init__( self, servers, @@ -40,7 +44,6 @@ class RedisCacheClient: self._lib = redis self._servers = servers - self._pools = {} self._client = self._lib.Redis @@ -69,11 +72,16 @@ class RedisCacheClient: def _get_connection_pool(self, write): index = self._get_connection_pool_index(write) - if index not in self._pools: - self._pools[index] = self._pool_class.from_url( - self._servers[index], + if index in self._pools: + if self._pools[index].connection_kwargs == { **self._pool_options, - ) + **parse_url(self._servers[index]), + }: + return self._pools[index] + self._pools[index] = self._pool_class.from_url( + self._servers[index], + **self._pool_options, + ) return self._pools[index] def get_client(self, key=None, *, write=False): diff --git a/tests/cache/tests.py b/tests/cache/tests.py index 978efdd9d3..57c9d07f66 100644 --- a/tests/cache/tests.py +++ b/tests/cache/tests.py @@ -1870,6 +1870,27 @@ class RedisCacheTests(BaseCacheTests, TestCase): self.assertEqual(pool.connection_kwargs["socket_timeout"], 0.1) self.assertIs(pool.connection_kwargs["retry_on_timeout"], True) + def test_redis_common_pool(self): + pool1 = cache._cache._get_connection_pool(write=False) + pool2 = cache._cache._get_connection_pool(write=False) + self.assertEqual(id(pool1), id(pool2)) + + @override_settings( + CACHES=caches_setting_for_tests( + base=RedisCache_params, + exclude=redis_excluded_caches, + OPTIONS={ + "db": 6, + "socket_timeout": 0.1, + "retry_on_timeout": True, + }, + ) + ) + def get_connection_pool(): + return cache._cache._get_connection_pool(write=False) + + self.assertNotEqual(id(pool1), id(get_connection_pool())) + class FileBasedCachePathLibTests(FileBasedCacheTests): def mkdtemp(self):