Commit 78d43a5e authored by Marten Kenbeek's avatar Marten Kenbeek Committed by Markus Holtermann
Browse files

Fixed #24366 -- Optimized traversal of large migration dependency graphs.

Switched from an adjancency list and uncached, iterative depth-first
search to a Node-based design with direct parent/child links and a
cached, recursive depth-first search. With this change, calculating
a migration plan for a large graph takes several seconds instead of
several hours.

Marked test `migrations.test_graph.GraphTests.test_dfs` as an expected
failure due to reaching the maximum recursion depth.
parent 7fa7dd48
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ class MigrationExecutor(object):
                # child(ren) in the same app, and no further.
                next_in_app = sorted(
                    n for n in
                    self.loader.graph.dependents.get(target, set())
                    self.loader.graph.node_map[target].children
                    if n[0] == target[0]
                )
                for node in next_in_app:
+91 −30
Original line number Diff line number Diff line
@@ -5,6 +5,68 @@ from collections import deque
from django.db.migrations.state import ProjectState
from django.utils.datastructures import OrderedSet
from django.utils.encoding import python_2_unicode_compatible
from django.utils.functional import total_ordering


@python_2_unicode_compatible
@total_ordering
class Node(object):
    """
    A single node in the migration graph. Contains direct links to adjacent
    nodes in either direction.
    """
    def __init__(self, key):
        self.key = key
        self.children = set()
        self.parents = set()

    def __eq__(self, other):
        return self.key == other

    def __lt__(self, other):
        return self.key < other

    def __hash__(self):
        return hash(self.key)

    def __getitem__(self, item):
        return self.key[item]

    def __str__(self):
        return str(self.key)

    def __repr__(self):
        return '<Node: (%r, %r)>' % self.key

    def add_child(self, child):
        self.children.add(child)

    def add_parent(self, parent):
        self.parents.add(parent)

    # Use manual caching, @cached_property effectively doubles the
    # recursion depth for each recursion.
    def ancestors(self):
        # Use self.key instead of self to speed up the frequent hashing
        # when constructing an OrderedSet.
        if '_ancestors' not in self.__dict__:
            ancestors = deque([self.key])
            for parent in sorted(self.parents):
                ancestors.extendleft(reversed(parent.ancestors()))
            self.__dict__['_ancestors'] = list(OrderedSet(ancestors))
        return self.__dict__['_ancestors']

    # Use manual caching, @cached_property effectively doubles the
    # recursion depth for each recursion.
    def descendants(self):
        # Use self.key instead of self to speed up the frequent hashing
        # when constructing an OrderedSet.
        if '_descendants' not in self.__dict__:
            descendants = deque([self.key])
            for child in sorted(self.children):
                descendants.extendleft(reversed(child.descendants()))
            self.__dict__['_descendants'] = list(OrderedSet(descendants))
        return self.__dict__['_descendants']


@python_2_unicode_compatible
@@ -32,12 +94,15 @@ class MigrationGraph(object):
    """

    def __init__(self):
        self.node_map = {}
        self.nodes = {}
        self.dependencies = {}
        self.dependents = {}
        self.cached = False

    def add_node(self, node, implementation):
        self.nodes[node] = implementation
    def add_node(self, key, implementation):
        node = Node(key)
        self.node_map[key] = node
        self.nodes[key] = implementation
        self.clear_cache()

    def add_dependency(self, migration, child, parent):
        if child not in self.nodes:
@@ -50,8 +115,16 @@ class MigrationGraph(object):
                "Migration %s dependencies reference nonexistent parent node %r" % (migration, parent),
                parent
            )
        self.dependencies.setdefault(child, set()).add(parent)
        self.dependents.setdefault(parent, set()).add(child)
        self.node_map[child].add_parent(self.node_map[parent])
        self.node_map[parent].add_child(self.node_map[child])
        self.clear_cache()

    def clear_cache(self):
        if self.cached:
            for node in self.nodes:
                self.node_map[node].__dict__.pop('_ancestors', None)
                self.node_map[node].__dict__.pop('_descendants', None)
            self.cached = False

    def forwards_plan(self, node):
        """
@@ -62,7 +135,10 @@ class MigrationGraph(object):
        """
        if node not in self.nodes:
            raise NodeNotFoundError("Node %r not a valid node" % (node, ), node)
        return self.dfs(node, lambda x: self.dependencies.get(x, set()))
        # Use parent.key instead of parent to speed up the frequent hashing in ensure_not_cyclic
        self.ensure_not_cyclic(node, lambda x: (parent.key for parent in self.node_map[x].parents))
        self.cached = True
        return self.node_map[node].ancestors()

    def backwards_plan(self, node):
        """
@@ -73,7 +149,10 @@ class MigrationGraph(object):
        """
        if node not in self.nodes:
            raise NodeNotFoundError("Node %r not a valid node" % (node, ), node)
        return self.dfs(node, lambda x: self.dependents.get(x, set()))
        # Use child.key instead of child to speed up the frequent hashing in ensure_not_cyclic
        self.ensure_not_cyclic(node, lambda x: (child.key for child in self.node_map[x].children))
        self.cached = True
        return self.node_map[node].descendants()

    def root_nodes(self, app=None):
        """
@@ -82,7 +161,7 @@ class MigrationGraph(object):
        """
        roots = set()
        for node in self.nodes:
            if (not any(key[0] == node[0] for key in self.dependencies.get(node, set()))
            if (not any(key[0] == node[0] for key in self.node_map[node].parents)
                    and (not app or app == node[0])):
                roots.add(node)
        return sorted(roots)
@@ -97,7 +176,7 @@ class MigrationGraph(object):
        """
        leaves = set()
        for node in self.nodes:
            if (not any(key[0] == node[0] for key in self.dependents.get(node, set()))
            if (not any(key[0] == node[0] for key in self.node_map[node].children)
                    and (not app or app == node[0])):
                leaves.add(node)
        return sorted(leaves)
@@ -105,7 +184,7 @@ class MigrationGraph(object):
    def ensure_not_cyclic(self, start, get_children):
        # Algo from GvR:
        # http://neopythonic.blogspot.co.uk/2009/01/detecting-cycles-in-directed-graph.html
        todo = set(self.nodes.keys())
        todo = set(self.nodes)
        while todo:
            node = todo.pop()
            stack = [node]
@@ -122,28 +201,10 @@ class MigrationGraph(object):
                else:
                    node = stack.pop()

    def dfs(self, start, get_children):
        """
        Iterative depth first search, for finding dependencies.
        """
        self.ensure_not_cyclic(start, get_children)
        visited = deque()
        visited.append(start)
        stack = deque(sorted(get_children(start)))
        while stack:
            node = stack.popleft()
            visited.appendleft(node)
            children = sorted(get_children(node), reverse=True)
            # reverse sorting is needed because prepending using deque.extendleft
            # also effectively reverses values
            stack.extendleft(children)

        return list(OrderedSet(visited))

    def __str__(self):
        return "Graph: %s nodes, %s edges" % (
            len(self.nodes),
            sum(len(x) for x in self.dependencies.values()),
            sum(len(node.parents) for node in self.node_map.values()),
        )

    def make_state(self, nodes=None, at_end=True, real_apps=None):
+20 −2
Original line number Diff line number Diff line
from unittest import expectedFailure

from django.db.migrations.graph import (
    CircularDependencyError, MigrationGraph, NodeNotFoundError,
)
@@ -151,7 +153,23 @@ class GraphTests(TestCase):
            graph.forwards_plan, ('C', '0001')
        )

    def test_dfs(self):
    def test_deep_graph(self):
        graph = MigrationGraph()
        root = ("app_a", "1")
        graph.add_node(root, None)
        expected = [root]
        for i in range(2, 750):
            parent = ("app_a", str(i - 1))
            child = ("app_a", str(i))
            graph.add_node(child, None)
            graph.add_dependency(str(i), child, parent)
            expected.append(child)

        actual = graph.node_map[root].descendants()
        self.assertEqual(expected[::-1], actual)

    @expectedFailure
    def test_recursion_depth(self):
        graph = MigrationGraph()
        root = ("app_a", "1")
        graph.add_node(root, None)
@@ -163,7 +181,7 @@ class GraphTests(TestCase):
            graph.add_dependency(str(i), child, parent)
            expected.append(child)

        actual = graph.dfs(root, lambda x: graph.dependents.get(x, set()))
        actual = graph.node_map[root].descendants()
        self.assertEqual(expected[::-1], actual)

    def test_plan_invalid_node(self):