Commit c4b2a326 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Add support for unique_together

parent b139315f
Loading
Loading
Loading
Loading
+52 −1
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ class BaseDatabaseSchemaEditor(object):

    # Overrideable SQL templates
    sql_create_table = "CREATE TABLE %(table)s (%(definition)s)"
    sql_create_table_unique = "UNIQUE (%(columns)s)"
    sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
    sql_delete_table = "DROP TABLE %(table)s CASCADE"

@@ -51,7 +52,7 @@ class BaseDatabaseSchemaEditor(object):
    sql_create_fk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
    sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"

    sql_create_index = "CREATE %(unique)s INDEX %(name)s ON %(table)s (%(columns)s)%s;"
    sql_create_index = "CREATE %(unique)s INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s;"
    sql_delete_index = "DROP INDEX %(name)s"

    sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s PRIMARY KEY (%(columns)s)"
@@ -174,6 +175,17 @@ class BaseDatabaseSchemaEditor(object):
                definition,
            ))
            params.extend(extra_params)
            # Indexes
            if field.db_index:
                self.deferred_sql.append(
                    self.sql_create_index % {
                        "unique": "",
                        "name": self._create_index_name(model, [field.column], suffix=""),
                        "table": self.quote_name(model._meta.db_table),
                        "columns": self.quote_name(field.column),
                        "extra": "",
                    }
                )
            # FK
            if field.rel:
                to_table = field.rel.to._meta.db_table
@@ -191,6 +203,12 @@ class BaseDatabaseSchemaEditor(object):
                        "to_column": self.quote_name(to_column),
                    }
                )
        # Add any unique_togethers
        for fields in model._meta.unique_together:
            columns = [model._meta.get_field_by_name(field)[0].column for field in fields]
            column_sqls.append(self.sql_create_table_unique % {
                "columns": ", ".join(self.quote_name(column) for column in columns),
            })
        # Make the table
        sql = self.sql_create_table % {
            "table": model._meta.db_table,
@@ -210,6 +228,39 @@ class BaseDatabaseSchemaEditor(object):
            "table": self.quote_name(model._meta.db_table),
        })

    def alter_unique_together(self, model, old_unique_together, new_unique_together):
        """
        Deals with a model changing its unique_together.
        Note: The input unique_togethers must be doubly-nested, not the single-
        nested ["foo", "bar"] format.
        """
        olds = set(frozenset(fields) for fields in old_unique_together)
        news = set(frozenset(fields) for fields in new_unique_together)
        # Deleted uniques
        for fields in olds.difference(news):
            columns = [model._meta.get_field_by_name(field)[0].column for field in fields]
            constraint_names = self._constraint_names(model, list(columns), unique=True)
            if len(constraint_names) != 1:
                raise ValueError("Found wrong number (%s) of constraints for %s(%s)" % (
                    len(constraint_names),
                    model._meta.db_table,
                    ", ".join(columns),
                ))
            self.execute(
                self.sql_delete_unique % {
                    "table": self.quote_name(model._meta.db_table),
                    "name": constraint_names[0],
                },
            )
        # Created uniques
        for fields in news.difference(olds):
            columns = [model._meta.get_field_by_name(field)[0].column for field in fields]
            self.execute(self.sql_create_unique % {
                "table": self.quote_name(model._meta.db_table),
                "name": self._create_index_name(model, columns, suffix="_uniq"),
                "columns": ", ".join(self.quote_name(column) for column in columns),
            })

    def create_field(self, model, field, keep_default=False):
        """
        Creates a field on a model.
+12 −0
Original line number Diff line number Diff line
@@ -32,3 +32,15 @@ class Book(models.Model):
class Tag(models.Model):
    title = models.CharField(max_length=255)
    slug = models.SlugField(unique=True)

    class Meta:
        managed = False


class UniqueTest(models.Model):
    year = models.IntegerField()
    slug = models.SlugField(unique=False)

    class Meta:
        managed = False
        unique_together = ["year", "slug"]
+46 −2
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField, CharField, SlugField
from django.db.models.fields.related import ManyToManyField
from django.db.models.loading import cache
from .models import Author, Book, AuthorWithM2M, Tag
from .models import Author, Book, AuthorWithM2M, Tag, UniqueTest


class SchemaTests(TestCase):
@@ -18,7 +18,7 @@ class SchemaTests(TestCase):
    as the code it is testing.
    """

    models = [Author, Book, AuthorWithM2M, Tag]
    models = [Author, Book, AuthorWithM2M, Tag, UniqueTest]

    # Utility functions

@@ -298,3 +298,47 @@ class SchemaTests(TestCase):
        Tag.objects.create(title="foo", slug="foo")
        self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
        connection.rollback()

    def test_unique_together(self):
        """
        Tests removing and adding unique_together constraints on a model.
        """
        # Create the table
        editor = connection.schema_editor()
        editor.start()
        editor.create_model(UniqueTest)
        editor.commit()
        # Ensure the fields are unique to begin with
        UniqueTest.objects.create(year=2012, slug="foo")
        UniqueTest.objects.create(year=2011, slug="foo")
        UniqueTest.objects.create(year=2011, slug="bar")
        self.assertRaises(IntegrityError, UniqueTest.objects.create, year=2012, slug="foo")
        connection.rollback()
        # Alter the model to it's non-unique-together companion
        editor = connection.schema_editor()
        editor.start()
        editor.alter_unique_together(
            UniqueTest,
            UniqueTest._meta.unique_together,
            [],
        )
        editor.commit()
        # Ensure the fields are no longer unique
        UniqueTest.objects.create(year=2012, slug="foo")
        UniqueTest.objects.create(year=2012, slug="foo")
        connection.rollback()
        # Alter it back
        new_new_field = SlugField(unique=True)
        new_new_field.set_attributes_from_name("slug")
        editor = connection.schema_editor()
        editor.start()
        editor.alter_unique_together(
            UniqueTest,
            [],
            UniqueTest._meta.unique_together,
        )
        editor.commit()
        # Ensure the fields are unique again
        UniqueTest.objects.create(year=2012, slug="foo")
        self.assertRaises(IntegrityError, UniqueTest.objects.create, year=2012, slug="foo")
        connection.rollback()