From 0d9872fc9a70ef6966930c68c68febea7eb60ede Mon Sep 17 00:00:00 2001
From: suraj <suraj.shaw@oracle.com>
Date: Tue, 10 Sep 2024 20:56:16 +0530
Subject: [PATCH] Fixed #7732 -- Added support for connection pools on Oracle.

---
 django/db/backends/oracle/base.py     | 52 ++++++++++++++++
 django/db/backends/oracle/creation.py | 15 +++--
 django/db/backends/oracle/features.py | 19 ++++++
 docs/ref/databases.txt                | 42 ++++++++++++-
 docs/releases/5.2.txt                 |  6 ++
 tests/backends/oracle/tests.py        | 88 ++++++++++++++++++++++++++-
 tests/requirements/oracle.txt         |  2 +-
 7 files changed, 215 insertions(+), 9 deletions(-)

diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py
index df78d9ba27..3b37c38f97 100644
--- a/django/db/backends/oracle/base.py
+++ b/django/db/backends/oracle/base.py
@@ -14,6 +14,7 @@ from django.conf import settings
 from django.core.exceptions import ImproperlyConfigured
 from django.db import IntegrityError
 from django.db.backends.base.base import BaseDatabaseWrapper
+from django.db.backends.oracle.oracledb_any import is_oracledb
 from django.db.backends.utils import debug_transaction
 from django.utils.asyncio import async_unsafe
 from django.utils.encoding import force_bytes, force_str
@@ -235,6 +236,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
     introspection_class = DatabaseIntrospection
     ops_class = DatabaseOperations
     validation_class = DatabaseValidation
+    _connection_pools = {}
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -243,10 +245,52 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         )
         self.features.can_return_columns_from_insert = use_returning_into
 
+    @property
+    def is_pool(self):
+        return self.settings_dict["OPTIONS"].get("pool", False)
+
+    @property
+    def pool(self):
+        if not self.is_pool:
+            return None
+
+        if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
+            raise ImproperlyConfigured(
+                "Pooling doesn't support persistent connections."
+            )
+
+        pool_key = (self.alias, self.settings_dict["USER"])
+        if pool_key not in self._connection_pools:
+            connect_kwargs = self.get_connection_params()
+            pool_options = connect_kwargs.pop("pool")
+            if pool_options is not True:
+                connect_kwargs.update(pool_options)
+
+            pool = Database.create_pool(
+                user=self.settings_dict["USER"],
+                password=self.settings_dict["PASSWORD"],
+                dsn=dsn(self.settings_dict),
+                **connect_kwargs,
+            )
+            self._connection_pools.setdefault(pool_key, pool)
+
+        return self._connection_pools[pool_key]
+
+    def close_pool(self):
+        if self.pool:
+            self.pool.close(force=True)
+            pool_key = (self.alias, self.settings_dict["USER"])
+            del self._connection_pools[pool_key]
+
     def get_database_version(self):
         return self.oracle_version
 
     def get_connection_params(self):
+        # Pooling feature is only supported for oracledb.
+        if self.is_pool and not is_oracledb:
+            raise ImproperlyConfigured(
+                "Pooling isn't supported by cx_Oracle. Use python-oracledb instead."
+            )
         conn_params = self.settings_dict["OPTIONS"].copy()
         if "use_returning_into" in conn_params:
             del conn_params["use_returning_into"]
@@ -254,6 +298,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
 
     @async_unsafe
     def get_new_connection(self, conn_params):
+        if self.pool:
+            return self.pool.acquire()
         return Database.connect(
             user=self.settings_dict["USER"],
             password=self.settings_dict["PASSWORD"],
@@ -345,6 +391,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         else:
             return True
 
+    def close_if_health_check_failed(self):
+        if self.pool:
+            # The pool only returns healthy connections.
+            return
+        return super().close_if_health_check_failed()
+
     @cached_property
     def oracle_version(self):
         with self.temporary_connection():
diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py
index b0a5177728..682379930f 100644
--- a/django/db/backends/oracle/creation.py
+++ b/django/db/backends/oracle/creation.py
@@ -205,13 +205,15 @@ class DatabaseCreation(BaseDatabaseCreation):
         Destroy a test database, prompting the user for confirmation if the
         database already exists. Return the name of the test database created.
         """
-        self.connection.settings_dict["USER"] = self.connection.settings_dict[
-            "SAVED_USER"
-        ]
-        self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[
-            "SAVED_PASSWORD"
-        ]
+        if not self.connection.is_pool:
+            self.connection.settings_dict["USER"] = self.connection.settings_dict[
+                "SAVED_USER"
+            ]
+            self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[
+                "SAVED_PASSWORD"
+            ]
         self.connection.close()
+        self.connection.close_pool()
         parameters = self._get_test_db_params()
         with self._maindb_connection.cursor() as cursor:
             if self._test_user_create():
@@ -223,6 +225,7 @@ class DatabaseCreation(BaseDatabaseCreation):
                     self.log("Destroying test database tables...")
                 self._execute_test_db_destruction(cursor, parameters, verbosity)
         self._maindb_connection.close()
+        self._maindb_connection.close_pool()
 
     def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
         if verbosity >= 2:
diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py
index 72c6180f50..ad9ab8da55 100644
--- a/django/db/backends/oracle/features.py
+++ b/django/db/backends/oracle/features.py
@@ -139,6 +139,25 @@ class DatabaseFeatures(BaseDatabaseFeatures):
                     },
                 }
             )
+        if self.connection.is_pool:
+            skips.update(
+                {
+                    "Pooling does not support persistent connections": {
+                        "backends.base.test_base.ConnectionHealthChecksTests."
+                        "test_health_checks_enabled",
+                        "backends.base.test_base.ConnectionHealthChecksTests."
+                        "test_health_checks_enabled_errors_occurred",
+                        "backends.base.test_base.ConnectionHealthChecksTests."
+                        "test_health_checks_disabled",
+                        "backends.base.test_base.ConnectionHealthChecksTests."
+                        "test_set_autocommit_health_checks_enabled",
+                        "servers.tests.LiveServerTestCloseConnectionTest."
+                        "test_closes_connections",
+                        "backends.oracle.tests.TransactionalTests."
+                        "test_password_with_at_sign",
+                    },
+                }
+            )
         if is_oracledb and self.connection.oracledb_version >= (2, 1, 2):
             skips.update(
                 {
diff --git a/docs/ref/databases.txt b/docs/ref/databases.txt
index 57e94140c2..d19c78b9ec 100644
--- a/docs/ref/databases.txt
+++ b/docs/ref/databases.txt
@@ -994,7 +994,7 @@ Oracle notes
 ============
 
 Django supports `Oracle Database Server`_ versions 19c and higher. Version
-1.3.2 or higher of the `oracledb`_ Python driver is required.
+2.3.0 or higher of the `oracledb`_ Python driver is required.
 
 .. deprecated:: 5.0
 
@@ -1105,6 +1105,46 @@ Example of a full DSN string::
         "(CONNECT_DATA=(SERVICE_NAME=orclpdb1)))"
     )
 
+.. _oracle-pool:
+
+Connection pool
+---------------
+
+.. versionadded:: 5.2
+
+To use a connection pool with `oracledb`_, set ``"pool"`` to ``True`` in the
+:setting:`OPTIONS` part of your database configuration. This uses the driver's
+`create_pool()`_ default values::
+
+    DATABASES = {
+        "default": {
+            "ENGINE": "django.db.backends.oracle",
+            # ...
+            "OPTIONS": {
+                "pool": True,
+            },
+        },
+    }
+
+To pass custom parameters to the driver's `create_pool()`_  function, you can
+alternatively set ``"pool"`` to be a dict::
+
+    DATABASES = {
+        "default": {
+            "ENGINE": "django.db.backends.oracle",
+            # ...
+            "OPTIONS": {
+                "pool": {
+                    "min": 1,
+                    "max": 10,
+                    # ...
+                }
+            },
+        },
+    }
+
+.. _`create_pool()`: https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#connection-pooling
+
 Threaded option
 ---------------
 
diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt
index 58d5727f32..918bd6a700 100644
--- a/docs/releases/5.2.txt
+++ b/docs/releases/5.2.txt
@@ -197,6 +197,9 @@ Database backends
   instead of ``utf8``, which is an alias for the deprecated character set
   ``utf8mb3``.
 
+* Oracle backends now support :ref:`connection pools <oracle-pool>`, by setting
+  ``"pool"`` in the :setting:`OPTIONS` part of your database configuration.
+
 Decorators
 ~~~~~~~~~~
 
@@ -464,6 +467,9 @@ Miscellaneous
   * :meth:`.QuerySet.update_or_create`
   * :meth:`.QuerySet.aupdate_or_create`
 
+* The minimum supported version of ``oracledb`` is increased from 1.3.2 to
+  2.3.0.
+
 .. _deprecated-features-5.2:
 
 Features deprecated in 5.2
diff --git a/tests/backends/oracle/tests.py b/tests/backends/oracle/tests.py
index a4aa26cd2e..0c4dbce8ba 100644
--- a/tests/backends/oracle/tests.py
+++ b/tests/backends/oracle/tests.py
@@ -1,12 +1,28 @@
+import copy
 import unittest
 from unittest import mock
 
-from django.db import DatabaseError, NotSupportedError, connection
+from django.core.exceptions import ImproperlyConfigured
+from django.db import DatabaseError, NotSupportedError, ProgrammingError, connection
 from django.db.models import BooleanField
 from django.test import TestCase, TransactionTestCase
 
 from ..models import Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
 
+try:
+    from django.db.backends.oracle.oracledb_any import is_oracledb
+except ImportError:
+    is_oracledb = False
+
+
+def no_pool_connection(alias=None):
+    new_connection = connection.copy(alias)
+    new_connection.settings_dict = copy.deepcopy(connection.settings_dict)
+    # Ensure that the second connection circumvents the pool, this is kind
+    # of a hack, but we cannot easily change the pool connections.
+    new_connection.settings_dict["OPTIONS"]["pool"] = False
+    return new_connection
+
 
 @unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
 class Tests(TestCase):
@@ -69,6 +85,76 @@ class Tests(TestCase):
             connection.check_database_version_supported()
         self.assertTrue(mocked_get_database_version.called)
 
+    @unittest.skipUnless(is_oracledb, "Pool specific tests")
+    def test_pool_set_to_true(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        try:
+            self.assertIsNotNone(new_connection.pool)
+        finally:
+            new_connection.close_pool()
+
+    @unittest.skipUnless(is_oracledb, "Pool specific tests")
+    def test_pool_reuse(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = {
+            "min": 0,
+            "max": 2,
+        }
+        self.assertIsNotNone(new_connection.pool)
+
+        connections = []
+
+        def get_connection():
+            # copy() reuses the existing alias and as such the same pool.
+            conn = new_connection.copy()
+            conn.connect()
+            connections.append(conn)
+            return conn
+
+        try:
+            connection_1 = get_connection()  # First connection.
+            get_connection()  # Get the second connection.
+            sql = "select sys_context('userenv', 'sid') from dual"
+            sids = [conn.cursor().execute(sql).fetchone()[0] for conn in connections]
+            connection_1.close()  # Release back to the pool.
+            connection_3 = get_connection()
+            sid = connection_3.cursor().execute(sql).fetchone()[0]
+            # Reuses the first connection as it is available.
+            self.assertEqual(sid, sids[0])
+        finally:
+            # Release all connections back to the pool.
+            for conn in connections:
+                conn.close()
+            new_connection.close_pool()
+
+    @unittest.skipUnless(is_oracledb, "Pool specific tests")
+    def test_cannot_open_new_connection_in_atomic_block(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        msg = "Cannot open a new connection in an atomic block."
+        new_connection.in_atomic_block = True
+        new_connection.closed_in_transaction = True
+        with self.assertRaisesMessage(ProgrammingError, msg):
+            new_connection.ensure_connection()
+
+    @unittest.skipUnless(is_oracledb, "Pool specific tests")
+    def test_pooling_not_support_persistent_connections(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        new_connection.settings_dict["CONN_MAX_AGE"] = 10
+        msg = "Pooling doesn't support persistent connections."
+        with self.assertRaisesMessage(ImproperlyConfigured, msg):
+            new_connection.pool
+
+    @unittest.skipIf(is_oracledb, "cx_oracle specific tests")
+    def test_cx_Oracle_not_support_pooling(self):
+        new_connection = no_pool_connection()
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        msg = "Pooling isn't supported by cx_Oracle. Use python-oracledb instead."
+        with self.assertRaisesMessage(ImproperlyConfigured, msg):
+            new_connection.connect()
+
 
 @unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
 class TransactionalTests(TransactionTestCase):
diff --git a/tests/requirements/oracle.txt b/tests/requirements/oracle.txt
index 5f9058f822..d1490df020 100644
--- a/tests/requirements/oracle.txt
+++ b/tests/requirements/oracle.txt
@@ -1 +1 @@
-oracledb >= 1.3.2
+oracledb >= 2.3.0