Commit 248fdb11 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Change FKs when what they point to changes

parent f3582a05
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -615,6 +615,11 @@ class BaseDatabaseSchemaEditor(object):
                    "extra": "",
                }
            )
        # Type alteration on primary key? Then we need to alter the column
        # referring to us.
        rels_to_update = []
        if old_field.primary_key and new_field.primary_key and old_type != new_type:
            rels_to_update.extend(model._meta.get_all_related_objects())
        # Changed to become primary key?
        # Note that we don't detect unsetting of a PK, as we assume another field
        # will always come along and replace it.
@@ -641,6 +646,21 @@ class BaseDatabaseSchemaEditor(object):
                    "columns": self.quote_name(new_field.column),
                }
            )
            # Update all referencing columns
            rels_to_update.extend(model._meta.get_all_related_objects())
        # Handle out type alters on the other end of rels from the PK stuff above
        for rel in rels_to_update:
            rel_db_params = rel.field.db_parameters(connection=self.connection)
            rel_type = rel_db_params['type']
            self.execute(
                self.sql_alter_column % {
                    "table": self.quote_name(rel.model._meta.db_table),
                    "changes": self.sql_alter_column_type % {
                        "column": self.quote_name(rel.field.column),
                        "type": rel_type,
                    }
                }
            )
        # Does it have a foreign key?
        if new_field.rel:
            self.execute(
+70 −59
Original line number Diff line number Diff line
@@ -153,30 +153,46 @@ class MigrationAutodetector(object):
            )
        # Changes within models
        kept_models = set(old_model_keys).intersection(new_model_keys)
        old_fields = set()
        new_fields = set()
        for app_label, model_name in kept_models:
            old_model_state = self.from_state.models[app_label, model_name]
            new_model_state = self.to_state.models[app_label, model_name]
            # Collect field changes for later global dealing with (so AddFields
            # always come before AlterFields even on separate models)
            old_fields.update((app_label, model_name, x) for x, y in old_model_state.fields)
            new_fields.update((app_label, model_name, x) for x, y in new_model_state.fields)
            # Unique_together changes
            if old_model_state.options.get("unique_together", set()) != new_model_state.options.get("unique_together", set()):
                self.add_to_migration(
                    app_label,
                    operations.AlterUniqueTogether(
                        name=model_name,
                        unique_together=new_model_state.options.get("unique_together", set()),
                    )
                )
        # New fields
            old_field_names = set(x for x, y in old_model_state.fields)
            new_field_names = set(x for x, y in new_model_state.fields)
            for field_name in new_field_names - old_field_names:
        for app_label, model_name, field_name in new_fields - old_fields:
            old_model_state = self.from_state.models[app_label, model_name]
            new_model_state = self.to_state.models[app_label, model_name]
            field = new_model_state.get_field_by_name(field_name)
            # Scan to see if this is actually a rename!
            field_dec = field.deconstruct()[1:]
            found_rename = False
                for removed_field_name in (old_field_names - new_field_names):
                    if old_model_state.get_field_by_name(removed_field_name).deconstruct()[1:] == field_dec:
                        if self.questioner.ask_rename(model_name, removed_field_name, field_name, field):
            for rem_app_label, rem_model_name, rem_field_name in (old_fields - new_fields):
                if rem_app_label == app_label and rem_model_name == model_name:
                    if old_model_state.get_field_by_name(rem_field_name).deconstruct()[1:] == field_dec:
                        if self.questioner.ask_rename(model_name, rem_field_name, field_name, field):
                            self.add_to_migration(
                                app_label,
                                operations.RenameField(
                                    model_name=model_name,
                                    old_name=removed_field_name,
                                    old_name=rem_field_name,
                                    new_name=field_name,
                                )
                            )
                            old_field_names.remove(removed_field_name)
                            new_field_names.remove(field_name)
                            old_fields.remove((rem_app_label, rem_model_name, rem_field_name))
                            new_fields.remove((app_label, model_name, field_name))
                            found_rename = True
                            break
            if found_rename:
@@ -204,7 +220,9 @@ class MigrationAutodetector(object):
                    )
                )
        # Old fields
            for field_name in old_field_names - new_field_names:
        for app_label, model_name, field_name in old_fields - new_fields:
            old_model_state = self.from_state.models[app_label, model_name]
            new_model_state = self.to_state.models[app_label, model_name]
            self.add_to_migration(
                app_label,
                operations.RemoveField(
@@ -213,8 +231,10 @@ class MigrationAutodetector(object):
                )
            )
        # The same fields
            for field_name in old_field_names.intersection(new_field_names):
        for app_label, model_name, field_name in old_fields.intersection(new_fields):
            # Did the field change?
            old_model_state = self.from_state.models[app_label, model_name]
            new_model_state = self.to_state.models[app_label, model_name]
            old_field_dec = old_model_state.get_field_by_name(field_name).deconstruct()
            new_field_dec = new_model_state.get_field_by_name(field_name).deconstruct()
            if old_field_dec != new_field_dec:
@@ -226,15 +246,6 @@ class MigrationAutodetector(object):
                        field=new_model_state.get_field_by_name(field_name),
                    )
                )
            # unique_together changes
            if old_model_state.options.get("unique_together", set()) != new_model_state.options.get("unique_together", set()):
                self.add_to_migration(
                    app_label,
                    operations.AlterUniqueTogether(
                        name=model_name,
                        unique_together=new_model_state.options.get("unique_together", set()),
                    )
                )
        # Alright, now add internal dependencies
        for app_label, migrations in self.migrations.items():
            for m1, m2 in zip(migrations, migrations[1:]):
+5 −7
Original line number Diff line number Diff line
@@ -24,10 +24,9 @@ class AddField(Operation):
        state.models[app_label, self.model_name.lower()].fields.append((self.name, field))

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
            schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
@@ -74,10 +73,9 @@ class RemoveField(Operation):
            schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
            schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])

    def describe(self):
        return "Remove field %s from %s" % (self.name, self.model_name)
@@ -109,7 +107,7 @@ class AlterField(Operation):
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.alter_field(
                from_model,
                to_model,
                from_model._meta.get_field_by_name(self.name)[0],
                to_model._meta.get_field_by_name(self.name)[0],
            )
@@ -155,7 +153,7 @@ class RenameField(Operation):
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.alter_field(
                from_model,
                to_model,
                from_model._meta.get_field_by_name(self.old_name)[0],
                to_model._meta.get_field_by_name(self.new_name)[0],
            )
@@ -165,7 +163,7 @@ class RenameField(Operation):
        to_model = to_state.render().get_model(app_label, self.model_name)
        if router.allow_migrate(schema_editor.connection.alias, to_model):
            schema_editor.alter_field(
                from_model,
                to_model,
                from_model._meta.get_field_by_name(self.new_name)[0],
                to_model._meta.get_field_by_name(self.old_name)[0],
            )