Commit a92bae0f authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Repoint ForeignKeys when their to= changes.

parent d683263f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -118,7 +118,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
                    "columns": set(),
                    "primary_key": kind.lower() == "primary key",
                    "unique": kind.lower() in ["primary key", "unique"],
                    "foreign_key": set([tuple(x.split(".", 1)) for x in used_cols]) if kind.lower() == "foreign key" else None,
                    "foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None,
                    "check": False,
                    "index": False,
                }
+30 −2
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@ class BaseDatabaseSchemaEditor(object):
    commit() is called.

    TODO:
        - Repointing of FKs
        - Repointing of M2Ms
        - Check constraints (PosIntField)
    """
@@ -401,6 +400,22 @@ class BaseDatabaseSchemaEditor(object):
                        "name": index_name,
                    }
                )
        # Drop any FK constraints, we'll remake them later
        if getattr(old_field, "rel"):
            fk_names = self._constraint_names(model, [old_field.column], foreign_key=True)
            if strict and len(fk_names) != 1:
                raise ValueError("Found wrong number (%s) of foreign key constraints for %s.%s" % (
                    len(fk_names),
                    model._meta.db_table,
                    old_field.column,
                ))
            for fk_name in fk_names:
                self.execute(
                    self.sql_delete_fk % {
                        "table": self.quote_name(model._meta.db_table),
                        "name": fk_name,
                    }
                )
        # Have they renamed the column?
        if old_field.column != new_field.column:
            self.execute(self.sql_rename_column % {
@@ -516,6 +531,17 @@ class BaseDatabaseSchemaEditor(object):
                    "columns": self.quote_name(new_field.column),
                }
            )
        # Does it have a foreign key?
        if getattr(new_field, "rel"):
            self.execute(
                self.sql_create_fk % {
                    "table": self.quote_name(model._meta.db_table),
                    "name": self._create_index_name(model, [new_field.column], suffix="_fk"),
                    "column": self.quote_name(new_field.column),
                    "to_table": self.quote_name(new_field.rel.to._meta.db_table),
                    "to_column": self.quote_name(new_field.rel.get_related_field().column),
                }
            )

    def _type_for_alter(self, field):
        """
@@ -543,7 +569,7 @@ class BaseDatabaseSchemaEditor(object):
            index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
        return index_name

    def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None):
    def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None):
        "Returns all constraint names matching the columns and conditions"
        column_names = set(column_names) if column_names else None
        constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
@@ -556,5 +582,7 @@ class BaseDatabaseSchemaEditor(object):
                    continue
                if index is not None and infodict['index'] != index:
                    continue
                if foreign_key is not None and not infodict['foreign_key']:
                    continue
                result.append(name)
        return result
+25 −4
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from django.test import TestCase
from django.utils.unittest import skipUnless
from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField, CharField, SlugField
from django.db.models.fields.related import ManyToManyField
from django.db.models.fields.related import ManyToManyField, ForeignKey
from django.db.models.loading import cache
from .models import Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest

@@ -114,15 +114,16 @@ class SchemaTests(TestCase):
        )

    @skipUnless(connection.features.supports_foreign_keys, "No FK support")
    def test_creation_fk(self):
        "Tests that creating tables out of FK order works"
    def test_fk(self):
        "Tests that creating tables out of FK order, then repointing, works"
        # Create the table
        editor = connection.schema_editor()
        editor.start()
        editor.create_model(Book)
        editor.create_model(Author)
        editor.create_model(Tag)
        editor.commit()
        # Check that both tables are there
        # Check that initial tables are there
        try:
            list(Author.objects.all())
        except DatabaseError, e:
@@ -139,6 +140,26 @@ class SchemaTests(TestCase):
                pub_date = datetime.datetime.now(),
            )
            connection.commit()
        # Repoint the FK constraint
        new_field = ForeignKey(Tag)
        new_field.set_attributes_from_name("author")
        editor = connection.schema_editor()
        editor.start()
        editor.alter_field(
            Book,
            Book._meta.get_field_by_name("author")[0],
            new_field,
            strict=True,
        )
        editor.commit()
        # Make sure the new FK constraint is present
        constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
        for name, details in constraints.items():
            if details['columns'] == set(["author_id"]) and details['foreign_key']:
                self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
                break
        else:
            self.fail("No FK constraint for author_id found")

    def test_create_field(self):
        """