Commit 80bdf68d authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Add AlterField and RenameField operations

parent 6f667999
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
from .models import CreateModel, DeleteModel, AlterModelTable
from .fields import AddField, RemoveField
from .fields import AddField, RemoveField, AlterField, RenameField
+68 −0
Original line number Diff line number Diff line
@@ -54,3 +54,71 @@ class RemoveField(Operation):

    def describe(self):
        return "Remove field %s from %s" % (self.name, self.model_name)


class AlterField(Operation):
    """
    Alters a field's database column (e.g. null, max_length) to the provided new field
    """

    def __init__(self, model_name, name, field):
        self.model_name = model_name
        self.name = name
        self.field = field

    def state_forwards(self, app_label, state):
        state.models[app_label, self.model_name.lower()].fields = [
            (n, self.field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
        ]

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
        to_model = to_state.render().get_model(app_label, self.model_name)
        schema_editor.alter_field(
            from_model,
            from_model._meta.get_field_by_name(self.name)[0],
            to_model._meta.get_field_by_name(self.name)[0],
        )

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        self.database_forwards(app_label, schema_editor, from_state, to_state)

    def describe(self):
        return "Alter field %s on %s" % (self.name, self.model_name)


class RenameField(Operation):
    """
    Renames a field on the model. Might affect db_column too.
    """

    def __init__(self, model_name, old_name, new_name):
        self.model_name = model_name
        self.old_name = old_name
        self.new_name = new_name

    def state_forwards(self, app_label, state):
        state.models[app_label, self.model_name.lower()].fields = [
            (self.new_name if n == self.old_name else n, f) for n, f in state.models[app_label, self.model_name.lower()].fields
        ]

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
        to_model = to_state.render().get_model(app_label, self.model_name)
        schema_editor.alter_field(
            from_model,
            from_model._meta.get_field_by_name(self.old_name)[0],
            to_model._meta.get_field_by_name(self.new_name)[0],
        )

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
        to_model = to_state.render().get_model(app_label, self.model_name)
        schema_editor.alter_field(
            from_model,
            from_model._meta.get_field_by_name(self.new_name)[0],
            to_model._meta.get_field_by_name(self.old_name)[0],
        )

    def describe(self):
        return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
+8 −1
Original line number Diff line number Diff line
@@ -93,10 +93,17 @@ class ModelState(object):

    def clone(self):
        "Returns an exact copy of this ModelState"
        # We deep-clone the fields using deconstruction
        fields = []
        for name, field in self.fields:
            _, path, args, kwargs = field.deconstruct()
            field_class = import_by_path(path)
            fields.append((name, field_class(*args, **kwargs)))
        # Now make a copy
        return self.__class__(
            app_label = self.app_label,
            name = self.name,
            fields = list(self.fields),
            fields = fields,
            options = dict(self.options),
            bases = self.bases,
        )
+52 −1
Original line number Diff line number Diff line
@@ -22,6 +22,12 @@ class OperationTests(TestCase):
    def assertColumnNotExists(self, table, column):
        self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])

    def assertColumnNull(self, table, column):
        self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True)

    def assertColumnNotNull(self, table, column):
        self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False)

    def set_up_test_model(self, app_label):
        """
        Creates a test model state and database table.
@@ -50,7 +56,7 @@ class OperationTests(TestCase):
            "Pony",
            [
                ("id", models.AutoField(primary_key=True)),
                ("pink", models.BooleanField(default=True)),
                ("pink", models.IntegerField(default=1)),
            ],
        )
        # Test the state alteration
@@ -157,3 +163,48 @@ class OperationTests(TestCase):
            operation.database_backwards("test_almota", editor, new_state, project_state)
        self.assertTableExists("test_almota_pony")
        self.assertTableNotExists("test_almota_pony_2")

    def test_alter_field(self):
        """
        Tests the AlterField operation.
        """
        project_state = self.set_up_test_model("test_alfl")
        # Test the state alteration
        operation = migrations.AlterField("Pony", "pink", models.IntegerField(null=True))
        new_state = project_state.clone()
        operation.state_forwards("test_alfl", new_state)
        self.assertEqual([f for n, f in project_state.models["test_alfl", "pony"].fields if n == "pink"][0].null, False)
        self.assertEqual([f for n, f in new_state.models["test_alfl", "pony"].fields if n == "pink"][0].null, True)
        # Test the database alteration
        self.assertColumnNotNull("test_alfl_pony", "pink")
        with connection.schema_editor() as editor:
            operation.database_forwards("test_alfl", editor, project_state, new_state)
        self.assertColumnNull("test_alfl_pony", "pink")
        # And test reversal
        with connection.schema_editor() as editor:
            operation.database_backwards("test_alfl", editor, new_state, project_state)
        self.assertColumnNotNull("test_alfl_pony", "pink")

    def test_rename_field(self):
        """
        Tests the RenameField operation.
        """
        project_state = self.set_up_test_model("test_rnfl")
        # Test the state alteration
        operation = migrations.RenameField("Pony", "pink", "blue")
        new_state = project_state.clone()
        operation.state_forwards("test_rnfl", new_state)
        self.assertIn("blue", [n for n, f in new_state.models["test_rnfl", "pony"].fields])
        self.assertNotIn("pink", [n for n, f in new_state.models["test_rnfl", "pony"].fields])
        # Test the database alteration
        self.assertColumnExists("test_rnfl_pony", "pink")
        self.assertColumnNotExists("test_rnfl_pony", "blue")
        with connection.schema_editor() as editor:
            operation.database_forwards("test_rnfl", editor, project_state, new_state)
        self.assertColumnExists("test_rnfl_pony", "blue")
        self.assertColumnNotExists("test_rnfl_pony", "pink")
        # And test reversal
        with connection.schema_editor() as editor:
            operation.database_backwards("test_rnfl", editor, new_state, project_state)
        self.assertColumnExists("test_rnfl_pony", "pink")
        self.assertColumnNotExists("test_rnfl_pony", "blue")