Commit 8a1f0177 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Add root_node and leaf_node functions to MigrationGraph

parent 9ce83546
Loading
Loading
Loading
Loading
+28 −2
Original line number Diff line number Diff line
@@ -19,8 +19,9 @@ class MigrationGraph(object):
    replacing migration, and repoint any dependencies that pointed to the
    replaced migrations to point to the replacing one.

    A node should be a tuple: (app_path, migration_name) - but the code
    here doesn't really care.
    A node should be a tuple: (app_path, migration_name). The tree special-cases
    things within an app - namely, root nodes and leaf nodes ignore dependencies
    to other apps.
    """

    def __init__(self):
@@ -59,6 +60,31 @@ class MigrationGraph(object):
            raise ValueError("Node %r not a valid node" % node)
        return self.dfs(node, lambda x: self.dependents.get(x, set()))

    def root_nodes(self):
        """
        Returns all root nodes - that is, nodes with no dependencies inside
        their app. These are the starting point for an app.
        """
        roots = set()
        for node in self.nodes:
            if not filter(lambda key: key[0] == node[0], self.dependencies.get(node, set())):
                roots.add(node)
        return roots

    def leaf_nodes(self):
        """
        Returns all leaf nodes - that is, nodes with no dependents in their app.
        These are the "most current" version of an app's schema.
        Having more than one per app is technically an error, but one that
        gets handled further up, in the interactive command - it's usually the
        result of a VCS merge and needs some user input.
        """
        leaves = set()
        for node in self.nodes:
            if not filter(lambda key: key[0] == node[0], self.dependents.get(node, set())):
                leaves.add(node)
        return leaves

    def dfs(self, start, get_children):
        """
        Dynamic programming based depth first search, for finding dependencies.
+18 −0
Original line number Diff line number Diff line
@@ -44,6 +44,15 @@ class GraphTests(TransactionTestCase):
            graph.backwards_plan(("app_b", "0002")),
            [('app_a', '0004'), ('app_a', '0003'), ('app_b', '0002')],
        )
        # Test roots and leaves
        self.assertEqual(
            graph.root_nodes(),
            set([('app_a', '0001'), ('app_b', '0001')]),
        )
        self.assertEqual(
            graph.leaf_nodes(),
            set([('app_a', '0004'), ('app_b', '0002')]),
        )

    def test_complex_graph(self):
        """
@@ -81,6 +90,15 @@ class GraphTests(TransactionTestCase):
            graph.backwards_plan(("app_b", "0001")),
            [('app_a', '0004'), ('app_c', '0002'), ('app_c', '0001'), ('app_a', '0003'), ('app_b', '0002'), ('app_b', '0001')],
        )
        # Test roots and leaves
        self.assertEqual(
            graph.root_nodes(),
            set([('app_a', '0001'), ('app_b', '0001'), ('app_c', '0001')]),
        )
        self.assertEqual(
            graph.leaf_nodes(),
            set([('app_a', '0004'), ('app_b', '0002'), ('app_c', '0002')]),
        )

    def test_circular_graph(self):
        """