Commit 85f6d893 authored by Markus Holtermann's avatar Markus Holtermann Committed by Tim Graham
Browse files

Fixed #23426 -- Allowed parameters in migrations.RunSQL

Thanks tchaumeny and Loic for reviews.
parent d49993fa
Loading
Loading
Loading
Loading
+18 −6
Original line number Diff line number Diff line
@@ -64,20 +64,32 @@ class RunSQL(Operation):
            state_operation.state_forwards(app_label, state)

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        statements = schema_editor.connection.ops.prepare_sql_script(self.sql)
        for statement in statements:
            schema_editor.execute(statement, params=None)
        self._run_sql(schema_editor, self.sql)

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        if self.reverse_sql is None:
            raise NotImplementedError("You cannot reverse this operation")
        statements = schema_editor.connection.ops.prepare_sql_script(self.reverse_sql)
        for statement in statements:
            schema_editor.execute(statement, params=None)
        self._run_sql(schema_editor, self.reverse_sql)

    def describe(self):
        return "Raw SQL operation"

    def _run_sql(self, schema_editor, sql):
        if isinstance(sql, (list, tuple)):
            for sql in sql:
                params = None
                if isinstance(sql, (list, tuple)):
                    elements = len(sql)
                    if elements == 2:
                        sql, params = sql
                    else:
                        raise ValueError("Expected a 2-tuple but got %d" % elements)
                schema_editor.execute(sql, params=params)
        else:
            statements = schema_editor.connection.ops.prepare_sql_script(sql)
            for statement in statements:
                schema_editor.execute(statement, params=None)


class RunPython(Operation):
    """
+18 −2
Original line number Diff line number Diff line
@@ -188,6 +188,17 @@ the database. On most database backends (all but PostgreSQL), Django will
split the SQL into individual statements prior to executing them. This
requires installing the sqlparse_ Python library.

You can also pass a list of strings or 2-tuples. The latter is used for passing
queries and parameters in the same way as :ref:`cursor.execute()
<executing-custom-sql>`. These three operations are equivalent::

    migrations.RunSQL("INSERT INTO musician (name) VALUES ('Reinhardt');")
    migrations.RunSQL(["INSERT INTO musician (name) VALUES ('Reinhardt');", None])
    migrations.RunSQL(["INSERT INTO musician (name) VALUES (%s);", ['Reinhardt']])

If you want to include literal percent signs in the query, you have to double
them if you are passing parameters.

The ``state_operations`` argument is so you can supply operations that are
equivalent to the SQL in terms of project state; for example, if you are
manually creating a column, you should pass in a list containing an ``AddField``
@@ -197,8 +208,13 @@ operation that adds that field and so will try to run it again).

.. versionchanged:: 1.7.1

    If you want to include literal percent signs in the query you don't need to
    double them anymore.
    If you want to include literal percent signs in a query without parameters
    you don't need to double them anymore.

.. versionchanged:: 1.8

    The ability to pass parameters to the ``sql`` and ``reverse_sql`` queries
    was added.

.. _sqlparse: https://pypi.python.org/pypi/sqlparse

+6 −0
Original line number Diff line number Diff line
@@ -265,6 +265,12 @@ Management Commands

* :djadmin:`makemigrations` can now serialize timezone-aware values.

Migrations
^^^^^^^^^^

* The :class:`~django.db.migrations.operations.RunSQL` operation can now handle
  parameters passed to the SQL statements.

Models
^^^^^^

+81 −0
Original line number Diff line number Diff line
@@ -1195,6 +1195,87 @@ class OperationTests(OperationTestBase):
            operation.database_backwards("test_runsql", editor, new_state, project_state)
        self.assertTableNotExists("i_love_ponies")

    def test_run_sql_params(self):
        """
        #23426 - RunSQL should accept parameters.
        """
        project_state = self.set_up_test_model("test_runsql")
        # Create the operation
        operation = migrations.RunSQL(
            "CREATE TABLE i_love_ponies (id int, special_thing varchar(15));",
            "DROP TABLE i_love_ponies",
        )
        param_operation = migrations.RunSQL(
            # forwards
            (
                "INSERT INTO i_love_ponies (id, special_thing) VALUES (1, 'Django');",
                ["INSERT INTO i_love_ponies (id, special_thing) VALUES (2, %s);", ['Ponies']],
                ("INSERT INTO i_love_ponies (id, special_thing) VALUES (%s, %s);", (3, 'Python',)),
            ),
            # backwards
            [
                "DELETE FROM i_love_ponies WHERE special_thing = 'Django';",
                ["DELETE FROM i_love_ponies WHERE special_thing = 'Ponies';", None],
                ("DELETE FROM i_love_ponies WHERE id = %s OR special_thing = %s;", [3, 'Python']),
            ]
        )

        # Make sure there's no table
        self.assertTableNotExists("i_love_ponies")
        new_state = project_state.clone()
        # Test the database alteration
        with connection.schema_editor() as editor:
            operation.database_forwards("test_runsql", editor, project_state, new_state)

        # Test parameter passing
        with connection.schema_editor() as editor:
            param_operation.database_forwards("test_runsql", editor, project_state, new_state)
        # Make sure all the SQL was processed
        with connection.cursor() as cursor:
            cursor.execute("SELECT COUNT(*) FROM i_love_ponies")
            self.assertEqual(cursor.fetchall()[0][0], 3)

        with connection.schema_editor() as editor:
            param_operation.database_backwards("test_runsql", editor, new_state, project_state)
        with connection.cursor() as cursor:
            cursor.execute("SELECT COUNT(*) FROM i_love_ponies")
            self.assertEqual(cursor.fetchall()[0][0], 0)

        # And test reversal
        with connection.schema_editor() as editor:
            operation.database_backwards("test_runsql", editor, new_state, project_state)
        self.assertTableNotExists("i_love_ponies")

    def test_run_sql_params_invalid(self):
        """
        #23426 - RunSQL should fail when a list of statements with an incorrect
        number of tuples is given.
        """
        project_state = self.set_up_test_model("test_runsql")
        new_state = project_state.clone()
        operation = migrations.RunSQL(
            # forwards
            [
                ["INSERT INTO foo (bar) VALUES ('buz');"]
            ],
            # backwards
            (
                ("DELETE FROM foo WHERE bar = 'buz';", 'invalid', 'parameter count'),
            ),
        )

        with connection.schema_editor() as editor:
            self.assertRaisesRegexp(ValueError,
                "Expected a 2-tuple but got 1",
                operation.database_forwards,
                "test_runsql", editor, project_state, new_state)

        with connection.schema_editor() as editor:
            self.assertRaisesRegexp(ValueError,
                "Expected a 2-tuple but got 3",
                operation.database_backwards,
                "test_runsql", editor, new_state, project_state)

    def test_run_python(self):
        """
        Tests the RunPython operation