Commit e24e9e04 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Fixed #23014: Renaming not atomic with unique together

parent 7dacc6ae
Loading
Loading
Loading
Loading
+13 −2
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
        else:
            raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))

    def _remake_table(self, model, create_fields=[], delete_fields=[], alter_fields=[], rename_fields=[], override_uniques=None):
    def _remake_table(self, model, create_fields=[], delete_fields=[], alter_fields=[], override_uniques=None):
        """
        Shortcut to transform a model from old_model into new_model
        """
@@ -52,6 +52,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
        # Since mapping might mix column names and default values,
        # its values must be already quoted.
        mapping = dict((f.column, self.quote_name(f.column)) for f in model._meta.local_fields)
        # This maps field names (not columns) for things like unique_together
        rename_mapping = {}
        # If any of the new or altered fields is introducing a new PK,
        # remove the old one
        restore_pk_field = None
@@ -77,6 +79,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
            del mapping[old_field.column]
            body[new_field.name] = new_field
            mapping[new_field.column] = self.quote_name(old_field.column)
            rename_mapping[old_field.name] = new_field.name
        # Remove any deleted fields
        for field in delete_fields:
            del body[field.name]
@@ -92,11 +95,19 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
        # the internal references of some of the provided fields.
        body = copy.deepcopy(body)

        # Work out the new value of unique_together, taking renames into
        # account
        if override_uniques is None:
            override_uniques = [
                [rename_mapping.get(n, n) for n in unique]
                for unique in model._meta.unique_together
            ]

        # Construct a new model for the new state
        meta_contents = {
            'app_label': model._meta.app_label,
            'db_table': model._meta.db_table + "__new",
            'unique_together': model._meta.unique_together if override_uniques is None else override_uniques,
            'unique_together': override_uniques,
            'apps': apps,
        }
        meta = type("Meta", tuple(), meta_contents)
+12 −1
Original line number Diff line number Diff line
@@ -791,7 +791,18 @@ class MigrationAutodetector(object):
            old_model_name = self.renamed_models.get((app_label, model_name), model_name)
            old_model_state = self.from_state.models[app_label, old_model_name]
            new_model_state = self.to_state.models[app_label, model_name]
            if old_model_state.options.get(option_name) != new_model_state.options.get(option_name):
            # We run the old version through the field renames to account for those
            if old_model_state.options.get(option_name) is None:
                old_value = None
            else:
                old_value = [
                    [
                        self.renamed_fields.get((app_label, model_name, n), n)
                        for n in unique
                    ]
                    for unique in old_model_state.options[option_name]
                ]
            if old_value != new_model_state.options.get(option_name):
                self.add_operation(
                    app_label,
                    operation(
+8 −0
Original line number Diff line number Diff line
@@ -162,9 +162,17 @@ class RenameField(Operation):
        self.new_name = new_name

    def state_forwards(self, app_label, state):
        # Rename the field
        state.models[app_label, self.model_name.lower()].fields = [
            (self.new_name if n == self.old_name else n, f) for n, f in state.models[app_label, self.model_name.lower()].fields
        ]
        # Fix unique_together to refer to the new field
        options = state.models[app_label, self.model_name.lower()].options
        if "unique_together" in options:
            options['unique_together'] = [
                [self.new_name if n == self.old_name else n for n in unique]
                for unique in options['unique_together']
            ]

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.render().get_model(app_label, self.model_name)
+13 −2
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ class OperationTestBase(MigrationTestBase):
        operation.state_forwards(app_label, new_state)
        return project_state, new_state

    def set_up_test_model(self, app_label, second_model=False, third_model=False, related_model=False, mti_model=False, proxy_model=False):
    def set_up_test_model(self, app_label, second_model=False, third_model=False, related_model=False, mti_model=False, proxy_model=False, unique_together=False):
        """
        Creates a test model state and database table.
        """
@@ -85,6 +85,7 @@ class OperationTestBase(MigrationTestBase):
            ],
            options={
                "swappable": "TEST_SWAP_MODEL",
                "unique_together": [["pink", "weight"]] if unique_together else [],
            },
        )]
        if second_model:
@@ -862,7 +863,7 @@ class OperationTests(OperationTestBase):
        """
        Tests the RenameField operation.
        """
        project_state = self.set_up_test_model("test_rnfl")
        project_state = self.set_up_test_model("test_rnfl", unique_together=True)
        # Test the state alteration
        operation = migrations.RenameField("Pony", "pink", "blue")
        self.assertEqual(operation.describe(), "Rename field pink on Pony to blue")
@@ -870,6 +871,9 @@ class OperationTests(OperationTestBase):
        operation.state_forwards("test_rnfl", new_state)
        self.assertIn("blue", [n for n, f in new_state.models["test_rnfl", "pony"].fields])
        self.assertNotIn("pink", [n for n, f in new_state.models["test_rnfl", "pony"].fields])
        # Make sure the unique_together has the renamed column too
        self.assertIn("blue", new_state.models["test_rnfl", "pony"].options['unique_together'][0])
        self.assertNotIn("pink", new_state.models["test_rnfl", "pony"].options['unique_together'][0])
        # Test the database alteration
        self.assertColumnExists("test_rnfl_pony", "pink")
        self.assertColumnNotExists("test_rnfl_pony", "blue")
@@ -877,6 +881,13 @@ class OperationTests(OperationTestBase):
            operation.database_forwards("test_rnfl", editor, project_state, new_state)
        self.assertColumnExists("test_rnfl_pony", "blue")
        self.assertColumnNotExists("test_rnfl_pony", "pink")
        # Ensure the unique constraint has been ported over
        with connection.cursor() as cursor:
            cursor.execute("INSERT INTO test_rnfl_pony (blue, weight) VALUES (1, 1)")
            with self.assertRaises(IntegrityError):
                with atomic():
                    cursor.execute("INSERT INTO test_rnfl_pony (blue, weight) VALUES (1, 1)")
            cursor.execute("DELETE FROM test_rnfl_pony")
        # And test reversal
        with connection.schema_editor() as editor:
            operation.database_backwards("test_rnfl", editor, new_state, project_state)