Commit d59f1993 authored by Loic Bistuer's avatar Loic Bistuer Committed by Tim Graham
Browse files

Made MigrationWriter look for a "deconstruct" attribute on functions.

Refs #20978.
parent 5df8f749
Loading
Loading
Loading
Loading
+23 −17
Original line number Diff line number Diff line
@@ -73,6 +73,26 @@ class MigrationWriter(object):
                raise ImportError("Cannot open migrations module %s for app %s" % (migrations_module_name, self.migration.app_label))
        return os.path.join(basedir, self.filename)

    @classmethod
    def serialize_deconstructed(cls, path, args, kwargs):
        module, name = path.rsplit(".", 1)
        if module == "django.db.models":
            imports = set(["from django.db import models"])
            name = "models.%s" % name
        else:
            imports = set(["import %s" % module])
            name = path
        arg_strings = []
        for arg in args:
            arg_string, arg_imports = cls.serialize(arg)
            arg_strings.append(arg_string)
            imports.update(arg_imports)
        for kw, arg in kwargs.items():
            arg_string, arg_imports = cls.serialize(arg)
            imports.update(arg_imports)
            arg_strings.append("%s=%s" % (kw, arg_string))
        return "%s(%s)" % (name, ", ".join(arg_strings)), imports

    @classmethod
    def serialize(cls, value):
        """
@@ -119,23 +139,7 @@ class MigrationWriter(object):
        # Django fields
        elif isinstance(value, models.Field):
            attr_name, path, args, kwargs = value.deconstruct()
            module, name = path.rsplit(".", 1)
            if module == "django.db.models":
                imports = set(["from django.db import models"])
                name = "models.%s" % name
            else:
                imports = set(["import %s" % module])
                name = path
            arg_strings = []
            for arg in args:
                arg_string, arg_imports = cls.serialize(arg)
                arg_strings.append(arg_string)
                imports.update(arg_imports)
            for kw, arg in kwargs.items():
                arg_string, arg_imports = cls.serialize(arg)
                imports.update(arg_imports)
                arg_strings.append("%s=%s" % (kw, arg_string))
            return "%s(%s)" % (name, ", ".join(arg_strings)), imports
            return cls.serialize_deconstructed(path, args, kwargs)
        # Functions
        elif isinstance(value, (types.FunctionType, types.BuiltinFunctionType)):
            # Special-cases, as these don't have im_class
@@ -152,6 +156,8 @@ class MigrationWriter(object):
                klass = value.im_class
                module = klass.__module__
                return "%s.%s.%s" % (module, klass.__name__, value.__name__), set(["import %s" % module])
            elif hasattr(value, 'deconstruct'):
                return cls.serialize_deconstructed(*value.deconstruct())
            elif value.__name__ == '<lambda>':
                raise ValueError("Cannot serialize function: lambda")
            elif value.__module__ is None:
+1 −0
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ def SET(value):
    else:
        def set_on_delete(collector, field, sub_objs, using):
            collector.add_field_update(field, value, sub_objs)
    set_on_delete.deconstruct = lambda: ('django.db.models.SET', (value,), {})
    return set_on_delete


+4 −0
Original line number Diff line number Diff line
@@ -63,6 +63,10 @@ class WriterTests(TestCase):
        # Functions
        with six.assertRaisesRegex(self, ValueError, 'Cannot serialize function: lambda'):
            self.assertSerializedEqual(lambda x: 42)
        self.assertSerializedEqual(models.SET_NULL)
        string, imports = MigrationWriter.serialize(models.SET(42))
        self.assertEqual(string, 'models.SET(42)')
        self.serialize_round_trip(models.SET(42))
        # Datetime stuff
        self.assertSerializedEqual(datetime.datetime.utcnow())
        self.assertSerializedEqual(datetime.datetime.utcnow)