Commit e6f7f453 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Add an Executor for end-to-end running

parent 7f9a0b70
Loading
Loading
Loading
Loading
+68 −0
Original line number Diff line number Diff line
from .loader import MigrationLoader
from .recorder import MigrationRecorder


class MigrationExecutor(object):
    """
    End-to-end migration execution - loads migrations, and runs them
    up or down to a specified set of targets.
    """

    def __init__(self, connection):
        self.connection = connection
        self.loader = MigrationLoader(self.connection)
        self.recorder = MigrationRecorder(self.connection)

    def migration_plan(self, targets):
        """
        Given a set of targets, returns a list of (Migration instance, backwards?).
        """
        plan = []
        applied = self.recorder.applied_migrations()
        for target in targets:
            # If the migration is already applied, do backwards mode,
            # otherwise do forwards mode.
            if target in applied:
                for migration in self.loader.graph.backwards_plan(target)[:-1]:
                    if migration in applied:
                        plan.append((self.loader.graph.nodes[migration], True))
                        applied.remove(migration)
            else:
                for migration in self.loader.graph.forwards_plan(target):
                    if migration not in applied:
                        plan.append((self.loader.graph.nodes[migration], False))
                        applied.add(migration)
        return plan

    def migrate(self, targets):
        """
        Migrates the database up to the given targets.
        """
        plan = self.migration_plan(targets)
        for migration, backwards in plan:
            if not backwards:
                self.apply_migration(migration)
            else:
                self.unapply_migration(migration)

    def apply_migration(self, migration):
        """
        Runs a migration forwards.
        """
        print "Applying %s" % migration
        with self.connection.schema_editor() as schema_editor:
            project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
            migration.apply(project_state, schema_editor)
        self.recorder.record_applied(migration.app_label, migration.name)
        print "Finished %s" % migration

    def unapply_migration(self, migration):
        """
        Runs a migration backwards.
        """
        print "Unapplying %s" % migration
        with self.connection.schema_editor() as schema_editor:
            project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
            migration.unapply(project_state, schema_editor)
        self.recorder.record_unapplied(migration.app_label, migration.name)
        print "Finished %s" % migration
+48 −0
Original line number Diff line number Diff line
@@ -36,6 +36,17 @@ class Migration(object):
        self.name = name
        self.app_label = app_label

    def __eq__(self, other):
        if not isinstance(other, Migration):
            return False
        return (self.name == other.name) and (self.app_label == other.app_label)

    def __ne__(self, other):
        return not (self == other)

    def __repr__(self):
        return "<Migration %s.%s>" % (self.app_label, self.name)

    def mutate_state(self, project_state):
        """
        Takes a ProjectState and returns a new one with the migration's
@@ -45,3 +56,40 @@ class Migration(object):
        for operation in self.operations:
            operation.state_forwards(self.app_label, new_state)
        return new_state

    def apply(self, project_state, schema_editor):
        """
        Takes a project_state representing all migrations prior to this one
        and a schema_editor for a live database and applies the migration
        in a forwards order.

        Returns the resulting project state for efficient re-use by following
        Migrations.
        """
        for operation in self.operations:
            # Get the state after the operation has run
            new_state = project_state.clone()
            operation.state_forwards(self.app_label, new_state)
            # Run the operation
            operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
            # Switch states
            project_state = new_state
        return project_state

    def unapply(self, project_state, schema_editor):
        """
        Takes a project_state representing all migrations prior to this one
        and a schema_editor for a live database and applies the migration
        in a reverse order.
        """
        # We need to pre-calculate the stack of project states
        to_run = []
        for operation in self.operations:
            new_state = project_state.clone()
            operation.state_forwards(self.app_label, new_state)
            to_run.append((operation, project_state, new_state))
            project_state = new_state
        # Now run them in reverse
        to_run.reverse()
        for operation, to_state, from_state in to_run:
            operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
+8 −8
Original line number Diff line number Diff line
@@ -16,13 +16,13 @@ class AddField(Operation):

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        app_cache = to_state.render()
        model = app_cache.get_model(app_label, self.name)
        schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
        model = app_cache.get_model(app_label, self.model_name)
        schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        app_cache = from_state.render()
        model = app_cache.get_model(app_label, self.name)
        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
        model = app_cache.get_model(app_label, self.model_name)
        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])


class RemoveField(Operation):
@@ -43,10 +43,10 @@ class RemoveField(Operation):

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        app_cache = from_state.render()
        model = app_cache.get_model(app_label, self.name)
        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
        model = app_cache.get_model(app_label, self.model_name)
        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        app_cache = to_state.render()
        model = app_cache.get_model(app_label, self.name)
        schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
        model = app_cache.get_model(app_label, self.model_name)
        schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ class Migration(migrations.Migration):

        migrations.RemoveField("Author", "silly_field"),

        migrations.AddField("Author", "important", models.BooleanField()),
        migrations.AddField("Author", "rating", models.IntegerField(default=0)),

        migrations.CreateModel(
            "Book",
+35 −0
Original line number Diff line number Diff line
from django.test import TransactionTestCase
from django.db import connection
from django.db.migrations.executor import MigrationExecutor


class ExecutorTests(TransactionTestCase):
    """
    Tests the migration executor (full end-to-end running).

    Bear in mind that if these are failing you should fix the other
    test failures first, as they may be propagating into here.
    """

    def test_run(self):
        """
        Tests running a simple set of migrations.
        """
        executor = MigrationExecutor(connection)
        # Let's look at the plan first and make sure it's up to scratch
        plan = executor.migration_plan([("migrations", "0002_second")])
        self.assertEqual(
            plan,
            [
                (executor.loader.graph.nodes["migrations", "0001_initial"], False),
                (executor.loader.graph.nodes["migrations", "0002_second"], False),
            ],
        )
        # Were the tables there before?
        self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
        self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
        # Alright, let's try running it
        executor.migrate([("migrations", "0002_second")])
        # Are the tables there now?
        self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
        self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
Loading