Commit f1e408ff authored by Vytis Banaitis's avatar Vytis Banaitis Committed by Tim Graham
Browse files

Fixed #25044 -- Fixed migrations for renaming ManyToManyField's through model.

parent 16a842b3
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -758,6 +758,7 @@ answer newbie questions, and generally made Django that much better:
    Vladimir Kuzma <vladimirkuzma.ch@gmail.com>
    Vlado <vlado@labath.org>
    Vsevolod Solovyov
    Vytis Banaitis <vytis.banaitis@gmail.com>
    wam-djangobug@wamber.net
    Wang Chun <wangchun@exoweb.net>
    Warren Smith <warren@wandrsmith.net>
+7 −0
Original line number Diff line number Diff line
@@ -867,6 +867,13 @@ class MigrationAutodetector(object):
                )
                if rename_key in self.renamed_models:
                    new_field.remote_field.model = old_field.remote_field.model
            if hasattr(new_field, "remote_field") and getattr(new_field.remote_field, "through", None):
                rename_key = (
                    new_field.remote_field.through._meta.app_label,
                    new_field.remote_field.through._meta.model_name,
                )
                if rename_key in self.renamed_models:
                    new_field.remote_field.through = old_field.remote_field.through
            old_field_dec = self.deep_deconstruct(old_field)
            new_field_dec = self.deep_deconstruct(new_field)
            if old_field_dec != new_field_dec:
+22 −0
Original line number Diff line number Diff line
@@ -307,6 +307,28 @@ class RenameModel(ModelOperation):
                new_fields.append((name, field))
            state.models[related_key].fields = new_fields
            state.reload_model(*related_key)
        # Repoint M2Ms with through pointing to us
        related_models = {
            f.remote_field.model for f in model._meta.fields
            if getattr(f.remote_field, 'model', None)
        }
        model_name = '%s.%s' % (app_label, self.old_name)
        for related_model in related_models:
            if related_model == model:
                related_key = (app_label, self.new_name_lower)
            else:
                related_key = (related_model._meta.app_label, related_model._meta.model_name)
            new_fields = []
            changed = False
            for name, field in state.models[related_key].fields:
                if field.is_relation and field.many_to_many and field.remote_field.through == model_name:
                    field = field.clone()
                    field.remote_field.through = '%s.%s' % (app_label, self.new_name)
                    changed = True
                new_fields.append((name, field))
            if changed:
                state.models[related_key].fields = new_fields
                state.reload_model(*related_key)
        state.reload_model(app_label, self.new_name_lower)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+24 −0
Original line number Diff line number Diff line
@@ -259,6 +259,10 @@ class AutodetectorTests(TestCase):
        ("id", models.AutoField(primary_key=True)),
        ("publishers", models.ManyToManyField("testapp.Publisher", through="testapp.Contract")),
    ])
    author_with_renamed_m2m_through = ModelState("testapp", "Author", [
        ("id", models.AutoField(primary_key=True)),
        ("publishers", models.ManyToManyField("testapp.Publisher", through="testapp.Deal")),
    ])
    author_with_former_m2m = ModelState("testapp", "Author", [
        ("id", models.AutoField(primary_key=True)),
        ("publishers", models.CharField(max_length=100)),
@@ -286,6 +290,11 @@ class AutodetectorTests(TestCase):
        ("author", models.ForeignKey("testapp.Author", models.CASCADE)),
        ("publisher", models.ForeignKey("testapp.Publisher", models.CASCADE)),
    ])
    contract_renamed = ModelState("testapp", "Deal", [
        ("id", models.AutoField(primary_key=True)),
        ("author", models.ForeignKey("testapp.Author", models.CASCADE)),
        ("publisher", models.ForeignKey("testapp.Publisher", models.CASCADE)),
    ])
    publisher = ModelState("testapp", "Publisher", [
        ("id", models.AutoField(primary_key=True)),
        ("name", models.CharField(max_length=100)),
@@ -848,6 +857,21 @@ class AutodetectorTests(TestCase):
        # no AlterField for the related field.
        self.assertNumberMigrations(changes, 'otherapp', 0)

    def test_rename_m2m_through_model(self):
        """
        Tests autodetection of renamed models that are used in M2M relations as
        through models.
        """
        # Make state
        before = self.make_project_state([self.author_with_m2m_through, self.publisher, self.contract])
        after = self.make_project_state([self.author_with_renamed_m2m_through, self.publisher, self.contract_renamed])
        autodetector = MigrationAutodetector(before, after, MigrationQuestioner({'ask_rename_model': True}))
        changes = autodetector._detect_changes()
        # Right number/type of migrations?
        self.assertNumberMigrations(changes, 'testapp', 1)
        self.assertOperationTypes(changes, 'testapp', 0, ['RenameModel'])
        self.assertOperationAttributes(changes, 'testapp', 0, 0, old_name='Contract', new_name='Deal')

    def test_rename_model_with_renamed_rel_field(self):
        """
        Tests autodetection of renamed models while simultaneously renaming one
+41 −0
Original line number Diff line number Diff line
@@ -727,6 +727,47 @@ class OperationTests(OperationTestBase):
        self.assertEqual(Rider.objects.count(), 2)
        self.assertEqual(Pony._meta.get_field('riders').remote_field.through.objects.count(), 2)

    def test_rename_m2m_through_model(self):
        app_label = "test_rename_through"
        project_state = self.apply_operations(app_label, ProjectState(), operations=[
            migrations.CreateModel("Rider", fields=[
                ("id", models.AutoField(primary_key=True)),
            ]),
            migrations.CreateModel("Pony", fields=[
                ("id", models.AutoField(primary_key=True)),
            ]),
            migrations.CreateModel("PonyRider", fields=[
                ("id", models.AutoField(primary_key=True)),
                ("rider", models.ForeignKey("test_rename_through.Rider", models.CASCADE)),
                ("pony", models.ForeignKey("test_rename_through.Pony", models.CASCADE)),
            ]),
            migrations.AddField(
                "Pony",
                "riders",
                models.ManyToManyField("test_rename_through.Rider", through="test_rename_through.PonyRider"),
            ),
        ])
        Pony = project_state.apps.get_model(app_label, "Pony")
        Rider = project_state.apps.get_model(app_label, "Rider")
        PonyRider = project_state.apps.get_model(app_label, "PonyRider")
        pony = Pony.objects.create()
        rider = Rider.objects.create()
        PonyRider.objects.create(pony=pony, rider=rider)

        project_state = self.apply_operations(app_label, project_state, operations=[
            migrations.RenameModel("PonyRider", "PonyRider2"),
        ])
        Pony = project_state.apps.get_model(app_label, "Pony")
        Rider = project_state.apps.get_model(app_label, "Rider")
        PonyRider = project_state.apps.get_model(app_label, "PonyRider2")
        pony = Pony.objects.first()
        rider = Rider.objects.create()
        PonyRider.objects.create(pony=pony, rider=rider)
        self.assertEqual(Pony.objects.count(), 1)
        self.assertEqual(Rider.objects.count(), 2)
        self.assertEqual(PonyRider.objects.count(), 2)
        self.assertEqual(pony.riders.count(), 2)

    def test_add_field(self):
        """
        Tests the AddField operation.