Commit b878c73f authored by Ben Reilly's avatar Ben Reilly
Browse files

switch out recursive dfs for stack based approach, to avoid possibly hitting the recursion limit

parent 45768e6b
Loading
Loading
Loading
Loading
+20 −25
Original line number Diff line number Diff line
@@ -94,31 +94,26 @@ class MigrationGraph(object):
        """
        Dynamic programming based depth first search, for finding dependencies.
        """
        cache = {}

        def _dfs(start, get_children, path):
            # If we already computed this, use that (dynamic programming)
            if (start, get_children) in cache:
                return cache[(start, get_children)]
            # If we've traversed here before, that's a circular dep
            if start in path:
                raise CircularDependencyError(path[path.index(start):] + [start])
            # Build our own results list, starting with us
            results = []
            results.append(start)
            # We need to add to results all the migrations this one depends on
            children = sorted(get_children(start))
            path.append(start)
            for n in children:
                results = _dfs(n, get_children, path) + results
            path.pop()
            # Use OrderedSet to ensure only one instance of each result
            results = list(OrderedSet(results))
            # Populate DP cache
            cache[(start, get_children)] = results
            # Done!
            return results
        return _dfs(start, get_children, [])
        visited = []
        visited.append(start)
        path = [start]
        stack = sorted(get_children(start))
        while stack:
            node = stack.pop(0)

            if node in path:
                raise CircularDependencyError()
            path.append(node)

            visited.insert(0, node)
            children = sorted(get_children(node))

            if not children:
                path = []

            stack = children + stack

        return list(OrderedSet(visited))

    def __str__(self):
        return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values()))
+15 −0
Original line number Diff line number Diff line
@@ -134,6 +134,21 @@ class GraphTests(TestCase):
            graph.forwards_plan, ("app_a", "0003"),
        )

    def test_dfs(self):
        graph = MigrationGraph()
        root = ("app_a", "1")
        graph.add_node(root, None)
        expected = [root]
        for i in xrange(2, 1000):
            parent = ("app_a", str(i - 1))
            child = ("app_a", str(i))
            graph.add_node(child, None)
            graph.add_dependency(str(i), child, parent)
            expected.append(child)

        actual = graph.dfs(root, lambda x: graph.dependents.get(x, set()))
        self.assertEqual(expected[::-1], actual)

    def test_plan_invalid_node(self):
        """
        Tests for forwards/backwards_plan of nonexistent node.