From db926a0048bbfabafc5711034d08e2a221564f3e Mon Sep 17 00:00:00 2001
From: Krzysztof Gogolewski <krz.gogolewski@gmail.com>
Date: Tue, 20 Mar 2018 22:14:22 +0100
Subject: [PATCH] Fixed #29243 -- Improved efficiency of migration graph
 algorithm.

---
 django/db/migrations/graph.py  | 121 ++++++++-------------------------
 django/db/migrations/loader.py |   6 +-
 tests/migrations/test_graph.py | 103 +++++++++++++---------------
 3 files changed, 81 insertions(+), 149 deletions(-)

diff --git a/django/db/migrations/graph.py b/django/db/migrations/graph.py
index 1f0bafb00e..0096392c5d 100644
--- a/django/db/migrations/graph.py
+++ b/django/db/migrations/graph.py
@@ -1,18 +1,9 @@
-import warnings
 from functools import total_ordering
 
 from django.db.migrations.state import ProjectState
-from django.utils.datastructures import OrderedSet
 
 from .exceptions import CircularDependencyError, NodeNotFoundError
 
-RECURSION_DEPTH_WARNING = (
-    "Maximum recursion depth exceeded while generating migration graph, "
-    "falling back to iterative approach. If you're experiencing performance issues, "
-    "consider squashing migrations as described at "
-    "https://docs.djangoproject.com/en/dev/topics/migrations/#squashing-migrations."
-)
-
 
 @total_ordering
 class Node:
@@ -49,49 +40,20 @@ class Node:
     def add_parent(self, parent):
         self.parents.add(parent)
 
-    # Use manual caching, @cached_property effectively doubles the
-    # recursion depth for each recursion.
-    def ancestors(self):
-        # Use self.key instead of self to speed up the frequent hashing
-        # when constructing an OrderedSet.
-        if '_ancestors' not in self.__dict__:
-            ancestors = []
-            for parent in sorted(self.parents, reverse=True):
-                ancestors += parent.ancestors()
-            ancestors.append(self.key)
-            self.__dict__['_ancestors'] = list(OrderedSet(ancestors))
-        return self.__dict__['_ancestors']
-
-    # Use manual caching, @cached_property effectively doubles the
-    # recursion depth for each recursion.
-    def descendants(self):
-        # Use self.key instead of self to speed up the frequent hashing
-        # when constructing an OrderedSet.
-        if '_descendants' not in self.__dict__:
-            descendants = []
-            for child in sorted(self.children, reverse=True):
-                descendants += child.descendants()
-            descendants.append(self.key)
-            self.__dict__['_descendants'] = list(OrderedSet(descendants))
-        return self.__dict__['_descendants']
-
 
 class DummyNode(Node):
+    """
+    A node that doesn't correspond to a migration file on disk.
+    (A squashed migration that was removed, for example.)
+
+    After the migration graph is processed, all dummy nodes should be removed.
+    If there are any left, a nonexistent dependency error is raised.
+    """
     def __init__(self, key, origin, error_message):
         super().__init__(key)
         self.origin = origin
         self.error_message = error_message
 
-    def promote(self):
-        """
-        Transition dummy to a normal node and clean off excess attribs.
-        Creating a Node object from scratch would be too much of a
-        hassle as many dependendies would need to be remapped.
-        """
-        del self.origin
-        del self.error_message
-        self.__class__ = Node
-
     def raise_error(self):
         raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
 
@@ -122,19 +84,12 @@ class MigrationGraph:
     def __init__(self):
         self.node_map = {}
         self.nodes = {}
-        self.cached = False
 
     def add_node(self, key, migration):
-        # If the key already exists, then it must be a dummy node.
-        dummy_node = self.node_map.get(key)
-        if dummy_node:
-            # Promote DummyNode to Node.
-            dummy_node.promote()
-        else:
-            node = Node(key)
-            self.node_map[key] = node
+        assert key not in self.node_map
+        node = Node(key)
+        self.node_map[key] = node
         self.nodes[key] = migration
-        self.clear_cache()
 
     def add_dummy_node(self, key, origin, error_message):
         node = DummyNode(key, origin, error_message)
@@ -163,7 +118,6 @@ class MigrationGraph:
         self.node_map[parent].add_child(self.node_map[child])
         if not skip_validation:
             self.validate_consistency()
-        self.clear_cache()
 
     def remove_replaced_nodes(self, replacement, replaced):
         """
@@ -199,7 +153,6 @@ class MigrationGraph:
                     if parent.key not in replaced:
                         replacement_node.add_parent(parent)
                         parent.add_child(replacement_node)
-        self.clear_cache()
 
     def remove_replacement_node(self, replacement, replaced):
         """
@@ -236,19 +189,11 @@ class MigrationGraph:
             parent.children.remove(replacement_node)
             # NOTE: There is no need to remap parent dependencies as we can
             # assume the replaced nodes already have the correct ancestry.
-        self.clear_cache()
 
     def validate_consistency(self):
         """Ensure there are no dummy nodes remaining in the graph."""
         [n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
 
-    def clear_cache(self):
-        if self.cached:
-            for node in self.nodes:
-                self.node_map[node].__dict__.pop('_ancestors', None)
-                self.node_map[node].__dict__.pop('_descendants', None)
-            self.cached = False
-
     def forwards_plan(self, target):
         """
         Given a node, return a list of which previous nodes (dependencies) must
@@ -257,16 +202,7 @@ class MigrationGraph:
         """
         if target not in self.nodes:
             raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
-        # Use parent.key instead of parent to speed up the frequent hashing in ensure_not_cyclic
-        self.ensure_not_cyclic(target, lambda x: (parent.key for parent in self.node_map[x].parents))
-        self.cached = True
-        node = self.node_map[target]
-        try:
-            return node.ancestors()
-        except RuntimeError:
-            # fallback to iterative dfs
-            warnings.warn(RECURSION_DEPTH_WARNING, RuntimeWarning)
-            return self.iterative_dfs(node)
+        return self.iterative_dfs(self.node_map[target])
 
     def backwards_plan(self, target):
         """
@@ -276,26 +212,24 @@ class MigrationGraph:
         """
         if target not in self.nodes:
             raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
-        # Use child.key instead of child to speed up the frequent hashing in ensure_not_cyclic
-        self.ensure_not_cyclic(target, lambda x: (child.key for child in self.node_map[x].children))
-        self.cached = True
-        node = self.node_map[target]
-        try:
-            return node.descendants()
-        except RuntimeError:
-            # fallback to iterative dfs
-            warnings.warn(RECURSION_DEPTH_WARNING, RuntimeWarning)
-            return self.iterative_dfs(node, forwards=False)
+        return self.iterative_dfs(self.node_map[target], forwards=False)
 
     def iterative_dfs(self, start, forwards=True):
         """Iterative depth-first search for finding dependencies."""
         visited = []
-        stack = [start]
+        visited_set = set()
+        stack = [(start, False)]
         while stack:
-            node = stack.pop()
-            visited.append(node)
-            stack += sorted(node.parents if forwards else node.children)
-        return list(OrderedSet(reversed(visited)))
+            node, processed = stack.pop()
+            if node in visited_set:
+                pass
+            elif processed:
+                visited_set.add(node)
+                visited.append(node.key)
+            else:
+                stack.append((node, True))
+                stack += [(n, False) for n in sorted(node.parents if forwards else node.children)]
+        return visited
 
     def root_nodes(self, app=None):
         """
@@ -322,7 +256,7 @@ class MigrationGraph:
                 leaves.add(node)
         return sorted(leaves)
 
-    def ensure_not_cyclic(self, start, get_children):
+    def ensure_not_cyclic(self):
         # Algo from GvR:
         # http://neopythonic.blogspot.co.uk/2009/01/detecting-cycles-in-directed-graph.html
         todo = set(self.nodes)
@@ -331,7 +265,10 @@ class MigrationGraph:
             stack = [node]
             while stack:
                 top = stack[-1]
-                for node in get_children(top):
+                for child in self.node_map[top].children:
+                    # Use child.key instead of child to speed up the frequent
+                    # hashing.
+                    node = child.key
                     if node in stack:
                         cycle = stack[stack.index(node):]
                         raise CircularDependencyError(", ".join("%s.%s" % n for n in cycle))
diff --git a/django/db/migrations/loader.py b/django/db/migrations/loader.py
index 4147cc3f09..79de511860 100644
--- a/django/db/migrations/loader.py
+++ b/django/db/migrations/loader.py
@@ -213,11 +213,12 @@ class MigrationLoader:
         self.replacements = {}
         for key, migration in self.disk_migrations.items():
             self.graph.add_node(key, migration)
-            # Internal (aka same-app) dependencies.
-            self.add_internal_dependencies(key, migration)
             # Replacing migrations.
             if migration.replaces:
                 self.replacements[key] = migration
+        for key, migration in self.disk_migrations.items():
+            # Internal (same app) dependencies.
+            self.add_internal_dependencies(key, migration)
         # Add external dependencies now that the internal ones have been resolved.
         for key, migration in self.disk_migrations.items():
             self.add_external_dependencies(key, migration)
@@ -268,6 +269,7 @@ class MigrationLoader:
                         exc.node
                     ) from exc
             raise exc
+        self.graph.ensure_not_cyclic()
 
     def check_consistent_history(self, connection):
         """
diff --git a/tests/migrations/test_graph.py b/tests/migrations/test_graph.py
index 884aaa70f7..10a5696f57 100644
--- a/tests/migrations/test_graph.py
+++ b/tests/migrations/test_graph.py
@@ -1,9 +1,7 @@
 from django.db.migrations.exceptions import (
     CircularDependencyError, NodeNotFoundError,
 )
-from django.db.migrations.graph import (
-    RECURSION_DEPTH_WARNING, DummyNode, MigrationGraph, Node,
-)
+from django.db.migrations.graph import DummyNode, MigrationGraph, Node
 from django.test import SimpleTestCase
 
 
@@ -145,7 +143,7 @@ class GraphTests(SimpleTestCase):
         graph.add_dependency("app_b.0001", ("app_b", "0001"), ("app_a", "0003"))
         # Test whole graph
         with self.assertRaises(CircularDependencyError):
-            graph.forwards_plan(("app_a", "0003"))
+            graph.ensure_not_cyclic()
 
     def test_circular_graph_2(self):
         graph = MigrationGraph()
@@ -157,9 +155,9 @@ class GraphTests(SimpleTestCase):
         graph.add_dependency('C.0001', ('C', '0001'), ('B', '0001'))
 
         with self.assertRaises(CircularDependencyError):
-            graph.forwards_plan(('C', '0001'))
+            graph.ensure_not_cyclic()
 
-    def test_graph_recursive(self):
+    def test_iterative_dfs(self):
         graph = MigrationGraph()
         root = ("app_a", "1")
         graph.add_node(root, None)
@@ -178,28 +176,29 @@ class GraphTests(SimpleTestCase):
         backwards_plan = graph.backwards_plan(root)
         self.assertEqual(expected[::-1], backwards_plan)
 
-    def test_graph_iterative(self):
+    def test_iterative_dfs_complexity(self):
+        """
+        In a graph with merge migrations, iterative_dfs() traverses each node
+        only once even if there are multiple paths leading to it.
+        """
+        n = 50
         graph = MigrationGraph()
-        root = ("app_a", "1")
-        graph.add_node(root, None)
-        expected = [root]
-        for i in range(2, 1000):
-            parent = ("app_a", str(i - 1))
-            child = ("app_a", str(i))
-            graph.add_node(child, None)
-            graph.add_dependency(str(i), child, parent)
-            expected.append(child)
-        leaf = expected[-1]
-
-        with self.assertWarnsMessage(RuntimeWarning, RECURSION_DEPTH_WARNING):
-            forwards_plan = graph.forwards_plan(leaf)
-
-        self.assertEqual(expected, forwards_plan)
-
-        with self.assertWarnsMessage(RuntimeWarning, RECURSION_DEPTH_WARNING):
-            backwards_plan = graph.backwards_plan(root)
-
-        self.assertEqual(expected[::-1], backwards_plan)
+        for i in range(1, n + 1):
+            graph.add_node(('app_a', str(i)), None)
+            graph.add_node(('app_b', str(i)), None)
+            graph.add_node(('app_c', str(i)), None)
+        for i in range(1, n):
+            graph.add_dependency(None, ('app_b', str(i)), ('app_a', str(i)))
+            graph.add_dependency(None, ('app_c', str(i)), ('app_a', str(i)))
+            graph.add_dependency(None, ('app_a', str(i + 1)), ('app_b', str(i)))
+            graph.add_dependency(None, ('app_a', str(i + 1)), ('app_c', str(i)))
+        plan = graph.forwards_plan(('app_a', str(n)))
+        expected = [
+            (app, str(i))
+            for i in range(1, n)
+            for app in ['app_a', 'app_c', 'app_b']
+        ] + [('app_a', str(n))]
+        self.assertEqual(plan, expected)
 
     def test_plan_invalid_node(self):
         """
@@ -241,34 +240,39 @@ class GraphTests(SimpleTestCase):
         with self.assertRaisesMessage(NodeNotFoundError, msg):
             graph.add_dependency("app_a.0002", ("app_a", "0002"), ("app_a", "0001"))
 
-    def test_validate_consistency(self):
-        """
-        Tests for missing nodes, using `validate_consistency()` to raise the error.
-        """
-        # Build graph
+    def test_validate_consistency_missing_parent(self):
         graph = MigrationGraph()
         graph.add_node(("app_a", "0001"), None)
-        # Add dependency with missing parent node (skipping validation).
         graph.add_dependency("app_a.0001", ("app_a", "0001"), ("app_b", "0002"), skip_validation=True)
         msg = "Migration app_a.0001 dependencies reference nonexistent parent node ('app_b', '0002')"
         with self.assertRaisesMessage(NodeNotFoundError, msg):
             graph.validate_consistency()
-        # Add missing parent node and ensure `validate_consistency()` no longer raises error.
+
+    def test_validate_consistency_missing_child(self):
+        graph = MigrationGraph()
         graph.add_node(("app_b", "0002"), None)
-        graph.validate_consistency()
-        # Add dependency with missing child node (skipping validation).
-        graph.add_dependency("app_a.0002", ("app_a", "0002"), ("app_a", "0001"), skip_validation=True)
-        msg = "Migration app_a.0002 dependencies reference nonexistent child node ('app_a', '0002')"
+        graph.add_dependency("app_b.0002", ("app_a", "0001"), ("app_b", "0002"), skip_validation=True)
+        msg = "Migration app_b.0002 dependencies reference nonexistent child node ('app_a', '0001')"
         with self.assertRaisesMessage(NodeNotFoundError, msg):
             graph.validate_consistency()
-        # Add missing child node and ensure `validate_consistency()` no longer raises error.
-        graph.add_node(("app_a", "0002"), None)
+
+    def test_validate_consistency_no_error(self):
+        graph = MigrationGraph()
+        graph.add_node(("app_a", "0001"), None)
+        graph.add_node(("app_b", "0002"), None)
+        graph.add_dependency("app_a.0001", ("app_a", "0001"), ("app_b", "0002"), skip_validation=True)
         graph.validate_consistency()
-        # Rawly add dummy node.
-        msg = "app_a.0001 (req'd by app_a.0002) is missing!"
+
+    def test_validate_consistency_dummy(self):
+        """
+        validate_consistency() raises an error if there's an isolated dummy
+        node.
+        """
+        msg = "app_a.0001 (req'd by app_b.0002) is missing!"
+        graph = MigrationGraph()
         graph.add_dummy_node(
             key=("app_a", "0001"),
-            origin="app_a.0002",
+            origin="app_b.0002",
             error_message=msg
         )
         with self.assertRaisesMessage(NodeNotFoundError, msg):
@@ -382,7 +386,7 @@ class GraphTests(SimpleTestCase):
         graph.add_dependency("app_c.0001_squashed_0002", ("app_c", "0001_squashed_0002"), ("app_b", "0002"))
 
         with self.assertRaises(CircularDependencyError):
-            graph.forwards_plan(("app_c", "0001_squashed_0002"))
+            graph.ensure_not_cyclic()
 
     def test_stringify(self):
         graph = MigrationGraph()
@@ -413,14 +417,3 @@ class NodeTests(SimpleTestCase):
             error_message='x is missing',
         )
         self.assertEqual(repr(node), "<DummyNode: ('app_a', '0001')>")
-
-    def test_dummynode_promote(self):
-        dummy = DummyNode(
-            key=('app_a', '0001'),
-            origin='app_a.0002',
-            error_message="app_a.0001 (req'd by app_a.0002) is missing!",
-        )
-        dummy.promote()
-        self.assertIsInstance(dummy, Node)
-        self.assertFalse(hasattr(dummy, 'origin'))
-        self.assertFalse(hasattr(dummy, 'error_message'))