From 85f6d893138c052f0fa24d7c2005b9c222af91b4 Mon Sep 17 00:00:00 2001
From: Markus Holtermann <info@markusholtermann.eu>
Date: Tue, 16 Sep 2014 02:25:02 +0200
Subject: [PATCH] Fixed #23426 -- Allowed parameters in migrations.RunSQL

Thanks tchaumeny and Loic for reviews.
---
 django/db/migrations/operations/special.py | 24 +++++--
 docs/ref/migration-operations.txt          | 20 +++++-
 docs/releases/1.8.txt                      |  6 ++
 tests/migrations/test_operations.py        | 81 ++++++++++++++++++++++
 4 files changed, 123 insertions(+), 8 deletions(-)

diff --git a/django/db/migrations/operations/special.py b/django/db/migrations/operations/special.py
index bfe418034c..3a29a33a6b 100644
--- a/django/db/migrations/operations/special.py
+++ b/django/db/migrations/operations/special.py
@@ -64,20 +64,32 @@ class RunSQL(Operation):
             state_operation.state_forwards(app_label, state)
 
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
-        statements = schema_editor.connection.ops.prepare_sql_script(self.sql)
-        for statement in statements:
-            schema_editor.execute(statement, params=None)
+        self._run_sql(schema_editor, self.sql)
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         if self.reverse_sql is None:
             raise NotImplementedError("You cannot reverse this operation")
-        statements = schema_editor.connection.ops.prepare_sql_script(self.reverse_sql)
-        for statement in statements:
-            schema_editor.execute(statement, params=None)
+        self._run_sql(schema_editor, self.reverse_sql)
 
     def describe(self):
         return "Raw SQL operation"
 
+    def _run_sql(self, schema_editor, sql):
+        if isinstance(sql, (list, tuple)):
+            for sql in sql:
+                params = None
+                if isinstance(sql, (list, tuple)):
+                    elements = len(sql)
+                    if elements == 2:
+                        sql, params = sql
+                    else:
+                        raise ValueError("Expected a 2-tuple but got %d" % elements)
+                schema_editor.execute(sql, params=params)
+        else:
+            statements = schema_editor.connection.ops.prepare_sql_script(sql)
+            for statement in statements:
+                schema_editor.execute(statement, params=None)
+
 
 class RunPython(Operation):
     """
diff --git a/docs/ref/migration-operations.txt b/docs/ref/migration-operations.txt
index 2c5c1bb980..6998bdb574 100644
--- a/docs/ref/migration-operations.txt
+++ b/docs/ref/migration-operations.txt
@@ -188,6 +188,17 @@ the database. On most database backends (all but PostgreSQL), Django will
 split the SQL into individual statements prior to executing them. This
 requires installing the sqlparse_ Python library.
 
+You can also pass a list of strings or 2-tuples. The latter is used for passing
+queries and parameters in the same way as :ref:`cursor.execute()
+<executing-custom-sql>`. These three operations are equivalent::
+
+    migrations.RunSQL("INSERT INTO musician (name) VALUES ('Reinhardt');")
+    migrations.RunSQL(["INSERT INTO musician (name) VALUES ('Reinhardt');", None])
+    migrations.RunSQL(["INSERT INTO musician (name) VALUES (%s);", ['Reinhardt']])
+
+If you want to include literal percent signs in the query, you have to double
+them if you are passing parameters.
+
 The ``state_operations`` argument is so you can supply operations that are
 equivalent to the SQL in terms of project state; for example, if you are
 manually creating a column, you should pass in a list containing an ``AddField``
@@ -197,8 +208,13 @@ operation that adds that field and so will try to run it again).
 
 .. versionchanged:: 1.7.1
 
-    If you want to include literal percent signs in the query you don't need to
-    double them anymore.
+    If you want to include literal percent signs in a query without parameters
+    you don't need to double them anymore.
+
+.. versionchanged:: 1.8
+
+    The ability to pass parameters to the ``sql`` and ``reverse_sql`` queries
+    was added.
 
 .. _sqlparse: https://pypi.python.org/pypi/sqlparse
 
diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt
index 67a6fab91d..3eec697364 100644
--- a/docs/releases/1.8.txt
+++ b/docs/releases/1.8.txt
@@ -265,6 +265,12 @@ Management Commands
 
 * :djadmin:`makemigrations` can now serialize timezone-aware values.
 
+Migrations
+^^^^^^^^^^
+
+* The :class:`~django.db.migrations.operations.RunSQL` operation can now handle
+  parameters passed to the SQL statements.
+
 Models
 ^^^^^^
 
diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py
index 208f7e24ba..4f885b407e 100644
--- a/tests/migrations/test_operations.py
+++ b/tests/migrations/test_operations.py
@@ -1195,6 +1195,87 @@ class OperationTests(OperationTestBase):
             operation.database_backwards("test_runsql", editor, new_state, project_state)
         self.assertTableNotExists("i_love_ponies")
 
+    def test_run_sql_params(self):
+        """
+        #23426 - RunSQL should accept parameters.
+        """
+        project_state = self.set_up_test_model("test_runsql")
+        # Create the operation
+        operation = migrations.RunSQL(
+            "CREATE TABLE i_love_ponies (id int, special_thing varchar(15));",
+            "DROP TABLE i_love_ponies",
+        )
+        param_operation = migrations.RunSQL(
+            # forwards
+            (
+                "INSERT INTO i_love_ponies (id, special_thing) VALUES (1, 'Django');",
+                ["INSERT INTO i_love_ponies (id, special_thing) VALUES (2, %s);", ['Ponies']],
+                ("INSERT INTO i_love_ponies (id, special_thing) VALUES (%s, %s);", (3, 'Python',)),
+            ),
+            # backwards
+            [
+                "DELETE FROM i_love_ponies WHERE special_thing = 'Django';",
+                ["DELETE FROM i_love_ponies WHERE special_thing = 'Ponies';", None],
+                ("DELETE FROM i_love_ponies WHERE id = %s OR special_thing = %s;", [3, 'Python']),
+            ]
+        )
+
+        # Make sure there's no table
+        self.assertTableNotExists("i_love_ponies")
+        new_state = project_state.clone()
+        # Test the database alteration
+        with connection.schema_editor() as editor:
+            operation.database_forwards("test_runsql", editor, project_state, new_state)
+
+        # Test parameter passing
+        with connection.schema_editor() as editor:
+            param_operation.database_forwards("test_runsql", editor, project_state, new_state)
+        # Make sure all the SQL was processed
+        with connection.cursor() as cursor:
+            cursor.execute("SELECT COUNT(*) FROM i_love_ponies")
+            self.assertEqual(cursor.fetchall()[0][0], 3)
+
+        with connection.schema_editor() as editor:
+            param_operation.database_backwards("test_runsql", editor, new_state, project_state)
+        with connection.cursor() as cursor:
+            cursor.execute("SELECT COUNT(*) FROM i_love_ponies")
+            self.assertEqual(cursor.fetchall()[0][0], 0)
+
+        # And test reversal
+        with connection.schema_editor() as editor:
+            operation.database_backwards("test_runsql", editor, new_state, project_state)
+        self.assertTableNotExists("i_love_ponies")
+
+    def test_run_sql_params_invalid(self):
+        """
+        #23426 - RunSQL should fail when a list of statements with an incorrect
+        number of tuples is given.
+        """
+        project_state = self.set_up_test_model("test_runsql")
+        new_state = project_state.clone()
+        operation = migrations.RunSQL(
+            # forwards
+            [
+                ["INSERT INTO foo (bar) VALUES ('buz');"]
+            ],
+            # backwards
+            (
+                ("DELETE FROM foo WHERE bar = 'buz';", 'invalid', 'parameter count'),
+            ),
+        )
+
+        with connection.schema_editor() as editor:
+            self.assertRaisesRegexp(ValueError,
+                "Expected a 2-tuple but got 1",
+                operation.database_forwards,
+                "test_runsql", editor, project_state, new_state)
+
+        with connection.schema_editor() as editor:
+            self.assertRaisesRegexp(ValueError,
+                "Expected a 2-tuple but got 3",
+                operation.database_backwards,
+                "test_runsql", editor, new_state, project_state)
+
     def test_run_python(self):
         """
         Tests the RunPython operation