Commit f7e90711 authored by Akshesh's avatar Akshesh Committed by Tim Graham
Browse files

Refs #26709 -- Added IndexOperation to avoid code duplication.

parent b1e7d19d
Loading
Loading
Loading
Loading
+17 −9
Original line number Diff line number Diff line
@@ -744,7 +744,15 @@ class AlterModelManagers(ModelOptionOperation):
        return "Change managers on %s" % (self.name, )


class AddIndex(Operation):
class IndexOperation(Operation):
    option_name = 'indexes'

    @cached_property
    def model_name_lower(self):
        return self.model_name.lower()


class AddIndex(IndexOperation):
    """
    Add an index on a model.
    """
@@ -759,9 +767,9 @@ class AddIndex(Operation):
        self.index = index

    def state_forwards(self, app_label, state):
        model_state = state.models[app_label, self.model_name.lower()]
        model_state = state.models[app_label, self.model_name_lower]
        self.index.model = state.apps.get_model(app_label, self.model_name)
        model_state.options['indexes'].append(self.index)
        model_state.options[self.option_name].append(self.index)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        schema_editor.add_index(self.index)
@@ -787,7 +795,7 @@ class AddIndex(Operation):
        )


class RemoveIndex(Operation):
class RemoveIndex(IndexOperation):
    """
    Remove an index from a model.
    """
@@ -797,17 +805,17 @@ class RemoveIndex(Operation):
        self.name = name

    def state_forwards(self, app_label, state):
        model_state = state.models[app_label, self.model_name.lower()]
        indexes = model_state.options['indexes']
        model_state.options['indexes'] = [idx for idx in indexes if idx.name != self.name]
        model_state = state.models[app_label, self.model_name_lower]
        indexes = model_state.options[self.option_name]
        model_state.options[self.option_name] = [idx for idx in indexes if idx.name != self.name]

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        from_model_state = from_state.models[app_label, self.model_name.lower()]
        from_model_state = from_state.models[app_label, self.model_name_lower]
        index = from_model_state.get_index_by_name(self.name)
        schema_editor.remove_index(index)

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        to_model_state = to_state.models[app_label, self.model_name.lower()]
        to_model_state = to_state.models[app_label, self.model_name_lower]
        index = to_model_state.get_index_by_name(self.name)
        schema_editor.add_index(index)