Commit bb39037f authored by Matthew Schinckel's avatar Matthew Schinckel Committed by Tim Graham
Browse files

Fixed #22788 -- Ensured custom migration operations can be written.

This inspects the migration operation, and if it is not in the
django.db.migrations module, it adds the relevant imports to the
migration writer and uses the correct class name.
parent 37a8f5ae
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import sys
import types

from django.apps import apps
from django.db import models
from django.db import models, migrations
from django.db.migrations.loader import MigrationLoader
from django.utils import datetime_safe, six
from django.utils.encoding import force_text
@@ -44,7 +44,15 @@ class OperationWriter(object):
        argspec = inspect.getargspec(self.operation.__init__)
        normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs)

        # See if this operation is in django.db.migrations. If it is,
        # We can just use the fact we already have that imported,
        # otherwise, we need to add an import for the operation class.
        if getattr(migrations, name, None) == self.operation.__class__:
            self.feed('migrations.%s(' % name)
        else:
            imports.add('import %s' % (self.operation.__class__.__module__))
            self.feed('%s.%s(' % (self.operation.__class__.__module__, name))

        self.indent()
        for arg_name in argspec.args[1:]:
            arg_value = normalized_kwargs[arg_name]
+0 −0

Empty file added.

+22 −0
Original line number Diff line number Diff line
from django.db.migrations.operations.base import Operation


class TestOperation(Operation):
    def __init__(self):
        pass

    @property
    def reversible(self):
        return True

    def state_forwards(self, app_label, state):
        pass

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        pass

    def state_backwards(self, app_label, state):
        pass

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        pass
+26 −0
Original line number Diff line number Diff line
from django.db.migrations.operations.base import Operation


class TestOperation(Operation):
    def __init__(self):
        pass

    @property
    def reversible(self):
        return True

    def state_forwards(self, app_label, state):
        pass

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        pass

    def state_backwards(self, app_label, state):
        pass

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        pass


class CreateModel(TestOperation):
    pass
+22 −0
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@ from django.utils.deconstruct import deconstructible
from django.utils.translation import ugettext_lazy as _
from django.utils.timezone import get_default_timezone

import custom_migration_operations.operations
import custom_migration_operations.more_operations


class TestModel1(object):
    def upload_to(self):
@@ -222,3 +225,22 @@ class WriterTests(TestCase):
                expected_path = os.path.join(base_dir, *(app.split('.') + ['migrations', '0001_initial.py']))
                writer = MigrationWriter(migration)
                self.assertEqual(writer.path, expected_path)

    def test_custom_operation(self):
        migration = type(str("Migration"), (migrations.Migration,), {
            "operations": [
                custom_migration_operations.operations.TestOperation(),
                custom_migration_operations.operations.CreateModel(),
                migrations.CreateModel("MyModel", (), {}, (models.Model,)),
                custom_migration_operations.more_operations.TestOperation()
            ],
            "dependencies": []
        })
        writer = MigrationWriter(migration)
        output = writer.as_string()
        result = self.safe_exec(output)
        self.assertIn("custom_migration_operations", result)
        self.assertNotEqual(
            result['custom_migration_operations'].operations.TestOperation,
            result['custom_migration_operations'].more_operations.TestOperation
        )