Commit 1aa3e09c authored by Claude Paroz's avatar Claude Paroz
Browse files

Fixed #23745 -- Reused states as much as possible in migrations

Thanks Tim Graham and Markus Holtermann for the reviews.
parent 2a9c4b49
Loading
Loading
Loading
Loading
+13 −11
Original line number Diff line number Diff line
@@ -98,13 +98,13 @@ class MigrationExecutor(object):
            self.progress_callback("apply_start", migration, fake)
        if not fake:
            # Test to see if this is an already-applied initial migration
            if self.detect_soft_applied(state, migration):
            applied, state = self.detect_soft_applied(state, migration)
            if applied:
                fake = True
            else:
                # Alright, do it normally
                with self.connection.schema_editor() as schema_editor:
                    project_state = self.loader.project_state((migration.app_label, migration.name), at_end=False)
                    migration.apply(project_state, schema_editor)
                    state = migration.apply(state, schema_editor)
        # For replacement migrations, record individual statuses
        if migration.replaces:
            for app_label, name in migration.replaces:
@@ -124,8 +124,7 @@ class MigrationExecutor(object):
            self.progress_callback("unapply_start", migration, fake)
        if not fake:
            with self.connection.schema_editor() as schema_editor:
                project_state = self.loader.project_state((migration.app_label, migration.name), at_end=False)
                migration.unapply(project_state, schema_editor)
                state = migration.unapply(state, schema_editor)
        # For replacement migrations, record individual statuses
        if migration.replaces:
            for app_label, name in migration.replaces:
@@ -143,12 +142,15 @@ class MigrationExecutor(object):
        tables it would create exist. This is intended only for use
        on initial migrations (as it only looks for CreateModel).
        """
        project_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
        apps = project_state.apps
        found_create_migration = False
        # Bail if the migration isn't the first one in its app
        if [name for app, name in migration.dependencies if app == migration.app_label]:
            return False
            return False, project_state
        if project_state is None:
            after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
        else:
            after_state = migration.mutate_state(project_state)
        apps = after_state.apps
        found_create_migration = False
        # Make sure all create model are done
        for operation in migration.operations:
            if isinstance(operation, migrations.CreateModel):
@@ -158,8 +160,8 @@ class MigrationExecutor(object):
                    # main app cache, as it's not a direct dependency.
                    model = global_apps.get_model(model._meta.swapped)
                if model._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
                    return False
                    return False, project_state
                found_create_migration = True
        # If we get this far and we found at least one CreateModel migration,
        # the migration is considered implicitly applied.
        return found_create_migration
        return found_create_migration, after_state
+8 −11
Original line number Diff line number Diff line
@@ -97,19 +97,17 @@ class Migration(object):
                schema_editor.collected_sql.append("-- %s" % operation.describe())
                schema_editor.collected_sql.append("--")
                continue
            # Get the state after the operation has run
            new_state = project_state.clone()
            operation.state_forwards(self.app_label, new_state)
            # Save the state before the operation has run
            old_state = project_state.clone()
            operation.state_forwards(self.app_label, project_state)
            # Run the operation
            if not schema_editor.connection.features.can_rollback_ddl and operation.atomic:
                # We're forcing a transaction on a non-transactional-DDL backend
                with atomic(schema_editor.connection.alias):
                    operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
                    operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
            else:
                # Normal behaviour
                operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
            # Switch states
            project_state = new_state
                operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
        return project_state

    def unapply(self, project_state, schema_editor, collect_sql=False):
@@ -133,10 +131,9 @@ class Migration(object):
            # If it's irreversible, error out
            if not operation.reversible:
                raise Migration.IrreversibleError("Operation %s in %s is not reversible" % (operation, self))
            new_state = project_state.clone()
            operation.state_forwards(self.app_label, new_state)
            to_run.append((operation, project_state, new_state))
            project_state = new_state
            old_state = project_state.clone()
            operation.state_forwards(self.app_label, project_state)
            to_run.append((operation, old_state, project_state))
        # Now run them in reverse
        to_run.reverse()
        for operation, to_state, from_state in to_run:
+4 −0
Original line number Diff line number Diff line
@@ -38,6 +38,7 @@ class AddField(Operation):
        else:
            field = self.field
        state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
        state.reload_model(app_label, self.model_name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -94,6 +95,7 @@ class RemoveField(Operation):
            if name != self.name:
                new_fields.append((name, instance))
        state.models[app_label, self.model_name.lower()].fields = new_fields
        state.reload_model(app_label, self.model_name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model = from_state.apps.get_model(app_label, self.model_name)
@@ -150,6 +152,7 @@ class AlterField(Operation):
        state.models[app_label, self.model_name.lower()].fields = [
            (n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
        ]
        state.reload_model(app_label, self.model_name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -220,6 +223,7 @@ class RenameField(Operation):
                    [self.new_name if n == self.old_name else n for n in together]
                    for together in options[option]
                ]
        state.reload_model(app_label, self.model_name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        to_model = to_state.apps.get_model(app_label, self.model_name)
+12 −5
Original line number Diff line number Diff line
@@ -39,14 +39,14 @@ class CreateModel(Operation):
        )

    def state_forwards(self, app_label, state):
        state.models[app_label, self.name.lower()] = ModelState(
        state.add_model(ModelState(
            app_label,
            self.name,
            list(self.fields),
            dict(self.options),
            tuple(self.bases),
            list(self.managers),
        )
        ))

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        model = to_state.apps.get_model(app_label, self.name)
@@ -98,7 +98,7 @@ class DeleteModel(Operation):
        )

    def state_forwards(self, app_label, state):
        del state.models[app_label, self.name.lower()]
        state.remove_model(app_label, self.name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        model = from_state.apps.get_model(app_label, self.name)
@@ -141,12 +141,13 @@ class RenameModel(Operation):
        # Get all of the related objects we need to repoint
        apps = state.apps
        model = apps.get_model(app_label, self.old_name)
        model._meta.apps = apps
        related_objects = model._meta.get_all_related_objects()
        related_m2m_objects = model._meta.get_all_related_many_to_many_objects()
        # Rename the model
        state.models[app_label, self.new_name.lower()] = state.models[app_label, self.old_name.lower()]
        state.models[app_label, self.new_name.lower()].name = self.new_name
        del state.models[app_label, self.old_name.lower()]
        state.remove_model(app_label, self.old_name)
        # Repoint the FKs and M2Ms pointing to us
        for related_object in (related_objects + related_m2m_objects):
            # Use the new related key for self referential related objects.
@@ -164,7 +165,8 @@ class RenameModel(Operation):
                    field.rel.to = "%s.%s" % (app_label, self.new_name)
                new_fields.append((name, field))
            state.models[related_key].fields = new_fields
        del state.apps  # FIXME: this should be replaced by a logic in state (update_model?)
            state.reload_model(*related_key)
        state.reload_model(app_label, self.new_name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        new_model = to_state.apps.get_model(app_label, self.new_name)
@@ -235,6 +237,7 @@ class AlterModelTable(Operation):

    def state_forwards(self, app_label, state):
        state.models[app_label, self.name.lower()].options["db_table"] = self.table
        state.reload_model(app_label, self.name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        new_model = to_state.apps.get_model(app_label, self.name)
@@ -290,6 +293,7 @@ class AlterUniqueTogether(Operation):
    def state_forwards(self, app_label, state):
        model_state = state.models[app_label, self.name.lower()]
        model_state.options[self.option_name] = self.unique_together
        state.reload_model(app_label, self.name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        new_model = to_state.apps.get_model(app_label, self.name)
@@ -337,6 +341,7 @@ class AlterIndexTogether(Operation):
    def state_forwards(self, app_label, state):
        model_state = state.models[app_label, self.name.lower()]
        model_state.options[self.option_name] = self.index_together
        state.reload_model(app_label, self.name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        new_model = to_state.apps.get_model(app_label, self.name)
@@ -381,6 +386,7 @@ class AlterOrderWithRespectTo(Operation):
    def state_forwards(self, app_label, state):
        model_state = state.models[app_label, self.name.lower()]
        model_state.options['order_with_respect_to'] = self.order_with_respect_to
        state.reload_model(app_label, self.name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        to_model = to_state.apps.get_model(app_label, self.name)
@@ -451,6 +457,7 @@ class AlterModelOptions(Operation):
        for key in self.ALTER_OPTION_KEYS:
            if key not in self.options and key in model_state.options:
                del model_state.options[key]
        state.reload_model(app_label, self.name)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        pass
+68 −4
Original line number Diff line number Diff line
from __future__ import unicode_literals
from collections import OrderedDict
import copy

from django.apps import AppConfig
from django.apps.registry import Apps, apps as global_apps
@@ -30,15 +32,52 @@ class ProjectState(object):
        # Apps to include from main registry, usually unmigrated ones
        self.real_apps = real_apps or []

    def add_model_state(self, model_state):
        self.models[(model_state.app_label, model_state.name.lower())] = model_state
    def add_model(self, model_state):
        app_label, model_name = model_state.app_label, model_state.name.lower()
        self.models[(app_label, model_name)] = model_state
        if 'apps' in self.__dict__:  # hasattr would cache the property
            self.reload_model(app_label, model_name)

    def remove_model(self, app_label, model_name):
        model_name = model_name.lower()
        del self.models[app_label, model_name]
        if 'apps' in self.__dict__:  # hasattr would cache the property
            self.apps.unregister_model(app_label, model_name)

    def reload_model(self, app_label, model_name):
        if 'apps' in self.__dict__:  # hasattr would cache the property
            # Get relations before reloading the models, as _meta.apps may change
            model_name = model_name.lower()
            try:
                related_old = {
                    f.model for f in
                    self.apps.get_model(app_label, model_name)._meta.get_all_related_objects()
                }
            except LookupError:
                related_old = set()
            self._reload_one_model(app_label, model_name)
            # Reload models if there are relations
            model = self.apps.get_model(app_label, model_name)
            related_m2m = {f.rel.to for f, _ in model._meta.get_m2m_with_model()}
            for rel_model in related_old.union(related_m2m):
                self._reload_one_model(rel_model._meta.app_label, rel_model._meta.model_name)
            if related_m2m:
                # Re-render this model after related models have been reloaded
                self._reload_one_model(app_label, model_name)

    def _reload_one_model(self, app_label, model_name):
        self.apps.unregister_model(app_label, model_name)
        self.models[app_label, model_name].render(self.apps)

    def clone(self):
        "Returns an exact copy of this ProjectState"
        return ProjectState(
        new_state = ProjectState(
            models={k: v.clone() for k, v in self.models.items()},
            real_apps=self.real_apps,
        )
        if 'apps' in self.__dict__:
            new_state.apps = self.apps.clone()
        return new_state

    @cached_property
    def apps(self):
@@ -147,6 +186,31 @@ class StateApps(Apps):
            else:
                do_pending_lookups(model)

    def clone(self):
        """
        Return a clone of this registry, mainly used by the migration framework.
        """
        clone = StateApps([], {})
        clone.all_models = copy.deepcopy(self.all_models)
        clone.app_configs = copy.deepcopy(self.app_configs)
        return clone

    def register_model(self, app_label, model):
        self.all_models[app_label][model._meta.model_name] = model
        if app_label not in self.app_configs:
            self.app_configs[app_label] = AppConfigStub(app_label)
            self.app_configs[app_label].models = OrderedDict()
        self.app_configs[app_label].models[model._meta.model_name] = model
        self.clear_cache()

    def unregister_model(self, app_label, model_name):
        try:
            del self.all_models[app_label][model_name]
            del self.app_configs[app_label].models[model_name]
        except KeyError:
            pass
        self.clear_cache()


class ModelState(object):
    """
@@ -368,7 +432,7 @@ class ModelState(object):
        for mgr_name, manager in self.managers:
            body[mgr_name] = manager

        # Then, make a Model object
        # Then, make a Model object (apps.register_model is called in __new__)
        return type(
            str(self.name),
            bases,
Loading