Commit 73e30e9d authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Better naming, and prompt for NOT NULL field addition

parent 41214eaf
Loading
Loading
Loading
Loading
+60 −1
Original line number Diff line number Diff line
import re
import sys
from django.utils import datetime_safe
from django.utils.six.moves import input
from django.db.migrations import operations
from django.db.migrations.migration import Migration
@@ -66,12 +68,16 @@ class MigrationAutodetector(object):
            old_field_names = set([x for x, y in old_model_state.fields])
            new_field_names = set([x for x, y in new_model_state.fields])
            for field_name in new_field_names - old_field_names:
                field = [y for x, y in new_model_state.fields if x == field_name][0]
                # You can't just add NOT NULL fields with no default
                if not field.null and not field.has_default():
                    field.default = self.questioner.ask_not_null_addition(field_name, model_name)
                self.add_to_migration(
                    app_label,
                    operations.AddField(
                        model_name = model_name,
                        name = field_name,
                        field = [y for x, y in new_model_state.fields if x == field_name][0],
                        field = field,
                    )
                )
            # Old fields
@@ -180,6 +186,10 @@ class MigrationAutodetector(object):
                return ops[0].name.lower()
            elif isinstance(ops[0], operations.DeleteModel):
                return "delete_%s" % ops[0].name.lower()
            elif isinstance(ops[0], operations.AddField):
                return "%s_%s" % (ops[0].model_name.lower(), ops[0].name.lower())
            elif isinstance(ops[0], operations.RemoveField):
                return "remove_%s_%s" % (ops[0].model_name.lower(), ops[0].name.lower())
        elif all(isinstance(o, operations.CreateModel) for o in ops):
            return "_".join(sorted(o.name.lower() for o in ops))
        return "auto"
@@ -209,6 +219,11 @@ class MigrationQuestioner(object):
        "Should we create an initial migration for the app?"
        return self.defaults.get("ask_initial", False)

    def ask_not_null_addition(self, field_name, model_name):
        "Adding a NOT NULL field to a model"
        # None means quit
        return None


class InteractiveMigrationQuestioner(MigrationQuestioner):

@@ -221,7 +236,22 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
            result = input("Please answer yes or no: ")
        return result[0].lower() == "y"

    def _choice_input(self, question, choices):
        print question
        for i, choice in enumerate(choices):
            print " %s) %s" % (i + 1, choice)
        result = input("Select an option: ")
        while True:
            try:
                value = int(result)
                if 0 < value <= len(choices):
                    return value
            except ValueError:
                pass
            result = input("Please select a valid option: ")

    def ask_initial(self, app_label):
        "Should we create an initial migration for the app?"
        # Don't ask for django.contrib apps
        app = cache.get_app(app_label)
        if app.__name__.startswith("django.contrib"):
@@ -231,3 +261,32 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
            return True
        # Now ask
        return self._boolean_input("Do you want to enable migrations for app '%s'?" % app_label)

    def ask_not_null_addition(self, field_name, model_name):
        "Adding a NOT NULL field to a model"
        choice = self._choice_input(
            "You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) +
            "this is not possible. Please select a fix:",
            [
                "Provide a one-off default now (will be set on all existing rows)",
                "Quit, and let me add a default in models.py",
            ]
        )
        if choice == 2:
            sys.exit(3)
        else:
            print("Please enter the default value now, as valid Python")
            print("The datetime module is available, so you can do e.g. datetime.date.today()")
            while True:
                code = input(">>> ")
                if not code:
                    print("Please enter some code, or 'exit' (with no quotes) to exit.")
                elif code == "exit":
                    sys.exit(1)
                else:
                    try:
                        return eval(code, {}, {"datetime": datetime_safe})
                    except (SyntaxError, NameError) as e:
                        print("Invalid input: %s" % e)
                    else:
                        break