Commit 4ce7a6bc authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Fixed #22750, #22248: Model renaming now also alters field FKs

parent 1e84d261
Loading
Loading
Loading
Loading
+13 −3
Original line number Diff line number Diff line
@@ -580,11 +580,21 @@ class MigrationAutodetector(object):
        for app_label, model_name, field_name in sorted(self.old_field_keys.intersection(self.new_field_keys)):
            # Did the field change?
            old_model_name = self.renamed_models.get((app_label, model_name), model_name)
            old_model_state = self.from_state.models[app_label, old_model_name]
            new_model_state = self.to_state.models[app_label, model_name]
            old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name)
            old_field_dec = self.deep_deconstruct(old_model_state.get_field_by_name(old_field_name))
            new_field_dec = self.deep_deconstruct(new_model_state.get_field_by_name(field_name))
            old_field = self.old_apps.get_model(app_label, old_model_name)._meta.get_field_by_name(old_field_name)[0]
            new_field = self.new_apps.get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0]
            # Implement any model renames on relations; these are handled by RenameModel
            # so we need to exclude them from the comparison
            if hasattr(new_field, "rel") and getattr(new_field.rel, "to", None):
                rename_key = (
                    new_field.rel.to._meta.app_label,
                    new_field.rel.to._meta.object_name.lower(),
                )
                if rename_key in self.renamed_models:
                    new_field.rel.to = old_field.rel.to
            old_field_dec = self.deep_deconstruct(old_field)
            new_field_dec = self.deep_deconstruct(new_field)
            if old_field_dec != new_field_dec:
                self.add_operation(
                    app_label,
+36 −10
Original line number Diff line number Diff line
@@ -113,9 +113,28 @@ class RenameModel(Operation):
        self.new_name = new_name

    def state_forwards(self, app_label, state):
        # Get all of the related objects we need to repoint
        apps = state.render(skip_cache=True)
        model = apps.get_model(app_label, self.old_name)
        related_objects = model._meta.get_all_related_objects()
        related_m2m_objects = model._meta.get_all_related_many_to_many_objects()
        # Rename the model
        state.models[app_label, self.new_name.lower()] = state.models[app_label, self.old_name.lower()]
        state.models[app_label, self.new_name.lower()].name = self.new_name
        del state.models[app_label, self.old_name.lower()]
        # Repoint the FKs and M2Ms pointing to us
        for related_object in (related_objects + related_m2m_objects):
            related_key = (
                related_object.model._meta.app_label,
                related_object.model._meta.object_name.lower(),
            )
            new_fields = []
            for name, field in state.models[related_key].fields:
                if name == related_object.field.name:
                    field = field.clone()
                    field.rel.to = "%s.%s" % (app_label, self.new_name)
                new_fields.append((name, field))
            state.models[related_key].fields = new_fields

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        old_apps = from_state.render()
@@ -123,23 +142,30 @@ class RenameModel(Operation):
        old_model = old_apps.get_model(app_label, self.old_name)
        new_model = new_apps.get_model(app_label, self.new_name)
        if router.allow_migrate(schema_editor.connection.alias, new_model):
            # Move the main table
            schema_editor.alter_db_table(
                new_model,
                old_model._meta.db_table,
                new_model._meta.db_table,
            )
            # Alter the fields pointing to us
            related_objects = old_model._meta.get_all_related_objects()
            related_m2m_objects = old_model._meta.get_all_related_many_to_many_objects()
            for related_object in (related_objects + related_m2m_objects):
                to_field = new_apps.get_model(
                    related_object.model._meta.app_label,
                    related_object.model._meta.object_name.lower(),
                )._meta.get_field_by_name(related_object.field.name)[0]
                schema_editor.alter_field(
                    related_object.model,
                    related_object.field,
                    to_field,
                )

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        old_apps = from_state.render()
        new_apps = to_state.render()
        old_model = old_apps.get_model(app_label, self.new_name)
        new_model = new_apps.get_model(app_label, self.old_name)
        if router.allow_migrate(schema_editor.connection.alias, new_model):
            schema_editor.alter_db_table(
                new_model,
                old_model._meta.db_table,
                new_model._meta.db_table,
            )
        self.new_name, self.old_name = self.old_name, self.new_name
        self.database_forwards(app_label, schema_editor, from_state, to_state)
        self.new_name, self.old_name = self.old_name, self.new_name

    def references_model(self, name, app_label=None):
        return (
+7 −3
Original line number Diff line number Diff line
@@ -38,9 +38,9 @@ class ProjectState(object):
            real_apps=self.real_apps,
        )

    def render(self, include_real=None, ignore_swappable=False):
    def render(self, include_real=None, ignore_swappable=False, skip_cache=False):
        "Turns the project state into actual models in a new Apps"
        if self.apps is None:
        if self.apps is None or skip_cache:
            # Any apps in self.real_apps should have all their models included
            # in the render. We don't use the original model instances as there
            # are some variables that refer to the Apps object.
@@ -87,7 +87,11 @@ class ProjectState(object):
                        ))
                    else:
                        do_pending_lookups(model)
        try:
            return self.apps
        finally:
            if skip_cache:
                self.apps = None

    @classmethod
    def from_apps(cls, apps):
+9 −24
Original line number Diff line number Diff line
@@ -77,9 +77,7 @@ class AutodetectorTests(TestCase):
        return output

    def assertNumberMigrations(self, changes, app_label, number):
        if not changes.get(app_label, None):
            self.fail("No migrations found for %s\n%s" % (app_label, self.repr_changes(changes)))
        if len(changes[app_label]) != number:
        if len(changes.get(app_label, [])) != number:
            self.fail("Incorrect number of migrations (%s) for %s (expected %s)\n%s" % (
                len(changes[app_label]),
                app_label,
@@ -285,7 +283,7 @@ class AutodetectorTests(TestCase):
        changes = autodetector._detect_changes()

        # Right number of migrations for model rename?
        self.assertEqual(len(changes['testapp']), 1)
        self.assertNumberMigrations(changes, 'testapp', 1)
        # Right number of actions?
        migration = changes['testapp'][0]
        self.assertEqual(len(migration.operations), 1)
@@ -294,17 +292,9 @@ class AutodetectorTests(TestCase):
        self.assertEqual(action.__class__.__name__, "RenameModel")
        self.assertEqual(action.old_name, "Author")
        self.assertEqual(action.new_name, "Writer")

        # Right number of migrations for related field rename?
        self.assertEqual(len(changes['otherapp']), 1)
        # Right number of actions?
        migration = changes['otherapp'][0]
        self.assertEqual(len(migration.operations), 1)
        # Right action?
        action = migration.operations[0]
        self.assertEqual(action.__class__.__name__, "AlterField")
        self.assertEqual(action.name, "author")
        self.assertEqual(action.field.rel.to, "testapp.Writer")
        # Now that RenameModel handles related fields too, there should be
        # no AlterField for the related field.
        self.assertNumberMigrations(changes, 'otherapp', 0)

    def test_rename_model_with_renamed_rel_field(self):
        """
@@ -316,9 +306,8 @@ class AutodetectorTests(TestCase):
        after = self.make_project_state([self.author_renamed_with_book, self.book_with_field_and_author_renamed])
        autodetector = MigrationAutodetector(before, after, MigrationQuestioner({"ask_rename_model": True, "ask_rename": True}))
        changes = autodetector._detect_changes()

        # Right number of migrations for model rename?
        self.assertEqual(len(changes['testapp']), 1)
        self.assertNumberMigrations(changes, 'testapp', 1)
        # Right number of actions?
        migration = changes['testapp'][0]
        self.assertEqual(len(migration.operations), 1)
@@ -327,21 +316,17 @@ class AutodetectorTests(TestCase):
        self.assertEqual(action.__class__.__name__, "RenameModel")
        self.assertEqual(action.old_name, "Author")
        self.assertEqual(action.new_name, "Writer")

        # Right number of migrations for related field rename?
        self.assertEqual(len(changes['otherapp']), 1)
        # Alter is already taken care of.
        self.assertNumberMigrations(changes, 'otherapp', 1)
        # Right number of actions?
        migration = changes['otherapp'][0]
        self.assertEqual(len(migration.operations), 2)
        self.assertEqual(len(migration.operations), 1)
        # Right actions?
        action = migration.operations[0]
        self.assertEqual(action.__class__.__name__, "RenameField")
        self.assertEqual(action.old_name, "author")
        self.assertEqual(action.new_name, "writer")
        action = migration.operations[1]
        self.assertEqual(action.__class__.__name__, "AlterField")
        self.assertEqual(action.name, "writer")
        self.assertEqual(action.field.rel.to, "testapp.Writer")

    def test_fk_dependency(self):
        "Tests that having a ForeignKey automatically adds a dependency"
+14 −0
Original line number Diff line number Diff line
@@ -46,3 +46,17 @@ class MigrationTestBase(TransactionTestCase):

    def assertIndexNotExists(self, table, columns):
        return self.assertIndexExists(table, columns, False)

    def assertFKExists(self, table, columns, to, value=True):
        with connection.cursor() as cursor:
            self.assertEqual(
                value,
                any(
                    c["foreign_key"] == to
                    for c in connection.introspection.get_constraints(cursor, table).values()
                    if c['columns'] == list(columns)
                ),
            )

    def assertFKNotExists(self, table, columns, to, value=True):
        return self.assertFKExists(table, columns, to, False)
Loading