From 3b429c96736b8328c40e5d77282b0d30de563c3c Mon Sep 17 00:00:00 2001
From: Simon Charette <charette.s@gmail.com>
Date: Tue, 24 May 2016 15:25:05 -0400
Subject: [PATCH] Refs #25530 -- Tracked references of deferred SQL statements.

---
 django/db/backends/base/schema.py       |  61 ++++++-----
 django/db/backends/ddl_references.py    | 128 ++++++++++++++++++++++++
 django/db/backends/sqlite3/schema.py    |  10 +-
 tests/backends/test_ddl_references.py   | 125 +++++++++++++++++++++++
 tests/indexes/tests.py                  |   8 +-
 tests/model_options/test_tablespaces.py |   2 +-
 6 files changed, 302 insertions(+), 32 deletions(-)
 create mode 100644 django/db/backends/ddl_references.py
 create mode 100644 tests/backends/test_ddl_references.py

diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py
index 20243ae070..bf22711131 100644
--- a/django/db/backends/base/schema.py
+++ b/django/db/backends/base/schema.py
@@ -2,6 +2,9 @@ import hashlib
 import logging
 from datetime import datetime
 
+from django.db.backends.ddl_references import (
+    Columns, ForeignKeyName, IndexName, Statement, Table,
+)
 from django.db.backends.utils import strip_quotes
 from django.db.models import Index
 from django.db.transaction import TransactionManagementError, atomic
@@ -97,6 +100,8 @@ class BaseDatabaseSchemaEditor:
                 "Executing DDL statements while in a transaction on databases "
                 "that can't perform a rollback is prohibited."
             )
+        # Account for non-string statement objects.
+        sql = str(sql)
         # Log the command we're running, then run it
         logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql})
         if self.collect_sql:
@@ -878,13 +883,19 @@ class BaseDatabaseSchemaEditor:
         tablespace_sql = self._get_index_tablespace_sql(model, fields)
         columns = [field.column for field in fields]
         sql_create_index = sql or self.sql_create_index
-        return sql_create_index % {
-            "table": self.quote_name(model._meta.db_table),
-            "name": self.quote_name(self._create_index_name(model._meta.db_table, columns, suffix=suffix)),
-            "using": "",
-            "columns": ", ".join(self.quote_name(column) for column in columns),
-            "extra": tablespace_sql,
-        }
+        table = model._meta.db_table
+
+        def create_index_name(*args, **kwargs):
+            return self.quote_name(self._create_index_name(*args, **kwargs))
+
+        return Statement(
+            sql_create_index,
+            table=Table(table, self.quote_name),
+            name=IndexName(table, columns, suffix, create_index_name),
+            using='',
+            columns=Columns(table, columns, self.quote_name),
+            extra=tablespace_sql,
+        )
 
     def _model_indexes_sql(self, model):
         """
@@ -930,26 +941,28 @@ class BaseDatabaseSchemaEditor:
         from_column = field.column
         to_table = field.target_field.model._meta.db_table
         to_column = field.target_field.column
-        suffix = suffix % {
-            "to_table": to_table,
-            "to_column": to_column,
-        }
 
-        return self.sql_create_fk % {
-            "table": self.quote_name(from_table),
-            "name": self.quote_name(self._create_index_name(model._meta.db_table, [from_column], suffix=suffix)),
-            "column": self.quote_name(from_column),
-            "to_table": self.quote_name(to_table),
-            "to_column": self.quote_name(to_column),
-            "deferrable": self.connection.ops.deferrable_sql(),
-        }
+        def create_fk_name(*args, **kwargs):
+            return self.quote_name(self._create_index_name(*args, **kwargs))
+
+        return Statement(
+            self.sql_create_fk,
+            table=Table(from_table, self.quote_name),
+            name=ForeignKeyName(from_table, [from_column], to_table, [to_column], suffix, create_fk_name),
+            column=Columns(from_table, [from_column], self.quote_name),
+            to_table=Table(to_table, self.quote_name),
+            to_column=Columns(to_table, [to_column], self.quote_name),
+            deferrable=self.connection.ops.deferrable_sql(),
+        )
 
     def _create_unique_sql(self, model, columns):
-        return self.sql_create_unique % {
-            "table": self.quote_name(model._meta.db_table),
-            "name": self.quote_name(self._create_index_name(model._meta.db_table, columns, suffix="_uniq")),
-            "columns": ", ".join(self.quote_name(column) for column in columns),
-        }
+        table = model._meta.db_table
+        return Statement(
+            self.sql_create_unique,
+            table=Table(table, self.quote_name),
+            name=IndexName(table, columns, '_uniq', self._create_index_name),
+            columns=Columns(table, columns, self.quote_name),
+        )
 
     def _delete_constraint_sql(self, template, model, name):
         return template % {
diff --git a/django/db/backends/ddl_references.py b/django/db/backends/ddl_references.py
new file mode 100644
index 0000000000..dd4d1aa415
--- /dev/null
+++ b/django/db/backends/ddl_references.py
@@ -0,0 +1,128 @@
+"""
+Helpers to manipulate deferred DDL statements that might need to be adjusted or
+discarded within when executing a migration.
+"""
+
+
+class Reference:
+    """Base class that defines the reference interface."""
+
+    def references_table(self, table):
+        """
+        Return whether or not this instance references the specified table.
+        """
+        return False
+
+    def references_column(self, table, column):
+        """
+        Return whether or not this instance references the specified column.
+        """
+        return False
+
+    def __repr__(self):
+        return '<%s %r>' % (self.__class__.__name__, str(self))
+
+    def __str__(self):
+        raise NotImplementedError('Subclasses must define how they should be converted to string.')
+
+
+class Table(Reference):
+    """Hold a reference to a table."""
+
+    def __init__(self, table, quote_name):
+        self.table = table
+        self.quote_name = quote_name
+
+    def references_table(self, table):
+        return self.table == table
+
+    def __str__(self):
+        return self.quote_name(self.table)
+
+
+class TableColumns(Table):
+    """Base class for references to multiple columns of a table."""
+
+    def __init__(self, table, columns):
+        self.table = table
+        self.columns = columns
+
+    def references_column(self, table, column):
+        return self.table == table and column in self.columns
+
+
+class Columns(TableColumns):
+    """Hold a reference to one or many columns."""
+
+    def __init__(self, table, columns, quote_name):
+        self.quote_name = quote_name
+        super().__init__(table, columns)
+
+    def __str__(self):
+        return ', '.join(self.quote_name(column) for column in self.columns)
+
+
+class IndexName(TableColumns):
+    """Hold a reference to an index name."""
+
+    def __init__(self, table, columns, suffix, create_index_name):
+        self.suffix = suffix
+        self.create_index_name = create_index_name
+        super().__init__(table, columns)
+
+    def __str__(self):
+        return self.create_index_name(self.table, self.columns, self.suffix)
+
+
+class ForeignKeyName(TableColumns):
+    """Hold a reference to a foreign key name."""
+
+    def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name):
+        self.to_reference = TableColumns(to_table, to_columns)
+        self.suffix_template = suffix_template
+        self.create_fk_name = create_fk_name
+        super().__init__(from_table, from_columns,)
+
+    def references_table(self, table):
+        return super().references_table(table) or self.to_reference.references_table(table)
+
+    def references_column(self, table, column):
+        return (
+            super().references_column(table, column) or
+            self.to_reference.references_column(table, column)
+        )
+
+    def __str__(self):
+        suffix = self.suffix_template % {
+            'to_table': self.to_reference.table,
+            'to_column': self.to_reference.columns[0],
+        }
+        return self.create_fk_name(self.table, self.columns, suffix)
+
+
+class Statement(Reference):
+    """
+    Statement template and formatting parameters container.
+
+    Allows keeping a reference to a statement without interpolating identifiers
+    that might have to be adjusted if they're referencing a table or column
+    that is removed
+    """
+    def __init__(self, template, **parts):
+        self.template = template
+        self.parts = parts
+
+    def references_table(self, table):
+        return any(
+            hasattr(part, 'references_table') and part.references_table(table)
+            for part in self.parts.values()
+        )
+
+    def references_column(self, table, column):
+        return any(
+            hasattr(part, 'references_column') and part.references_column(table, column)
+            for part in self.parts.values()
+        )
+
+    def __str__(self):
+        return self.template % self.parts
diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py
index 10d7c623f8..5517d78a97 100644
--- a/django/db/backends/sqlite3/schema.py
+++ b/django/db/backends/sqlite3/schema.py
@@ -5,6 +5,7 @@ from decimal import Decimal
 
 from django.apps.registry import Apps
 from django.db.backends.base.schema import BaseDatabaseSchemaEditor
+from django.db.backends.ddl_references import Statement
 
 
 class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
@@ -189,9 +190,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
             # Rename the old table to make way for the new
             self.alter_db_table(model, temp_model._meta.db_table, model._meta.db_table)
 
-            # Create a new table with the updated schema. We remove things
-            # from the deferred SQL that match our table name, too
-            self.deferred_sql = [x for x in self.deferred_sql if temp_model._meta.db_table not in x]
+            # Remove all deferred statements referencing the temporary table.
+            for sql in list(self.deferred_sql):
+                if isinstance(sql, Statement) and sql.references_table(temp_model._meta.db_table):
+                    self.deferred_sql.remove(sql)
+
+            # Create a new table with the updated schema.
             self.create_model(temp_model)
 
             # Copy data from the old table into the new table
diff --git a/tests/backends/test_ddl_references.py b/tests/backends/test_ddl_references.py
new file mode 100644
index 0000000000..268eed988b
--- /dev/null
+++ b/tests/backends/test_ddl_references.py
@@ -0,0 +1,125 @@
+from django.db.backends.ddl_references import (
+    Columns, ForeignKeyName, IndexName, Statement, Table,
+)
+from django.test import SimpleTestCase
+
+
+class TableTests(SimpleTestCase):
+    def setUp(self):
+        self.reference = Table('table', lambda table: table.upper())
+
+    def test_references_table(self):
+        self.assertIs(self.reference.references_table('table'), True)
+        self.assertIs(self.reference.references_table('other'), False)
+
+    def test_repr(self):
+        self.assertEqual(repr(self.reference), "<Table 'TABLE'>")
+
+    def test_str(self):
+        self.assertEqual(str(self.reference), 'TABLE')
+
+
+class ColumnsTests(TableTests):
+    def setUp(self):
+        self.reference = Columns(
+            'table', ['first_column', 'second_column'], lambda column: column.upper()
+        )
+
+    def test_references_column(self):
+        self.assertIs(self.reference.references_column('other', 'first_column'), False)
+        self.assertIs(self.reference.references_column('table', 'third_column'), False)
+        self.assertIs(self.reference.references_column('table', 'first_column'), True)
+
+    def test_repr(self):
+        self.assertEqual(repr(self.reference), "<Columns 'FIRST_COLUMN, SECOND_COLUMN'>")
+
+    def test_str(self):
+        self.assertEqual(str(self.reference), 'FIRST_COLUMN, SECOND_COLUMN')
+
+
+class IndexNameTests(ColumnsTests):
+    def setUp(self):
+        def create_index_name(table_name, column_names, suffix):
+            return ', '.join("%s_%s_%s" % (table_name, column_name, suffix) for column_name in column_names)
+        self.reference = IndexName(
+            'table', ['first_column', 'second_column'], 'suffix', create_index_name
+        )
+
+    def test_repr(self):
+        self.assertEqual(repr(self.reference), "<IndexName 'table_first_column_suffix, table_second_column_suffix'>")
+
+    def test_str(self):
+        self.assertEqual(str(self.reference), 'table_first_column_suffix, table_second_column_suffix')
+
+
+class ForeignKeyNameTests(IndexNameTests):
+    def setUp(self):
+        def create_foreign_key_name(table_name, column_names, suffix):
+            return ', '.join("%s_%s_%s" % (table_name, column_name, suffix) for column_name in column_names)
+        self.reference = ForeignKeyName(
+            'table', ['first_column', 'second_column'],
+            'to_table', ['to_first_column', 'to_second_column'],
+            '%(to_table)s_%(to_column)s_fk',
+            create_foreign_key_name,
+        )
+
+    def test_references_table(self):
+        super().test_references_table()
+        self.assertIs(self.reference.references_table('to_table'), True)
+
+    def test_references_column(self):
+        super().test_references_column()
+        self.assertIs(self.reference.references_column('to_table', 'second_column'), False)
+        self.assertIs(self.reference.references_column('to_table', 'to_second_column'), True)
+
+    def test_repr(self):
+        self.assertEqual(
+            repr(self.reference),
+            "<ForeignKeyName 'table_first_column_to_table_to_first_column_fk, "
+            "table_second_column_to_table_to_first_column_fk'>"
+        )
+
+    def test_str(self):
+        self.assertEqual(
+            str(self.reference),
+            'table_first_column_to_table_to_first_column_fk, '
+            'table_second_column_to_table_to_first_column_fk'
+        )
+
+
+class MockReference(object):
+    def __init__(self, representation, referenced_tables, referenced_columns):
+        self.representation = representation
+        self.referenced_tables = referenced_tables
+        self.referenced_columns = referenced_columns
+
+    def references_table(self, table):
+        return table in self.referenced_tables
+
+    def references_column(self, table, column):
+        return (table, column) in self.referenced_columns
+
+    def __str__(self):
+        return self.representation
+
+
+class StatementTests(SimpleTestCase):
+    def test_references_table(self):
+        statement = Statement('', reference=MockReference('', {'table'}, {}), non_reference='')
+        self.assertIs(statement.references_table('table'), True)
+        self.assertIs(statement.references_table('other'), False)
+
+    def test_references_column(self):
+        statement = Statement('', reference=MockReference('', {}, {('table', 'column')}), non_reference='')
+        self.assertIs(statement.references_column('table', 'column'), True)
+        self.assertIs(statement.references_column('other', 'column'), False)
+
+    def test_repr(self):
+        reference = MockReference('reference', {}, {})
+        statement = Statement("%(reference)s - %(non_reference)s", reference=reference, non_reference='non_reference')
+        self.assertEqual(repr(statement), "<Statement 'reference - non_reference'>")
+
+    def test_str(self):
+        reference = MockReference('reference', {}, {})
+        statement = Statement("%(reference)s - %(non_reference)s", reference=reference, non_reference='non_reference')
+        self.assertEqual(str(statement), 'reference - non_reference')
diff --git a/tests/indexes/tests.py b/tests/indexes/tests.py
index c2d76feeb9..ee2cbd1564 100644
--- a/tests/indexes/tests.py
+++ b/tests/indexes/tests.py
@@ -51,7 +51,7 @@ class SchemaIndexesTests(TestCase):
 
     def test_index_together(self):
         editor = connection.schema_editor()
-        index_sql = editor._model_indexes_sql(Article)
+        index_sql = [str(statement) for statement in editor._model_indexes_sql(Article)]
         self.assertEqual(len(index_sql), 1)
         # Ensure the index name is properly quoted
         self.assertIn(
@@ -70,7 +70,7 @@ class SchemaIndexesTests(TestCase):
     def test_postgresql_text_indexes(self):
         """Test creation of PostgreSQL-specific text indexes (#12234)"""
         from .models import IndexedArticle
-        index_sql = connection.schema_editor()._model_indexes_sql(IndexedArticle)
+        index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(IndexedArticle)]
         self.assertEqual(len(index_sql), 5)
         self.assertIn('("headline" varchar_pattern_ops)', index_sql[1])
         self.assertIn('("body" text_pattern_ops)', index_sql[3])
@@ -99,7 +99,7 @@ class SchemaIndexesMySQLTests(TransactionTestCase):
         )
         if storage != "InnoDB":
             self.skip("This test only applies to the InnoDB storage engine")
-        index_sql = connection.schema_editor()._model_indexes_sql(ArticleTranslation)
+        index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(ArticleTranslation)]
         self.assertEqual(index_sql, [
             'CREATE INDEX `indexes_articletranslation_article_no_constraint_id_d6c0806b` '
             'ON `indexes_articletranslation` (`article_no_constraint_id`)'
@@ -114,7 +114,7 @@ class SchemaIndexesMySQLTests(TransactionTestCase):
                 new_field.set_attributes_from_name('new_foreign_key')
                 editor.add_field(ArticleTranslation, new_field)
                 field_created = True
-                self.assertEqual(editor.deferred_sql, [
+                self.assertEqual([str(statement) for statement in editor.deferred_sql], [
                     'ALTER TABLE `indexes_articletranslation` '
                     'ADD CONSTRAINT `indexes_articletrans_new_foreign_key_id_d27a9146_fk_indexes_a` '
                     'FOREIGN KEY (`new_foreign_key_id`) REFERENCES `indexes_article` (`id`)'
diff --git a/tests/model_options/test_tablespaces.py b/tests/model_options/test_tablespaces.py
index 03a137603b..79b0a8bb75 100644
--- a/tests/model_options/test_tablespaces.py
+++ b/tests/model_options/test_tablespaces.py
@@ -15,7 +15,7 @@ def sql_for_table(model):
 
 
 def sql_for_index(model):
-    return '\n'.join(connection.schema_editor()._model_indexes_sql(model))
+    return '\n'.join(str(sql) for sql in connection.schema_editor()._model_indexes_sql(model))
 
 
 # We can't test the DEFAULT_TABLESPACE and DEFAULT_INDEX_TABLESPACE settings