Commit 375178fc authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Add M2M repointing

parent a92bae0f
Loading
Loading
Loading
Loading
+22 −10
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@ class BaseDatabaseSchemaEditor(object):
    commit() is called.

    TODO:
        - Repointing of M2Ms
        - Check constraints (PosIntField)
    """

@@ -154,13 +153,13 @@ class BaseDatabaseSchemaEditor(object):

    # Actions

    def create_model(self, model):
    def create_model(self, model, force=False):
        """
        Takes a model and creates a table for it in the database.
        Will also create any accompanying indexes or unique constraints.
        """
        # Do nothing if this is an unmanaged or proxy model
        if not model._meta.managed or model._meta.proxy:
        if not force and (not model._meta.managed or model._meta.proxy):
            return
        # Create column SQL, add FK deferreds if needed
        column_sqls = []
@@ -214,13 +213,16 @@ class BaseDatabaseSchemaEditor(object):
            "definition": ", ".join(column_sqls)
        }
        self.execute(sql, params)
        # Make M2M tables
        for field in model._meta.local_many_to_many:
            self.create_model(field.rel.through, force=True)

    def delete_model(self, model):
    def delete_model(self, model, force=False):
        """
        Deletes a model from the database.
        """
        # Do nothing if this is an unmanaged or proxy model
        if not model._meta.managed or model._meta.proxy:
        if not force and (not model._meta.managed or model._meta.proxy):
            return
        # Delete the table
        self.execute(self.sql_delete_table % {
@@ -287,7 +289,7 @@ class BaseDatabaseSchemaEditor(object):
        """
        # Special-case implicit M2M tables
        if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
            return self.create_model(field.rel.through)
            return self.create_model(field.rel.through, force=True)
        # Get the column's definition
        definition, params = self.column_sql(model, field, include_default=True)
        # It might not actually have a column behind it
@@ -358,11 +360,10 @@ class BaseDatabaseSchemaEditor(object):
        # Ensure this field is even column-based
        old_type = old_field.db_type(connection=self.connection)
        new_type = self._type_for_alter(new_field)
        if old_type is None and new_type is None:
            # TODO: Handle M2M fields being repointed
            return
        if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
            return self._alter_many_to_many(model, old_field, new_field, strict)
        elif old_type is None or new_type is None:
            raise ValueError("Cannot alter field %s into %s - they are not compatible types" % (
            raise ValueError("Cannot alter field %s into %s - they are not compatible types (probably means only one is an M2M with implicit through model)" % (
                    old_field,
                    new_field,
                ))
@@ -543,6 +544,17 @@ class BaseDatabaseSchemaEditor(object):
                }
            )

    def _alter_many_to_many(self, model, old_field, new_field, strict):
        "Alters M2Ms to repoint their to= endpoints."
        # Rename the through table
        self.alter_db_table(old_field.rel.through, old_field.rel.through._meta.db_table, new_field.rel.through._meta.db_table)
        # Repoint the FK to the other side
        self.alter_field(
            new_field.rel.through,
            old_field.rel.through._meta.get_field_by_name(old_field.m2m_reverse_field_name())[0],
            new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0],
        )

    def _type_for_alter(self, field):
        """
        Returns a field's type suitable for ALTER COLUMN.
+25 −4
Original line number Diff line number Diff line
@@ -101,11 +101,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
        # Ensure this field is even column-based
        old_type = old_field.db_type(connection=self.connection)
        new_type = self._type_for_alter(new_field)
        if old_type is None and new_type is None:
            # TODO: Handle M2M fields being repointed
            return
        if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
            return self._alter_many_to_many(model, old_field, new_field, strict)
        elif old_type is None or new_type is None:
            raise ValueError("Cannot alter field %s into %s - they are not compatible types" % (
            raise ValueError("Cannot alter field %s into %s - they are not compatible types (probably means only one is an M2M with implicit through model)" % (
                    old_field,
                    new_field,
                ))
@@ -114,3 +113,25 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):

    def alter_unique_together(self, model, old_unique_together, new_unique_together):
        self._remake_table(model, override_uniques=new_unique_together)

    def _alter_many_to_many(self, model, old_field, new_field, strict):
        "Alters M2Ms to repoint their to= endpoints."
        # Make a new through table
        self.create_model(new_field.rel.through)
        # Copy the data across
        self.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % (
            self.quote_name(new_field.rel.through._meta.db_table),
            ', '.join([
                "id",
                new_field.m2m_column_name(),
                new_field.m2m_reverse_name(),
            ]),
            ', '.join([
                "id",
                old_field.m2m_column_name(),
                old_field.m2m_reverse_name(),
            ]),
            self.quote_name(old_field.rel.through._meta.db_table),
        ))
        # Delete the old through table
        self.delete_model(old_field.rel.through, force=True)
+10 −0
Original line number Diff line number Diff line
@@ -29,6 +29,16 @@ class Book(models.Model):
        managed = False


class BookWithM2M(models.Model):
    author = models.ForeignKey(Author)
    title = models.CharField(max_length=100, db_index=True)
    pub_date = models.DateTimeField()
    tags = models.ManyToManyField("Tag", related_name="books")

    class Meta:
        managed = False


class BookWithSlug(models.Model):
    author = models.ForeignKey(Author)
    title = models.CharField(max_length=100, db_index=True)
+62 −2
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField, CharField, SlugField
from django.db.models.fields.related import ManyToManyField, ForeignKey
from django.db.models.loading import cache
from .models import Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest
from .models import Author, Book, BookWithSlug, BookWithM2M, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest


class SchemaTests(TestCase):
@@ -19,7 +19,7 @@ class SchemaTests(TestCase):
    as the code it is testing.
    """

    models = [Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest]
    models = [Author, Book, BookWithSlug, BookWithM2M, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest]

    # Utility functions

@@ -248,6 +248,21 @@ class SchemaTests(TestCase):
        self.assertEqual(columns['display_name'][0], "CharField")
        self.assertNotIn("name", columns)

    def test_m2m_create(self):
        """
        Tests M2M fields on models during creation
        """
        # Create the tables
        editor = connection.schema_editor()
        editor.start()
        editor.create_model(Author)
        editor.create_model(Tag)
        editor.create_model(BookWithM2M)
        editor.commit()
        # Ensure there is now an m2m table there
        columns = self.column_classes(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
        self.assertEqual(columns['tag_id'][0], "IntegerField")

    def test_m2m(self):
        """
        Tests adding/removing M2M fields on models
@@ -287,6 +302,51 @@ class SchemaTests(TestCase):
        self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
        connection.rollback()

    def test_m2m_repoint(self):
        """
        Tests repointing M2M fields
        """
        # Create the tables
        editor = connection.schema_editor()
        editor.start()
        editor.create_model(Author)
        editor.create_model(BookWithM2M)
        editor.create_model(Tag)
        editor.create_model(UniqueTest)
        editor.commit()
        # Ensure the M2M exists and points to Tag
        constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
        if connection.features.supports_foreign_keys:
            for name, details in constraints.items():
                if details['columns'] == set(["tag_id"]) and details['foreign_key']:
                    self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
                    break
            else:
                self.fail("No FK constraint for tag_id found")
        # Repoint the M2M
        new_field = ManyToManyField(UniqueTest)
        new_field.contribute_to_class(BookWithM2M, "uniques")
        editor = connection.schema_editor()
        editor.start()
        editor.alter_field(
            Author,
            BookWithM2M._meta.get_field_by_name("tags")[0],
            new_field,
        )
        editor.commit()
        # Ensure old M2M is gone
        self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
        connection.rollback()
        # Ensure the new M2M exists and points to UniqueTest
        constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
        if connection.features.supports_foreign_keys:
            for name, details in constraints.items():
                if details['columns'] == set(["uniquetest_id"]) and details['foreign_key']:
                    self.assertEqual(details['foreign_key'], ('schema_uniquetest', 'id'))
                    break
            else:
                self.fail("No FK constraint for tag_id found")

    def test_unique(self):
        """
        Tests removing and adding unique constraints to a single column.