Commit bacbbb48 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

RunSQL migration operation and alpha SeparateDatabaseAndState op'n.

parent 9079436b
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from .models import CreateModel, DeleteModel, AlterModelTable, AlterUniqueTogether, AlterIndexTogether
from .fields import AddField, RemoveField, AlterField, RenameField
from .special import SeparateDatabaseAndState, RunSQL
+95 −0
Original line number Diff line number Diff line
import re
from .base import Operation
from django.db import models, router
from django.db.migrations.state import ModelState


class SeparateDatabaseAndState(Operation):
    """
    Takes two lists of operations - ones that will be used for the database,
    and ones that will be used for the state change. This allows operations
    that don't support state change to have it applied, or have operations
    that affect the state or not the database, or so on.
    """

    def __init__(self, database_operations=None, state_operations=None):
        self.database_operations = database_operations or []
        self.state_operations = state_operations or []

    def state_forwards(self, app_label, state):
        for state_operation in self.state_operations:
            state_operation.state_forwards(app_label, state)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        # We calculate state separately in here since our state functions aren't useful
        for database_operation in self.database_operations:
            to_state = from_state.clone()
            database_operation.state_forwards(app_label, to_state)
            database_operation.database_forwards(self, app_label, schema_editor, from_state, to_state)
            from_state = to_state

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        # We calculate state separately in here since our state functions aren't useful
        base_state = to_state
        for pos, database_operation in enumerate(reversed(self.database_operations)):
            to_state = base_state.clone()
            for dbop in self.database_operations[:-(pos+1)]:
                dbop.state_forwards(app_label, to_state)
            from_state = base_state.clone()
            database_operation.state_forwards(app_label, from_state)
            database_operation.database_backwards(self, app_label, schema_editor, from_state, to_state)

    def describe(self):
        return "Custom state/database change combination"


class RunSQL(Operation):
    """
    Runs some raw SQL - a single statement by default, but it will attempt
    to parse and split it into multiple statements if multiple=True.

    A reverse SQL statement may be provided.

    Also accepts a list of operations that represent the state change effected
    by this SQL change, in case it's custom column/table creation/deletion.
    """

    def __init__(self, sql, reverse_sql=None, state_operations=None, multiple=False):
        self.sql = sql
        self.reverse_sql = reverse_sql
        self.state_operations = state_operations or []
        self.multiple = multiple

    def state_forwards(self, app_label, state):
        for state_operation in self.state_operations:
            state_operation.state_forwards(app_label, state)

    def _split_sql(self, sql):
        regex = r"(?mx) ([^';]* (?:'[^']*'[^';]*)*)"
        comment_regex = r"(?mx) (?:^\s*$)|(?:--.*$)"
        # First, strip comments
        sql = "\n".join([x.strip().replace("%", "%%") for x in re.split(comment_regex, sql) if x.strip()])
        # Now get each statement
        for st in re.split(regex, sql)[1:][::2]:
            yield st

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        if self.multiple:
            statements = self._split_sql(self.sql)
        else:
            statements = [self.sql]
        for statement in statements:
            schema_editor.execute(statement)

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        if self.reverse_sql is None:
            raise NotImplementedError("You cannot reverse this operation")
        if self.multiple:
            statements = self._split_sql(self.reverse_sql)
        else:
            statements = [self.reverse_sql]
        for statement in statements:
            schema_editor.execute(statement)

    def describe(self):
        return "Raw SQL operation"
+26 −0
Original line number Diff line number Diff line
@@ -280,6 +280,32 @@ class OperationTests(MigrationTestBase):
            operation.database_backwards("test_alinto", editor, new_state, project_state)
        self.assertIndexNotExists("test_alinto_pony", ["pink", "weight"])

    def test_run_sql(self):
        """
        Tests the AlterIndexTogether operation.
        """
        project_state = self.set_up_test_model("test_runsql")
        # Create the operation
        operation = migrations.RunSQL(
            "CREATE TABLE i_love_ponies (id int, special_thing int)",
            "DROP TABLE i_love_ponies",
            state_operations = [migrations.CreateModel("SomethingElse", [("id", models.AutoField(primary_key=True))])],
        )
        # Test the state alteration
        new_state = project_state.clone()
        operation.state_forwards("test_runsql", new_state)
        self.assertEqual(len(new_state.models["test_runsql", "somethingelse"].fields), 1)
        # Make sure there's no table
        self.assertTableNotExists("i_love_ponies")
        # Test the database alteration
        with connection.schema_editor() as editor:
            operation.database_forwards("test_runsql", editor, project_state, new_state)
        self.assertTableExists("i_love_ponies")
        # And test reversal
        with connection.schema_editor() as editor:
            operation.database_backwards("test_runsql", editor, new_state, project_state)
        self.assertTableNotExists("i_love_ponies")


class MigrateNothingRouter(object):
    """