Commit 423244bc authored by Claude Paroz's avatar Claude Paroz
Browse files

Fixed #4680 -- Improved initial_sql parsing

In particular, allow the '--' sequence to be present in string
values without being interpreted as comment marker.
Thanks Tim Chase for the report and shaleh for the initial patch.
parent 32a4df6c
Loading
Loading
Loading
Loading
+17 −10
Original line number Diff line number Diff line
@@ -136,6 +136,20 @@ def sql_all(app, style, connection):
    "Returns a list of CREATE TABLE SQL, initial-data inserts, and CREATE INDEX SQL for the given module."
    return sql_create(app, style, connection) + sql_custom(app, style, connection) + sql_indexes(app, style, connection)

def _split_statements(content):
    comment_re = re.compile(r"^((?:'[^']*'|[^'])*?)--.*$")
    statements = []
    statement = ""
    for line in content.split("\n"):
        cleaned_line = comment_re.sub(r"\1", line).strip()
        if not cleaned_line:
            continue
        statement += cleaned_line
        if statement.endswith(";"):
            statements.append(statement)
            statement = ""
    return statements

def custom_sql_for_model(model, style, connection):
    opts = model._meta
    app_dir = os.path.normpath(os.path.join(os.path.dirname(models.get_app(model._meta.app_label).__file__), 'sql'))
@@ -149,10 +163,6 @@ def custom_sql_for_model(model, style, connection):
        for f in post_sql_fields:
            output.extend(f.post_create_sql(style, model._meta.db_table))

    # Some backends can't execute more than one SQL statement at a time,
    # so split into separate statements.
    statements = re.compile(r";[ \t]*$", re.M)

    # Find custom SQL, if it's available.
    backend_name = connection.settings_dict['ENGINE'].split('.')[-1]
    sql_files = [os.path.join(app_dir, "%s.%s.sql" % (opts.object_name.lower(), backend_name)),
@@ -160,12 +170,9 @@ def custom_sql_for_model(model, style, connection):
    for sql_file in sql_files:
        if os.path.exists(sql_file):
            with open(sql_file, 'U') as fp:
                for statement in statements.split(fp.read().decode(settings.FILE_CHARSET)):
                    # Remove any comments from the file
                    statement = re.sub(r"--.*([\n\Z]|$)", "", statement)
                    if statement.strip():
                        output.append(statement + ";")

                # Some backends can't execute more than one SQL statement at a time,
                # so split into separate statements.
                output.extend(_split_statements(fp.read().decode(settings.FILE_CHARSET)))
    return output


+4 −2
Original line number Diff line number Diff line
INSERT INTO initial_sql_regress_simple (name) VALUES ('John');
-- a comment
INSERT INTO initial_sql_regress_simple (name) VALUES ('John'); -- another comment
INSERT INTO initial_sql_regress_simple (name) VALUES ('-- Comment Man');
INSERT INTO initial_sql_regress_simple (name) VALUES ('Paul');
INSERT INTO initial_sql_regress_simple (name) VALUES ('Ringo');
INSERT INTO initial_sql_regress_simple (name) VALUES ('George');
+18 −4
Original line number Diff line number Diff line
@@ -4,12 +4,26 @@ from .models import Simple


class InitialSQLTests(TestCase):
    def test_initial_sql(self):
    # The format of the included SQL file for this test suite is important.
    # It must end with a trailing newline in order to test the fix for #2161.

        # However, as pointed out by #14661, test data loaded by custom SQL
    def test_initial_sql(self):
        # As pointed out by #14661, test data loaded by custom SQL
        # can't be relied upon; as a result, the test framework flushes the
        # data contents before every test. This test validates that this has
        # occurred.
        self.assertEqual(Simple.objects.count(), 0)

    def test_custom_sql(self):
        from django.core.management.sql import custom_sql_for_model
        from django.core.management.color import no_style
        from django.db import connections, DEFAULT_DB_ALIAS

        # Simulate the custom SQL loading by syncdb
        connection = connections[DEFAULT_DB_ALIAS]
        custom_sql = custom_sql_for_model(Simple, no_style(), connection)
        self.assertEqual(len(custom_sql), 8)
        cursor = connection.cursor()
        for sql in custom_sql:
            cursor.execute(sql)
        self.assertEqual(Simple.objects.count(), 8)