Commit 5db028af authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Fix altering of SERIAL columns and InnoDB being picky about FK changes

parent cee4fe73
Loading
Loading
Loading
Loading
+54 −1
Original line number Diff line number Diff line
@@ -2,4 +2,57 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
    pass

    sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
    sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"

    def _alter_column_type_sql(self, table, column, type):
        """
        Makes ALTER TYPE with SERIAL make sense.
        """
        if type.lower() == "serial":
            sequence_name = "%s_%s_seq" % (table, column)
            return (
                (
                    self.sql_alter_column_type % {
                        "column": self.quote_name(column),
                        "type": "integer",
                    },
                    [],
                ),
                [
                    (
                        self.sql_delete_sequence % {
                            "sequence": sequence_name,
                        },
                        [],
                    ),
                    (
                        self.sql_create_sequence % {
                            "sequence": sequence_name,
                        },
                        [],
                    ),
                    (
                        self.sql_alter_column % {
                            "table": table,
                            "changes": self.sql_alter_column_default % {
                                "column": column,
                                "default": "nextval('%s')" % sequence_name,
                            }
                        },
                        [],
                    ),
                    (
                        self.sql_set_sequence_max % {
                            "table": table,
                            "column": column,
                            "sequence": sequence_name,
                        },
                        [],
                    ),
                ],
            )
        else:
            return super(DatabaseSchemaEditor, self)._alter_column_type_sql(table, column, type)
+55 −10
Original line number Diff line number Diff line
@@ -498,6 +498,18 @@ class BaseDatabaseSchemaEditor(object):
                        "name": fk_name,
                    }
                )
        # Drop incoming FK constraints if we're a primary key and things are going
        # to change.
        if old_field.primary_key and new_field.primary_key and old_type != new_type:
            for rel in new_field.model._meta.get_all_related_objects():
                rel_fk_names = self._constraint_names(rel.model, [rel.field.column], foreign_key=True)
                for fk_name in rel_fk_names:
                    self.execute(
                        self.sql_delete_fk % {
                            "table": self.quote_name(rel.model._meta.db_table),
                            "name": fk_name,
                        }
                    )
        # Change check constraints?
        if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
            constraint_names = self._constraint_names(model, [old_field.column], check=True)
@@ -524,15 +536,12 @@ class BaseDatabaseSchemaEditor(object):
            })
        # Next, start accumulating actions to do
        actions = []
        post_actions = []
        # Type change?
        if old_type != new_type:
            actions.append((
                self.sql_alter_column_type % {
                    "column": self.quote_name(new_field.column),
                    "type": new_type,
                },
                [],
            ))
            fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type)
            actions.append(fragment)
            post_actions.extend(other_actions)
        # Default change?
        old_default = self.effective_default(old_field)
        new_default = self.effective_default(new_field)
@@ -596,6 +605,9 @@ class BaseDatabaseSchemaEditor(object):
                    },
                    params,
                )
        if post_actions:
            for sql, params in post_actions:
                self.execute(sql, params)
        # Added a unique?
        if not old_field.unique and new_field.unique:
            self.execute(
@@ -619,7 +631,7 @@ class BaseDatabaseSchemaEditor(object):
        # referring to us.
        rels_to_update = []
        if old_field.primary_key and new_field.primary_key and old_type != new_type:
            rels_to_update.extend(model._meta.get_all_related_objects())
            rels_to_update.extend(new_field.model._meta.get_all_related_objects())
        # Changed to become primary key?
        # Note that we don't detect unsetting of a PK, as we assume another field
        # will always come along and replace it.
@@ -647,8 +659,8 @@ class BaseDatabaseSchemaEditor(object):
                }
            )
            # Update all referencing columns
            rels_to_update.extend(model._meta.get_all_related_objects())
        # Handle out type alters on the other end of rels from the PK stuff above
            rels_to_update.extend(new_field.model._meta.get_all_related_objects())
        # Handle our type alters on the other end of rels from the PK stuff above
        for rel in rels_to_update:
            rel_db_params = rel.field.db_parameters(connection=self.connection)
            rel_type = rel_db_params['type']
@@ -672,6 +684,18 @@ class BaseDatabaseSchemaEditor(object):
                    "to_column": self.quote_name(new_field.rel.get_related_field().column),
                }
            )
        # Rebuild FKs that pointed to us if we previously had to drop them
        if old_field.primary_key and new_field.primary_key and old_type != new_type:
            for rel in new_field.model._meta.get_all_related_objects():
                self.execute(
                    self.sql_create_fk % {
                        "table": self.quote_name(rel.model._meta.db_table),
                        "name": self._create_index_name(rel.model, [rel.field.column], suffix="_fk"),
                        "column": self.quote_name(rel.field.column),
                        "to_table": self.quote_name(model._meta.db_table),
                        "to_column": self.quote_name(new_field.column),
                    }
                )
        # Does it have check constraints we need to add?
        if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
            self.execute(
@@ -686,6 +710,27 @@ class BaseDatabaseSchemaEditor(object):
        if self.connection.features.connection_persists_old_columns:
            self.connection.close()

    def _alter_column_type_sql(self, table, column, type):
        """
        Hook to specialise column type alteration for different backends,
        for cases when a creation type is different to an alteration type
        (e.g. SERIAL in PostgreSQL, PostGIS fields).

        Should return two things; an SQL fragment of (sql, params) to insert
        into an ALTER TABLE statement, and a list of extra (sql, params) tuples
        to run once the field is altered.
        """
        return (
            (
                self.sql_alter_column_type % {
                    "column": self.quote_name(column),
                    "type": type,
                },
                [],
            ),
            [],
        )

    def _alter_many_to_many(self, model, old_field, new_field, strict):
        """
        Alters M2Ms to repoint their to= endpoints.
+7 −5
Original line number Diff line number Diff line
@@ -24,9 +24,10 @@ class AddField(Operation):
        state.models[app_label, self.model_name.lower()].fields.append((self.name, field))

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
            schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
@@ -73,9 +74,10 @@ class RemoveField(Operation):
            schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
            schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])

    def describe(self):
        return "Remove field %s from %s" % (self.name, self.model_name)
@@ -107,7 +109,7 @@ class AlterField(Operation):
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.alter_field(
                to_model,
                from_model,
                from_model._meta.get_field_by_name(self.name)[0],
                to_model._meta.get_field_by_name(self.name)[0],
            )
@@ -153,7 +155,7 @@ class RenameField(Operation):
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.alter_field(
                to_model,
                from_model,
                from_model._meta.get_field_by_name(self.old_name)[0],
                to_model._meta.get_field_by_name(self.new_name)[0],
            )
@@ -163,7 +165,7 @@ class RenameField(Operation):
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.alter_field(
                to_model,
                from_model,
                from_model._meta.get_field_by_name(self.new_name)[0],
                to_model._meta.get_field_by_name(self.old_name)[0],
            )
+56 −1
Original line number Diff line number Diff line
import unittest
from django.db import connection, models, migrations, router
from django.db.models.fields import NOT_PROVIDED
from django.db.transaction import atomic
@@ -13,7 +14,7 @@ class OperationTests(MigrationTestBase):
    both forwards and backwards.
    """

    def set_up_test_model(self, app_label, second_model=False):
    def set_up_test_model(self, app_label, second_model=False, related_model=False):
        """
        Creates a test model state and database table.
        """
@@ -38,6 +39,14 @@ class OperationTests(MigrationTestBase):
        )]
        if second_model:
            operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))]))
        if related_model:
            operations.append(migrations.CreateModel(
                "Rider",
                [
                    ("id", models.AutoField(primary_key=True)),
                    ("pony", models.ForeignKey("Pony")),
                ],
            ))
        project_state = ProjectState()
        for operation in operations:
            operation.state_forwards(app_label, project_state)
@@ -269,6 +278,52 @@ class OperationTests(MigrationTestBase):
            operation.database_backwards("test_alfl", editor, new_state, project_state)
        self.assertColumnNotNull("test_alfl_pony", "pink")

    def test_alter_field_pk(self):
        """
        Tests the AlterField operation on primary keys (for things like PostgreSQL's SERIAL weirdness)
        """
        project_state = self.set_up_test_model("test_alflpk")
        # Test the state alteration
        operation = migrations.AlterField("Pony", "id", models.IntegerField(primary_key=True))
        new_state = project_state.clone()
        operation.state_forwards("test_alflpk", new_state)
        self.assertIsInstance(project_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.AutoField)
        self.assertIsInstance(new_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.IntegerField)
        # Test the database alteration
        with connection.schema_editor() as editor:
            operation.database_forwards("test_alflpk", editor, project_state, new_state)
        # And test reversal
        with connection.schema_editor() as editor:
            operation.database_backwards("test_alflpk", editor, new_state, project_state)

    @unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support")
    def test_alter_field_pk_fk(self):
        """
        Tests the AlterField operation on primary keys changes any FKs pointing to it.
        """
        project_state = self.set_up_test_model("test_alflpkfk", related_model=True)
        # Test the state alteration
        operation = migrations.AlterField("Pony", "id", models.FloatField(primary_key=True))
        new_state = project_state.clone()
        operation.state_forwards("test_alflpkfk", new_state)
        self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
        self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
        # Test the database alteration
        id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
        fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
        self.assertEqual(id_type, fk_type)
        with connection.schema_editor() as editor:
            operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
        id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
        fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
        self.assertEqual(id_type, fk_type)
        # And test reversal
        with connection.schema_editor() as editor:
            operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
        id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
        fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
        self.assertEqual(id_type, fk_type)

    def test_rename_field(self):
        """
        Tests the RenameField operation.
+1 −0
Original line number Diff line number Diff line
@@ -636,6 +636,7 @@ class SchemaTests(TransactionTestCase):
        # Alter to change the PK
        new_field = SlugField(primary_key=True)
        new_field.set_attributes_from_name("slug")
        new_field.model = Tag
        with connection.schema_editor() as editor:
            editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0])
            editor.alter_field(