Commit f633ba77 authored by Markus Holtermann's avatar Markus Holtermann Committed by Loic Bistuer
Browse files

Fixed #23609 -- Fixed IntegrityError that prevented altering a NULL column...

Fixed #23609 -- Fixed IntegrityError that prevented altering a NULL column into a NOT NULL one due to existing rows

Thanks to Simon Charette, Loic Bistuer and Tim Graham for the review.
parent 15d350fb
Loading
Loading
Loading
Loading
+40 −3
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ class BaseDatabaseSchemaEditor(object):
    sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
    sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
    sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"

    sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
    sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
@@ -533,12 +534,19 @@ class BaseDatabaseSchemaEditor(object):
            })
        # Next, start accumulating actions to do
        actions = []
        null_actions = []
        post_actions = []
        # Type change?
        if old_type != new_type:
            fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type)
            actions.append(fragment)
            post_actions.extend(other_actions)
        # When changing a column NULL constraint to NOT NULL with a given
        # default value, we need to perform 4 steps:
        #  1. Add a default for new incoming writes
        #  2. Update existing NULL rows with new default
        #  3. Replace NULL constraint with NOT NULL
        #  4. Drop the default again.
        # Default change?
        old_default = self.effective_default(old_field)
        new_default = self.effective_default(new_field)
@@ -573,7 +581,7 @@ class BaseDatabaseSchemaEditor(object):
        # Nullability change?
        if old_field.null != new_field.null:
            if new_field.null:
                actions.append((
                null_actions.append((
                    self.sql_alter_column_null % {
                        "column": self.quote_name(new_field.column),
                        "type": new_type,
@@ -581,14 +589,23 @@ class BaseDatabaseSchemaEditor(object):
                    [],
                ))
            else:
                actions.append((
                null_actions.append((
                    self.sql_alter_column_not_null % {
                        "column": self.quote_name(new_field.column),
                        "type": new_type,
                    },
                    [],
                ))
        if actions:
        # Only if we have a default and there is a change from NULL to NOT NULL
        four_way_default_alteration = (
            new_field.has_default() and
            (old_field.null and not new_field.null)
        )
        if actions or null_actions:
            if not four_way_default_alteration:
                # If we don't have to do a 4-way default alteration we can
                # directly run a (NOT) NULL alteration
                actions = actions + null_actions
            # Combine actions together if we can (e.g. postgres)
            if self.connection.features.supports_combined_alters:
                sql, params = tuple(zip(*actions))
@@ -602,6 +619,26 @@ class BaseDatabaseSchemaEditor(object):
                    },
                    params,
                )
            if four_way_default_alteration:
                # Update existing rows with default value
                self.execute(
                    self.sql_update_with_default % {
                        "table": self.quote_name(model._meta.db_table),
                        "column": self.quote_name(new_field.column),
                        "default": "%s",
                    },
                    [new_default],
                )
                # Since we didn't run a NOT NULL change before we need to do it
                # now
                for sql, params in null_actions:
                    self.execute(
                        self.sql_alter_column % {
                            "table": self.quote_name(model._meta.db_table),
                            "changes": sql,
                        },
                        params,
                    )
        if post_actions:
            for sql, params in post_actions:
                self.execute(sql, params)
+8 −1
Original line number Diff line number Diff line
@@ -78,6 +78,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
            del body[old_field.name]
            del mapping[old_field.column]
            body[new_field.name] = new_field
            if old_field.null and not new_field.null:
                case_sql = "coalesce(%(col)s, %(default)s)" % {
                    'col': self.quote_name(old_field.column),
                    'default': self.quote_value(self.effective_default(new_field))
                }
                mapping[new_field.column] = case_sql
            else:
                mapping[new_field.column] = self.quote_name(old_field.column)
            rename_mapping[old_field.name] = new_field.name
        # Remove any deleted fields
+13 −3
Original line number Diff line number Diff line
@@ -6,8 +6,8 @@ import datetime
from itertools import chain

from django.utils import six
from django.db import models
from django.conf import settings
from django.db import models
from django.db.migrations import operations
from django.db.migrations.migration import Migration
from django.db.migrations.questioner import MigrationQuestioner
@@ -838,7 +838,6 @@ class MigrationAutodetector(object):
        for app_label, model_name, field_name in sorted(self.old_field_keys.intersection(self.new_field_keys)):
            # Did the field change?
            old_model_name = self.renamed_models.get((app_label, model_name), model_name)
            new_model_state = self.to_state.models[app_label, model_name]
            old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name)
            old_field = self.old_apps.get_model(app_label, old_model_name)._meta.get_field_by_name(old_field_name)[0]
            new_field = self.new_apps.get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0]
@@ -854,12 +853,23 @@ class MigrationAutodetector(object):
            old_field_dec = self.deep_deconstruct(old_field)
            new_field_dec = self.deep_deconstruct(new_field)
            if old_field_dec != new_field_dec:
                preserve_default = True
                if (old_field.null and not new_field.null and not new_field.has_default() and
                        not isinstance(new_field, models.ManyToManyField)):
                    field = new_field.clone()
                    new_default = self.questioner.ask_not_null_alteration(field_name, model_name)
                    if new_default is not models.NOT_PROVIDED:
                        field.default = new_default
                        preserve_default = False
                else:
                    field = new_field
                self.add_operation(
                    app_label,
                    operations.AlterField(
                        model_name=model_name,
                        name=field_name,
                        field=new_model_state.get_field_by_name(field_name),
                        field=field,
                        preserve_default=preserve_default,
                    )
                )

+12 −2
Original line number Diff line number Diff line
@@ -104,14 +104,20 @@ class AlterField(Operation):
    Alters a field's database column (e.g. null, max_length) to the provided new field
    """

    def __init__(self, model_name, name, field):
    def __init__(self, model_name, name, field, preserve_default=True):
        self.model_name = model_name
        self.name = name
        self.field = field
        self.preserve_default = preserve_default

    def state_forwards(self, app_label, state):
        if not self.preserve_default:
            field = self.field.clone()
            field.default = NOT_PROVIDED
        else:
            field = self.field
        state.models[app_label, self.model_name.lower()].fields = [
            (n, self.field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
            (n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
        ]

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
@@ -128,7 +134,11 @@ class AlterField(Operation):
                    from_field.rel.to = to_field.rel.to
                elif to_field.rel and isinstance(to_field.rel.to, six.string_types):
                    to_field.rel.to = from_field.rel.to
            if not self.preserve_default:
                to_field.default = self.field.default
            schema_editor.alter_field(from_model, from_field, to_field)
            if not self.preserve_default:
                to_field.default = NOT_PROVIDED

    def database_backwards(self, app_label, schema_editor, from_state, to_state):
        self.database_forwards(app_label, schema_editor, from_state, to_state)
+56 −24
Original line number Diff line number Diff line
from __future__ import unicode_literals
from __future__ import print_function, unicode_literals

import importlib
import os
import sys

from django.apps import apps
from django.db.models.fields import NOT_PROVIDED
from django.utils import datetime_safe, six, timezone
from django.utils.six.moves import input

@@ -55,6 +56,11 @@ class MigrationQuestioner(object):
        # None means quit
        return None

    def ask_not_null_alteration(self, field_name, model_name):
        "Changing a NULL field to NOT NULL"
        # None means quit
        return None

    def ask_rename(self, model_name, old_name, new_name, field_instance):
        "Was this field really renamed?"
        return self.defaults.get("ask_rename", False)
@@ -92,24 +98,9 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
                pass
            result = input("Please select a valid option: ")

    def ask_not_null_addition(self, field_name, model_name):
        "Adding a NOT NULL field to a model"
        if not self.dry_run:
            choice = self._choice_input(
                "You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) +
                "we can't do that (the database needs something to populate existing rows).\n" +
                "Please select a fix:",
                [
                    "Provide a one-off default now (will be set on all existing rows)",
                    "Quit, and let me add a default in models.py",
                ]
            )
            if choice == 2:
                sys.exit(3)
            else:
    def _ask_default(self):
        print("Please enter the default value now, as valid Python")
                print("The datetime and django.utils.timezone modules are "
                      "available, so you can do e.g. timezone.now()")
        print("The datetime and django.utils.timezone modules are available, so you can do e.g. timezone.now()")
        while True:
            if six.PY3:
                # Six does not correctly abstract over the fact that
@@ -127,6 +118,47 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
                    return eval(code, {}, {"datetime": datetime_safe, "timezone": timezone})
                except (SyntaxError, NameError) as e:
                    print("Invalid input: %s" % e)

    def ask_not_null_addition(self, field_name, model_name):
        "Adding a NOT NULL field to a model"
        if not self.dry_run:
            choice = self._choice_input(
                "You are trying to add a non-nullable field '%s' to %s without a default; "
                "we can't do that (the database needs something to populate existing rows).\n"
                "Please select a fix:" % (field_name, model_name),
                [
                    "Provide a one-off default now (will be set on all existing rows)",
                    "Quit, and let me add a default in models.py",
                ]
            )
            if choice == 2:
                sys.exit(3)
            else:
                return self._ask_default()
        return None

    def ask_not_null_alteration(self, field_name, model_name):
        "Changing a NULL field to NOT NULL"
        if not self.dry_run:
            choice = self._choice_input(
                "You are trying to change the nullable field '%s' on %s to non-nullable "
                "without a default; we can't do that (the database needs something to "
                "populate existing rows).\n"
                "Please select a fix:" % (field_name, model_name),
                [
                    "Provide a one-off default now (will be set on all existing rows)",
                    ("Ignore for now, and let me handle existing rows with NULL myself "
                     "(e.g. adding a RunPython or RunSQL operation in the new migration "
                     "file before the AlterField operation)"),
                    "Quit, and let me add a default in models.py",
                ]
            )
            if choice == 2:
                return NOT_PROVIDED
            elif choice == 3:
                sys.exit(3)
            else:
                return self._ask_default()
        return None

    def ask_rename(self, model_name, old_name, new_name, field_instance):
Loading