Commit 959a3f97 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Add some field schema alteration methods and tests.

parent 8ba5bf31
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -419,6 +419,9 @@ class BaseDatabaseFeatures(object):
    # Can we roll back DDL in a transaction?
    can_rollback_ddl = False

    # Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
    supports_combined_alters = False

    def __init__(self, connection):
        self.connection = connection

+1 −0
Original line number Diff line number Diff line
@@ -85,6 +85,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
    supports_tablespaces = True
    can_distinct_on_fields = True
    can_rollback_ddl = True
    supports_combined_alters = True

class DatabaseWrapper(BaseDatabaseWrapper):
    vendor = 'postgresql'
+224 −40
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from django.conf import settings
from django.db import transaction
from django.db.utils import load_backend
from django.utils.log import getLogger
from django.db.models.fields.related import ManyToManyField

logger = getLogger('django.db.backends.schema')

@@ -29,11 +30,15 @@ class BaseDatabaseSchemaEditor(object):
    sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
    sql_delete_table = "DROP TABLE %(table)s CASCADE"

    sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(definition)s"
    sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s"
    sql_alter_column = "ALTER TABLE %(table)s %(changes)s"
    sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
    sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL"
    sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL"
    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE;"
    sql_alter_column_default = "ALTER COLUMN %(column)s SET DEFAULT %(default)s"
    sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
    sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"

    sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
    sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
@@ -91,6 +96,59 @@ class BaseDatabaseSchemaEditor(object):
    def quote_name(self, name):
        return self.connection.ops.quote_name(name)

    # Field <-> database mapping functions

    def column_sql(self, model, field, include_default=False):
        """
        Takes a field and returns its column definition.
        The field must already have had set_attributes_from_name called.
        """
        # Get the column's type and use that as the basis of the SQL
        sql = field.db_type(connection=self.connection)
        params = []
        # Check for fields that aren't actually columns (e.g. M2M)
        if sql is None:
            return None
        # Optionally add the tablespace if it's an implicitly indexed column
        tablespace = field.db_tablespace or model._meta.db_tablespace
        if tablespace and self.connection.features.supports_tablespaces and field.unique:
            sql += " %s" % self.connection.ops.tablespace_sql(tablespace, inline=True)
        # Work out nullability
        null = field.null
        # Oracle treats the empty string ('') as null, so coerce the null
        # option whenever '' is a possible value.
        if (field.empty_strings_allowed and not field.primary_key and
                self.connection.features.interprets_empty_strings_as_nulls):
            null = True
        if null:
            sql += " NULL"
        else:
            sql += " NOT NULL"
        # Primary key/unique outputs
        if field.primary_key:
            sql += " PRIMARY KEY"
        elif field.unique:
            sql += " UNIQUE"
        # If we were told to include a default value, do so
        if include_default:
            sql += " DEFAULT %s"
            params += [self.effective_default(field)]
        # Return the sql
        return sql, params

    def effective_default(self, field):
        "Returns a field's effective database default value"
        if field.has_default():
            default = field.get_default()
        elif not field.null and field.blank and field.empty_strings_allowed:
            default = ""
        else:
            default = None
        # If it's a callable, call it
        if callable(default):
            default = default()
        return default

    # Actions

    def create_model(self, model):
@@ -100,18 +158,20 @@ class BaseDatabaseSchemaEditor(object):
        """
        # Do nothing if this is an unmanaged or proxy model
        if not model._meta.managed or model._meta.proxy:
            return [], {}
            return
        # Create column SQL, add FK deferreds if needed
        column_sqls = []
        params = []
        for field in model._meta.local_fields:
            # SQL
            definition = self.column_sql(model, field)
            definition, extra_params = self.column_sql(model, field)
            if definition is None:
                continue
            column_sqls.append("%s %s" % (
                self.quote_name(field.column),
                definition,
            ))
            params.extend(extra_params)
            # FK
            if field.rel:
                to_table = field.rel.to._meta.db_table
@@ -134,45 +194,169 @@ class BaseDatabaseSchemaEditor(object):
            "table": model._meta.db_table,
            "definition": ", ".join(column_sqls)
        }
        self.execute(sql)
        self.execute(sql, params)

    def column_sql(self, model, field, include_default=False):
    def delete_model(self, model):
        """
        Takes a field and returns its column definition.
        The field must already have had set_attributes_from_name called.
        Deletes a model from the database.
        """
        # Get the column's type and use that as the basis of the SQL
        sql = field.db_type(connection=self.connection)
        # Check for fields that aren't actually columns (e.g. M2M)
        if sql is None:
            return None
        # Optionally add the tablespace if it's an implicitly indexed column
        tablespace = field.db_tablespace or model._meta.db_tablespace
        if tablespace and self.connection.features.supports_tablespaces and field.unique:
            sql += " %s" % self.connection.ops.tablespace_sql(tablespace, inline=True)
        # Work out nullability
        null = field.null
        # Oracle treats the empty string ('') as null, so coerce the null
        # option whenever '' is a possible value.
        if (field.empty_strings_allowed and not field.primary_key and
                self.connection.features.interprets_empty_strings_as_nulls):
            null = True
        if null:
            sql += " NULL"
        else:
            sql += " NOT NULL"
        # Primary key/unique outputs
        if field.primary_key:
            sql += " PRIMARY KEY"
        elif field.unique:
            sql += " UNIQUE"
        # If we were told to include a default value, do so
        if include_default:
            raise NotImplementedError()
        # Return the sql
        return sql

    def delete_model(self, model):
        # Do nothing if this is an unmanaged or proxy model
        if not model._meta.managed or model._meta.proxy:
            return
        # Delete the table
        self.execute(self.sql_delete_table % {
            "table": self.quote_name(model._meta.db_table),
        })

    def create_field(self, model, field, keep_default=False):
        """
        Creates a field on a model.
        Usually involves adding a column, but may involve adding a
        table instead (for M2M fields)
        """
        # Special-case implicit M2M tables
        if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
            return self.create_model(field.rel.through)
        # Get the column's definition
        definition, params = self.column_sql(model, field, include_default=True)
        # It might not actually have a column behind it
        if definition is None:
            return
        # Build the SQL and run it
        sql = self.sql_create_column % {
            "table": self.quote_name(model._meta.db_table),
            "column": self.quote_name(field.column),
            "definition": definition,
        }
        self.execute(sql, params)
        # Drop the default if we need to
        # (Django usually does not use in-database defaults)
        if not keep_default and field.default is not None:
            sql = self.sql_alter_column % {
                "table": self.quote_name(model._meta.db_table),
                "changes": self.sql_alter_column_no_default % {
                    "column": self.quote_name(field.column),
                }
            }
        # Add any FK constraints later
        if field.rel:
            to_table = field.rel.to._meta.db_table
            to_column = field.rel.to._meta.get_field(field.rel.field_name).column
            self.deferred_sql.append(
                self.sql_create_fk % {
                    "name": '%s_refs_%s_%x' % (
                        field.column,
                        to_column,
                        abs(hash((model._meta.db_table, to_table)))
                    ),
                    "table": self.quote_name(model._meta.db_table),
                    "column": self.quote_name(field.column),
                    "to_table": self.quote_name(to_table),
                    "to_column": self.quote_name(to_column),
                }
            )

    def delete_field(self, model, field):
        """
        Removes a field from a model. Usually involves deleting a column,
        but for M2Ms may involve deleting a table.
        """
        # Special-case implicit M2M tables
        if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
            return self.delete_model(field.rel.through)
        # Get the column's definition
        definition, params = self.column_sql(model, field)
        # It might not actually have a column behind it
        if definition is None:
            return
        # Delete the column
        sql = self.sql_delete_column % {
            "table": self.quote_name(model._meta.db_table),
            "column": self.quote_name(field.column),
        }
        self.execute(sql)

    def alter_field(self, model, old_field, new_field):
        """
        Allows a field's type, uniqueness, nullability, default, column,
        constraints etc. to be modified.
        Requires a copy of the old field as well so we can only perform
        changes that are required.
        """
        # Ensure this field is even column-based
        old_type = old_field.db_type(connection=self.connection)
        new_type = new_field.db_type(connection=self.connection)
        if old_type is None and new_type is None:
            # TODO: Handle M2M fields being repointed
            return
        elif old_type is None or new_type is None:
            raise ValueError("Cannot alter field %s into %s - they are not compatible types" % (
                    old_field,
                    new_field,
                ))
        # First, have they renamed the column?
        if old_field.column != new_field.column:
            self.execute(self.sql_rename_column % {
                "table": self.quote_name(model._meta.db_table),
                "old_column": self.quote_name(old_field.column),
                "new_column": self.quote_name(new_field.column),
            })
        # Next, start accumulating actions to do
        actions = []
        # Type change?
        if old_type != new_type:
            actions.append((
                self.sql_alter_column_type % {
                    "column": self.quote_name(new_field.column),
                    "type": new_type,
                },
                [],
            ))
        # Default change?
        old_default = self.effective_default(old_field)
        new_default = self.effective_default(new_field)
        if old_default != new_default:
            if new_default is None:
                actions.append((
                    self.sql_alter_column_no_default % {
                        "column": self.quote_name(new_field.column),
                    },
                    [],
                ))
            else:
                actions.append((
                    self.sql_alter_column_default % {
                        "column": self.quote_name(new_field.column),
                        "default": "%s",
                    },
                    [new_default],
                ))
        # Nullability change?
        if old_field.null != new_field.null:
            if new_field.null:
                actions.append((
                    self.sql_alter_column_null % {
                        "column": self.quote_name(new_field.column),
                    },
                    [],
                ))
            else:
                actions.append((
                    self.sql_alter_column_null % {
                        "column": self.quote_name(new_field.column),
                    },
                    [],
                ))
        # Combine actions together if we can (e.g. postgres)
        if self.connection.features.supports_combined_alters:
            sql, params = tuple(zip(*actions))
            actions = [(", ".join(sql), params)]
        # Apply those actions
        for sql, params in actions:
            self.execute(
                self.sql_alter_column % {
                    "table": self.quote_name(model._meta.db_table),
                    "changes": sql,
                },
                params,
            )
+73 −1
Original line number Diff line number Diff line
@@ -2,8 +2,9 @@ from __future__ import absolute_import
import copy
import datetime
from django.test import TestCase
from django.db.models.loading import cache
from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField
from django.db.models.loading import cache
from .models import Author, Book


@@ -18,6 +19,8 @@ class SchemaTests(TestCase):

    models = [Author, Book]

    # Utility functions

    def setUp(self):
        # Make sure we're in manual transaction mode
        connection.commit_unless_managed()
@@ -51,6 +54,18 @@ class SchemaTests(TestCase):
        cache.app_store = self.old_app_store
        cache._get_models_cache = {}

    def column_classes(self, model):
        cursor = connection.cursor()
        return dict(
            (d[0], (connection.introspection.get_field_type(d[1], d), d))
            for d in connection.introspection.get_table_description(
                cursor,
                model._meta.db_table,
            )
        )

    # Tests

    def test_creation_deletion(self):
        """
        Tries creating a model's table, and then deleting it.
@@ -100,3 +115,60 @@ class SchemaTests(TestCase):
                pub_date = datetime.datetime.now(),
            )
            connection.commit()

    def test_create_field(self):
        """
        Tests adding fields to models
        """
        # Create the table
        editor = connection.schema_editor()
        editor.start()
        editor.create_model(Author)
        editor.commit()
        # Ensure there's no age field
        columns = self.column_classes(Author)
        self.assertNotIn("age", columns)
        # Alter the name field to a TextField
        new_field = IntegerField(null=True)
        new_field.set_attributes_from_name("age")
        editor = connection.schema_editor()
        editor.start()
        editor.create_field(
            Author,
            new_field,
        )
        editor.commit()
        # Ensure the field is right afterwards
        columns = self.column_classes(Author)
        self.assertEqual(columns['age'][0], "IntegerField")
        self.assertEqual(columns['age'][1][6], True)

    def test_alter(self):
        """
        Tests simple altering of fields
        """
        # Create the table
        editor = connection.schema_editor()
        editor.start()
        editor.create_model(Author)
        editor.commit()
        # Ensure the field is right to begin with
        columns = self.column_classes(Author)
        self.assertEqual(columns['name'][0], "CharField")
        self.assertEqual(columns['name'][1][3], 255)
        self.assertEqual(columns['name'][1][6], False)
        # Alter the name field to a TextField
        new_field = TextField(null=True)
        new_field.set_attributes_from_name("name")
        editor = connection.schema_editor()
        editor.start()
        editor.alter_field(
            Author,
            Author._meta.get_field_by_name("name")[0],
            new_field,
        )
        editor.commit()
        # Ensure the field is right afterwards
        columns = self.column_classes(Author)
        self.assertEqual(columns['name'][0], "TextField")
        self.assertEqual(columns['name'][1][6], True)