Commit e9a456d1 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Fixed #22581: Pass default values for schema through get_db_prep_save()

parent fc974313
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -167,6 +167,9 @@ class BaseDatabaseSchemaEditor(object):
        # If it's a callable, call it
        if callable(default):
            default = default()
        # Run it through the field's get_db_prep_save method so we can send it
        # to the database.
        default = field.get_db_prep_save(default, self.connection)
        return default

    def quote_value(self, value):
+5 −2
Original line number Diff line number Diff line
from decimal import Decimal
from django.utils import six
from django.apps.registry import Apps
from django.db.backends.schema import BaseDatabaseSchemaEditor
@@ -19,6 +20,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
        # Manual emulation of SQLite parameter quoting
        if isinstance(value, type(True)):
            return str(int(value))
        elif isinstance(value, (Decimal, float)):
            return str(value)
        elif isinstance(value, six.integer_types):
            return str(value)
        elif isinstance(value, six.string_types):
@@ -26,7 +29,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
        elif value is None:
            return "NULL"
        else:
            raise ValueError("Cannot quote parameter value %r" % value)
            raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))

    def _remake_table(self, model, create_fields=[], delete_fields=[], alter_fields=[], rename_fields=[], override_uniques=None):
        """
@@ -52,7 +55,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
            # If there's a default, insert it into the copy map
            if field.has_default():
                mapping[field.column] = self.quote_value(
                    field.get_default()
                    self.effective_default(field)
                )
        # Add in any altered fields
        for (old_field, new_field) in alter_fields:
+34 −0
Original line number Diff line number Diff line
@@ -231,6 +231,40 @@ class SchemaTests(TransactionTestCase):
        else:
            self.assertEqual(field_type, 'BooleanField')

    def test_add_field_default_transform(self):
        """
        Tests adding fields to models with a default that is not directly
        valid in the database (#22581)
        """
        class TestTransformField(IntegerField):
            # Weird field that saves the count of items in its value
            def get_default(self):
                return self.default
            def get_prep_value(self, value):
                if value is None:
                    return 0
                return len(value)
        # Create the table
        with connection.schema_editor() as editor:
            editor.create_model(Author)
        # Add some rows of data
        Author.objects.create(name="Andrew", height=30)
        Author.objects.create(name="Andrea")
        # Add the field with a default it needs to cast (to string in this case)
        new_field = TestTransformField(default={1:2})
        new_field.set_attributes_from_name("thing")
        with connection.schema_editor() as editor:
            editor.add_field(
                Author,
                new_field,
            )
        # Ensure the field is there
        columns = self.column_classes(Author)
        field_type, field_info = columns['thing']
        self.assertEqual(field_type, 'IntegerField')
        # Make sure the values were transformed correctly
        self.assertEqual(Author.objects.extra(where=["thing = 1"]).count(), 2)

    def test_alter(self):
        """
        Tests simple altering of fields