Commit ce05b8a6 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Don't make a second migration if there was a force-null-default addcol.

parent df800b16
Loading
Loading
Loading
Loading
+20 −8
Original line number Diff line number Diff line
@@ -188,7 +188,18 @@ class MigrationAutodetector(object):
                    continue
                # You can't just add NOT NULL fields with no default
                if not field.null and not field.has_default():
                    field = field.clone()
                    field.default = self.questioner.ask_not_null_addition(field_name, model_name)
                    self.add_to_migration(
                        app_label,
                        operations.AddField(
                            model_name=model_name,
                            name=field_name,
                            field=field,
                            preserve_default=False,
                        )
                    )
                else:
                    self.add_to_migration(
                        app_label,
                        operations.AddField(
@@ -434,7 +445,8 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
        "Adding a NOT NULL field to a model"
        choice = self._choice_input(
            "You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) +
            "this is not possible. Please select a fix:",
            "we can't do that (the database needs something to populate existing rows).\n" +
            "Please select a fix:",
            [
                "Provide a one-off default now (will be set on all existing rows)",
                "Quit, and let me add a default in models.py",
+10 −2
Original line number Diff line number Diff line
from django.db import router
from django.db.models.fields import NOT_PROVIDED
from .base import Operation


@@ -7,13 +8,20 @@ class AddField(Operation):
    Adds a field to a model.
    """

    def __init__(self, model_name, name, field):
    def __init__(self, model_name, name, field, preserve_default=True):
        self.model_name = model_name
        self.name = name
        self.field = field
        self.preserve_default = preserve_default

    def state_forwards(self, app_label, state):
        state.models[app_label, self.model_name.lower()].fields.append((self.name, self.field))
        # If preserve default is off, don't use the default for future state
        if not self.preserve_default:
            field = self.field.clone()
            field.default = NOT_PROVIDED
        else:
            field = self.field
        state.models[app_label, self.model_name.lower()].fields.append((self.name, field))

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
+33 −1
Original line number Diff line number Diff line
from django.db import connection, models, migrations, router
from django.db.models.fields import NOT_PROVIDED
from django.db.transaction import atomic
from django.db.utils import IntegrityError
from django.db.migrations.state import ProjectState
@@ -130,10 +131,19 @@ class OperationTests(MigrationTestBase):
        """
        project_state = self.set_up_test_model("test_adfl")
        # Test the state alteration
        operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
        operation = migrations.AddField(
            "Pony",
            "height",
            models.FloatField(null=True, default=5),
        )
        new_state = project_state.clone()
        operation.state_forwards("test_adfl", new_state)
        self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4)
        field = [
            f for n, f in new_state.models["test_adfl", "pony"].fields
            if n == "height"
        ][0]
        self.assertEqual(field.default, 5)
        # Test the database alteration
        self.assertColumnNotExists("test_adfl_pony", "height")
        with connection.schema_editor() as editor:
@@ -144,6 +154,28 @@ class OperationTests(MigrationTestBase):
            operation.database_backwards("test_adfl", editor, new_state, project_state)
        self.assertColumnNotExists("test_adfl_pony", "height")

    def test_add_field_preserve_default(self):
        """
        Tests the AddField operation's state alteration
        when preserve_default = False.
        """
        project_state = self.set_up_test_model("test_adflpd")
        # Test the state alteration
        operation = migrations.AddField(
            "Pony",
            "height",
            models.FloatField(null=True, default=4),
            preserve_default = False,
        )
        new_state = project_state.clone()
        operation.state_forwards("test_adflpd", new_state)
        self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 4)
        field = [
            f for n, f in new_state.models["test_adflpd", "pony"].fields
            if n == "height"
        ][0]
        self.assertEqual(field.default, NOT_PROVIDED)

    def test_add_field_m2m(self):
        """
        Tests the AddField operation with a ManyToManyField.