Commit 5ab8b5d7 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Fix migration planner to fully understand squashed migrations. And test.

parent 4cfbde71
Loading
Loading
Loading
Loading
+15 −4
Original line number Diff line number Diff line
@@ -11,7 +11,6 @@ class MigrationExecutor(object):
    def __init__(self, connection, progress_callback=None):
        self.connection = connection
        self.loader = MigrationLoader(self.connection)
        self.loader.load_disk()
        self.recorder = MigrationRecorder(self.connection)
        self.progress_callback = progress_callback

@@ -20,7 +19,7 @@ class MigrationExecutor(object):
        Given a set of targets, returns a list of (Migration instance, backwards?).
        """
        plan = []
        applied = self.recorder.applied_migrations()
        applied = set(self.loader.applied_migrations)
        for target in targets:
            # If the target is (appname, None), that means unmigrate everything
            if target[1] is None:
@@ -87,7 +86,13 @@ class MigrationExecutor(object):
            with self.connection.schema_editor() as schema_editor:
                project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
                migration.apply(project_state, schema_editor)
        # For replacement migrations, record individual statuses
        if migration.replaces:
            for app_label, name in migration.replaces:
                self.recorder.record_applied(app_label, name)
        else:
            self.recorder.record_applied(migration.app_label, migration.name)
        # Report prgress
        if self.progress_callback:
            self.progress_callback("apply_success", migration)

@@ -101,6 +106,12 @@ class MigrationExecutor(object):
            with self.connection.schema_editor() as schema_editor:
                project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
                migration.unapply(project_state, schema_editor)
        # For replacement migrations, record individual statuses
        if migration.replaces:
            for app_label, name in migration.replaces:
                self.recorder.record_unapplied(app_label, name)
        else:
            self.recorder.record_unapplied(migration.app_label, migration.name)
        # Report progress
        if self.progress_callback:
            self.progress_callback("unapply_success", migration)
+29 −24
Original line number Diff line number Diff line
import os
import sys
from importlib import import_module
from django.utils.functional import cached_property
from django.db.models.loading import cache
from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.graph import MigrationGraph
from django.utils import six
from django.conf import settings


@@ -32,10 +33,12 @@ class MigrationLoader(object):
    in memory.
    """

    def __init__(self, connection):
    def __init__(self, connection, load=True):
        self.connection = connection
        self.disk_migrations = None
        self.applied_migrations = None
        if load:
            self.build_graph()

    @classmethod
    def migrations_module(cls, app_label):
@@ -55,6 +58,7 @@ class MigrationLoader(object):
            # Get the migrations module directory
            app_label = app.__name__.split(".")[-2]
            module_name = self.migrations_module(app_label)
            was_loaded = module_name in sys.modules
            try:
                module = import_module(module_name)
            except ImportError as e:
@@ -71,6 +75,9 @@ class MigrationLoader(object):
                # Module is not a package (e.g. migrations.py).
                if not hasattr(module, '__path__'):
                    continue
                # Force a reload if it's already loaded (tests need this)
                if was_loaded:
                    six.moves.reload_module(module)
            self.migrated_apps.add(app_label)
            directory = os.path.dirname(module.__file__)
            # Scan for .py[c|o] files
@@ -107,9 +114,6 @@ class MigrationLoader(object):

    def get_migration_by_prefix(self, app_label, name_prefix):
        "Returns the migration(s) which match the given app label and name _prefix_"
        # Make sure we have the disk data
        if self.disk_migrations is None:
            self.load_disk()
        # Do the search
        results = []
        for l, n in self.disk_migrations:
@@ -122,16 +126,15 @@ class MigrationLoader(object):
        else:
            return self.disk_migrations[results[0]]

    @cached_property
    def graph(self):
    def build_graph(self):
        """
        Builds a migration dependency graph using both the disk and database.
        You'll need to rebuild the graph if you apply migrations. This isn't
        usually a problem as generally migration stuff runs in a one-shot process.
        """
        # Make sure we have the disk data
        if self.disk_migrations is None:
        # Load disk data
        self.load_disk()
        # And the database data
        if self.applied_migrations is None:
        # Load database data
        recorder = MigrationRecorder(self.connection)
        self.applied_migrations = recorder.applied_migrations()
        # Do a first pass to separate out replacing and non-replacing migrations
@@ -152,12 +155,12 @@ class MigrationLoader(object):
        # Carry out replacements if we can - that is, if all replaced migrations
        # are either unapplied or missing.
        for key, migration in replacing.items():
            # Do the check
            can_replace = True
            for target in migration.replaces:
                if target in self.applied_migrations:
                    can_replace = False
                    break
            # Ensure this replacement migration is not in applied_migrations
            self.applied_migrations.discard(key)
            # Do the check. We can replace if all our replace targets are
            # applied, or if all of them are unapplied.
            applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
            can_replace = all(applied_statuses) or (not any(applied_statuses))
            if not can_replace:
                continue
            # Alright, time to replace. Step through the replaced migrations
@@ -171,14 +174,16 @@ class MigrationLoader(object):
                    normal[child_key].dependencies.remove(replaced)
                    normal[child_key].dependencies.append(key)
            normal[key] = migration
            # Mark the replacement as applied if all its replaced ones are
            if all(applied_statuses):
                self.applied_migrations.add(key)
        # Finally, make a graph and load everything into it
        graph = MigrationGraph()
        self.graph = MigrationGraph()
        for key, migration in normal.items():
            graph.add_node(key, migration)
            self.graph.add_node(key, migration)
        for key, migration in normal.items():
            for parent in migration.dependencies:
                graph.add_dependency(key, parent)
        return graph
                self.graph.add_dependency(key, parent)


class BadMigrationError(Exception):
+5 −0
Original line number Diff line number Diff line
@@ -39,6 +39,11 @@ class Migration(object):
    def __init__(self, name, app_label):
        self.name = name
        self.app_label = app_label
        # Copy dependencies & other attrs as we might mutate them at runtime
        self.operations = list(self.__class__.operations)
        self.dependencies = list(self.__class__.dependencies)
        self.run_before = list(self.__class__.run_before)
        self.replaces = list(self.__class__.replaces)

    def __eq__(self, other):
        if not isinstance(other, Migration):
+53 −0
Original line number Diff line number Diff line
@@ -38,7 +38,58 @@ class ExecutorTests(TransactionTestCase):
        # Are the tables there now?
        self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
        self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
        # Rebuild the graph to reflect the new DB state
        executor.loader.build_graph()
        # Alright, let's undo what we did
        plan = executor.migration_plan([("migrations", None)])
        self.assertEqual(
            plan,
            [
                (executor.loader.graph.nodes["migrations", "0002_second"], True),
                (executor.loader.graph.nodes["migrations", "0001_initial"], True),
            ],
        )
        executor.migrate([("migrations", None)])
        # Are the tables gone?
        self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
        self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))

    @override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
    def test_run_with_squashed(self):
        """
        Tests running a squashed migration from zero (should ignore what it replaces)
        """
        executor = MigrationExecutor(connection)
        executor.recorder.flush()
        # Check our leaf node is the squashed one
        leaves = [key for key in executor.loader.graph.leaf_nodes() if key[0] == "migrations"]
        self.assertEqual(leaves, [("migrations", "0001_squashed_0002")])
        # Check the plan
        plan = executor.migration_plan([("migrations", "0001_squashed_0002")])
        self.assertEqual(
            plan,
            [
                (executor.loader.graph.nodes["migrations", "0001_squashed_0002"], False),
            ],
        )
        # Were the tables there before?
        self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
        self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
        # Alright, let's try running it
        executor.migrate([("migrations", "0001_squashed_0002")])
        # Are the tables there now?
        self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
        self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
        # Rebuild the graph to reflect the new DB state
        executor.loader.build_graph()
        # Alright, let's undo what we did. Should also just use squashed.
        plan = executor.migration_plan([("migrations", None)])
        self.assertEqual(
            plan,
            [
                (executor.loader.graph.nodes["migrations", "0001_squashed_0002"], True),
            ],
        )
        executor.migrate([("migrations", None)])
        # Are the tables gone?
        self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
@@ -70,6 +121,8 @@ class ExecutorTests(TransactionTestCase):
        )
        # Fake-apply all migrations
        executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True)
        # Rebuild the graph to reflect the new DB state
        executor.loader.build_graph()
        # Now plan a second time and make sure it's empty
        plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")])
        self.assertEqual(plan, [])
+22 −9
Original line number Diff line number Diff line
@@ -82,21 +82,34 @@ class LoaderTests(TestCase):
            migration_loader.get_migration_by_prefix("migrations", "blarg")

    def test_load_import_error(self):
        migration_loader = MigrationLoader(connection)

        with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}):
            with self.assertRaises(ImportError):
                migration_loader.load_disk()
                MigrationLoader(connection)

    def test_load_module_file(self):
        migration_loader = MigrationLoader(connection)

        with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}):
            migration_loader.load_disk()
            MigrationLoader(connection)

    @skipIf(six.PY2, "PY2 doesn't load empty dirs.")
    def test_load_empty_dir(self):
        migration_loader = MigrationLoader(connection)

        with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}):
            migration_loader.load_disk()
            MigrationLoader(connection)

    @override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
    def test_loading_squashed(self):
        "Tests loading a squashed migration"
        migration_loader = MigrationLoader(connection)
        recorder = MigrationRecorder(connection)
        # Loading with nothing applied should just give us the one node
        self.assertEqual(
            len(migration_loader.graph.nodes),
            1,
        )
        # However, fake-apply one migration and it should now use the old two
        recorder.record_applied("migrations", "0001_initial")
        migration_loader.build_graph()
        self.assertEqual(
            len(migration_loader.graph.nodes),
            2,
        )
        recorder.flush()
Loading