Commit ab5cbae9 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

First stab at some migration creation commands

parent 2ae8a8a7
Loading
Loading
Loading
Loading
+52 −0
Original line number Diff line number Diff line
import sys
from optparse import make_option

from django.core.management.base import BaseCommand
from django.core.management.color import color_style
from django.core.exceptions import ImproperlyConfigured
from django.db import connections
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.autodetector import MigrationAutodetector, InteractiveMigrationQuestioner
from django.db.migrations.state import ProjectState
from django.db.models.loading import cache


class Command(BaseCommand):
    option_list = BaseCommand.option_list + (
        make_option('--empty', action='store_true', dest='empty', default=False,
            help='Make a blank migration.'),
    )

    help = "Creates new migration(s) for apps."
    usage_str = "Usage: ./manage.py createmigration [--empty] [app [app ...]]"

    def handle(self, *app_labels, **options):

        self.verbosity = int(options.get('verbosity'))
        self.interactive = options.get('interactive')
        self.style = color_style()

        # Make sure the app they asked for exists
        app_labels = set(app_labels)
        for app_label in app_labels:
            try:
                cache.get_app(app_label)
            except ImproperlyConfigured:
                self.stderr.write("The app you specified - '%s' - could not be found. Is it in INSTALLED_APPS?" % app_label)
                sys.exit(2)

        # Load the current graph state
        loader = MigrationLoader(connections["default"])

        # Detect changes
        autodetector = MigrationAutodetector(
            loader.graph.project_state(),
            ProjectState.from_app_cache(cache),
            InteractiveMigrationQuestioner(specified_apps=app_labels),
        )
        changes = autodetector.changes()
        changes = autodetector.arrange_for_graph(changes, loader.graph)
        if app_labels:
            changes = autodetector.trim_to_apps(changes, app_labels)

        print changes
+7 −6
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import traceback
from django.conf import settings
from django.core.management import call_command
from django.core.management.base import NoArgsCommand
from django.core.management.color import color_style
from django.core.management.color import color_style, no_style
from django.core.management.sql import custom_sql_for_model, emit_post_sync_signal, emit_pre_sync_signal
from django.db import connections, router, transaction, models, DEFAULT_DB_ALIAS
from django.db.migrations.executor import MigrationExecutor
@@ -32,6 +32,7 @@ class Command(NoArgsCommand):
        self.interactive = options.get('interactive')
        self.show_traceback = options.get('traceback')
        self.load_initial_data = options.get('load_initial_data')
        self.test_database = options.get('test_database', False)

        self.style = color_style()

@@ -144,14 +145,14 @@ class Command(NoArgsCommand):
                    # Create the model's database table, if it doesn't already exist.
                    if self.verbosity >= 3:
                        self.stdout.write("    Processing %s.%s model\n" % (app_name, model._meta.object_name))
                    sql, references = connection.creation.sql_create_model(model, self.style, seen_models)
                    sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
                    seen_models.add(model)
                    created_models.add(model)
                    for refto, refs in references.items():
                        pending_references.setdefault(refto, []).extend(refs)
                        if refto in seen_models:
                            sql.extend(connection.creation.sql_for_pending_references(refto, self.style, pending_references))
                    sql.extend(connection.creation.sql_for_pending_references(model, self.style, pending_references))
                            sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
                    sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
                    if self.verbosity >= 1 and sql:
                        self.stdout.write("    Creating table %s\n" % model._meta.db_table)
                    for statement in sql:
@@ -172,7 +173,7 @@ class Command(NoArgsCommand):
        for app_name, model_list in manifest.items():
            for model in model_list:
                if model in created_models:
                    custom_sql = custom_sql_for_model(model, self.style, connection)
                    custom_sql = custom_sql_for_model(model, no_style(), connection)
                    if custom_sql:
                        if self.verbosity >= 2:
                            self.stdout.write("    Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
@@ -194,7 +195,7 @@ class Command(NoArgsCommand):
        for app_name, model_list in manifest.items():
            for model in model_list:
                if model in created_models:
                    index_sql = connection.creation.sql_indexes_for_model(model, self.style)
                    index_sql = connection.creation.sql_indexes_for_model(model, no_style())
                    if index_sql:
                        if self.verbosity >= 2:
                            self.stdout.write("    Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
+106 −33
Original line number Diff line number Diff line
import re
from django.utils.six.moves import input
from django.db.migrations import operations
from django.db.migrations.migration import Migration
from django.db.models.loading import cache


class MigrationAutodetector(object):
@@ -16,9 +18,10 @@ class MigrationAutodetector(object):
    if it wishes, with the caveat that it may not always be possible.
    """

    def __init__(self, from_state, to_state):
    def __init__(self, from_state, to_state, questioner=None):
        self.from_state = from_state
        self.to_state = to_state
        self.questioner = questioner or MigrationQuestioner()

    def changes(self):
        """
@@ -54,7 +57,7 @@ class MigrationAutodetector(object):
                    model_state.name,
                )
            )
        # Alright, now sort out and return the migrations
        # Alright, now add internal dependencies
        for app_label, migrations in self.migrations.items():
            for m1, m2 in zip(migrations, migrations[1:]):
                m2.dependencies.append((app_label, m1.name))
@@ -68,34 +71,7 @@ class MigrationAutodetector(object):
            migrations.append(instance)
        migrations[-1].operations.append(operation)

    @classmethod
    def suggest_name(cls, ops):
        """
        Given a set of operations, suggests a name for the migration
        they might represent. Names not guaranteed to be unique; they
        must be prefixed by a number or date.
        """
        if len(ops) == 1:
            if isinstance(ops[0], operations.CreateModel):
                return ops[0].name.lower()
            elif isinstance(ops[0], operations.DeleteModel):
                return "delete_%s" % ops[0].name.lower()
        elif all(isinstance(o, operations.CreateModel) for o in ops):
            return "_".join(sorted(o.name.lower() for o in ops))
        return "auto"

    @classmethod
    def parse_number(cls, name):
        """
        Given a migration name, tries to extract a number from the
        beginning of it. If no number found, returns None.
        """
        if re.match(r"^\d+_", name):
            return int(name.split("_")[0])
        return None

    @classmethod
    def arrange_for_graph(cls, changes, graph):
    def arrange_for_graph(self, changes, graph):
        """
        Takes in a result from changes() and a MigrationGraph,
        and fixes the names and dependencies of the changes so they
@@ -103,7 +79,7 @@ class MigrationAutodetector(object):
        """
        leaves = graph.leaf_nodes()
        name_map = {}
        for app_label, migrations in changes.items():
        for app_label, migrations in list(changes.items()):
            if not migrations:
                continue
            # Find the app label's current leaf node
@@ -112,11 +88,17 @@ class MigrationAutodetector(object):
                if leaf[0] == app_label:
                    app_leaf = leaf
                    break
            # Do they want an initial migration for this app?
            if app_leaf is None and not self.questioner.ask_initial(app_label):
                # They don't.
                for migration in migrations:
                    name_map[(app_label, migration.name)] = (app_label, "__first__")
                del changes[app_label]
            # Work out the next number in the sequence
            if app_leaf is None:
                next_number = 1
            else:
                next_number = (cls.parse_number(app_leaf[1]) or 0) + 1
                next_number = (self.parse_number(app_leaf[1]) or 0) + 1
            # Name each migration
            for i, migration in enumerate(migrations):
                if i == 0 and app_leaf:
@@ -124,7 +106,7 @@ class MigrationAutodetector(object):
                if i == 0 and not app_leaf:
                    new_name = "0001_initial"
                else:
                    new_name = "%04i_%s" % (next_number, cls.suggest_name(migration.operations))
                    new_name = "%04i_%s" % (next_number, self.suggest_name(migration.operations))
                name_map[(app_label, migration.name)] = (app_label, new_name)
                migration.name = new_name
        # Now fix dependencies
@@ -132,3 +114,94 @@ class MigrationAutodetector(object):
            for migration in migrations:
                migration.dependencies = [name_map.get(d, d) for d in migration.dependencies]
        return changes

    def trim_to_apps(self, changes, app_labels):
        """
        Takes changes from arrange_for_graph and set of app labels and
        returns a modified set of changes which trims out as many migrations
        that are not in app_labels as possible.
        Note that some other migrations may still be present, as they may be
        required dependencies.
        """
        # Gather other app dependencies in a first pass
        app_dependencies = {}
        for app_label, migrations in changes.items():
            for migration in migrations:
                for dep_app_label, name in migration.dependencies:
                    app_dependencies.setdefault(app_label, set()).add(dep_app_label)
        required_apps = set(app_labels)
        # Keep resolving till there's no change
        old_required_apps = None
        while old_required_apps != required_apps:
            old_required_apps = set(required_apps)
            for app_label in list(required_apps):
                required_apps.update(app_dependencies.get(app_label, set()))
        # Remove all migrations that aren't needed
        for app_label in list(changes.keys()):
            if app_label not in required_apps:
                del changes[app_label]
        return changes

    @classmethod
    def suggest_name(cls, ops):
        """
        Given a set of operations, suggests a name for the migration
        they might represent. Names not guaranteed to be unique; they
        must be prefixed by a number or date.
        """
        if len(ops) == 1:
            if isinstance(ops[0], operations.CreateModel):
                return ops[0].name.lower()
            elif isinstance(ops[0], operations.DeleteModel):
                return "delete_%s" % ops[0].name.lower()
        elif all(isinstance(o, operations.CreateModel) for o in ops):
            return "_".join(sorted(o.name.lower() for o in ops))
        return "auto"

    @classmethod
    def parse_number(cls, name):
        """
        Given a migration name, tries to extract a number from the
        beginning of it. If no number found, returns None.
        """
        if re.match(r"^\d+_", name):
            return int(name.split("_")[0])
        return None


class MigrationQuestioner(object):
    """
    Gives the autodetector responses to questions it might have.
    This base class has a built-in noninteractive mode, but the
    interactive subclass is what the command-line arguments will use.
    """

    def __init__(self, defaults=None):
        self.defaults = defaults or {}

    def ask_initial(self, app_label):
        "Should we create an initial migration for the app?"
        return self.defaults.get("ask_initial", False)


class InteractiveMigrationQuestioner(MigrationQuestioner):

    def __init__(self, specified_apps=set()):
        self.specified_apps = specified_apps

    def _boolean_input(self, question):
        result = input("%s " % question)
        while len(result) < 1 or result[0].lower() not in "yn":
            result = input("Please answer yes or no: ")
        return result[0].lower() == "y"

    def ask_initial(self, app_label):
        # Don't ask for django.contrib apps
        app = cache.get_app(app_label)
        if app.__name__.startswith("django.contrib"):
            return False
        # If it was specified on the command line, definitely true
        if app_label in self.specified_apps:
            return True
        # Now ask
        return self._boolean_input("Do you want to enable migrations for app '%s'?" % app_label)
+8 −3
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ class MigrationGraph(object):
        a database.
        """
        if node not in self.nodes:
            raise ValueError("Node %r not a valid node" % node)
            raise ValueError("Node %r not a valid node" % (node, ))
        return self.dfs(node, lambda x: self.dependencies.get(x, set()))

    def backwards_plan(self, node):
@@ -60,7 +60,7 @@ class MigrationGraph(object):
        a database.
        """
        if node not in self.nodes:
            raise ValueError("Node %r not a valid node" % node)
            raise ValueError("Node %r not a valid node" % (node, ))
        return self.dfs(node, lambda x: self.dependents.get(x, set()))

    def root_nodes(self):
@@ -120,11 +120,16 @@ class MigrationGraph(object):
    def __str__(self):
        return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values()))

    def project_state(self, nodes, at_end=True):
    def project_state(self, nodes=None, at_end=True):
        """
        Given a migration node or nodes, returns a complete ProjectState for it.
        If at_end is False, returns the state before the migration has run.
        If nodes is not provided, returns the overall most current project state.
        """
        if nodes is None:
            nodes = list(self.leaf_nodes())
        if len(nodes) == 0:
            return ProjectState()
        if not isinstance(nodes[0], tuple):
            nodes = [nodes]
        plan = []
+7 −0
Original line number Diff line number Diff line
@@ -60,3 +60,10 @@ class MigrationRecorder(object):
        """
        self.ensure_schema()
        self.Migration.objects.filter(app=app, name=name).delete()

    @classmethod
    def flush(cls):
        """
        Deletes all migration records. Useful if you're testing migrations.
        """
        cls.Migration.objects.all().delete()
Loading