Commit 3a6580e4 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Make get_constraints return columns in order

parent 61ff46cf
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
import re
from .base import FIELD_TYPE

from django.utils.datastructures import SortedSet
from django.db.backends import BaseDatabaseIntrospection, FieldInfo
from django.utils.encoding import force_text

@@ -141,7 +141,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
        for constraint, column, ref_table, ref_column in cursor.fetchall():
            if constraint not in constraints:
                constraints[constraint] = {
                    'columns': set(),
                    'columns': SortedSet(),
                    'primary_key': False,
                    'unique': False,
                    'index': False,
@@ -169,7 +169,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
        for table, non_unique, index, colseq, column in [x[:5] for x in cursor.fetchall()]:
            if index not in constraints:
                constraints[index] = {
                    'columns': set(),
                    'columns': SortedSet(),
                    'primary_key': False,
                    'unique': False,
                    'index': True,
@@ -178,5 +178,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
                }
            constraints[index]['index'] = True
            constraints[index]['columns'].add(column)
        # Convert the sorted sets to lists
        for constraint in constraints.values():
            constraint['columns'] = list(constraint['columns'])
        # Return
        return constraints
+7 −10
Original line number Diff line number Diff line
@@ -140,7 +140,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
            # If we're the first column, make the record
            if constraint not in constraints:
                constraints[constraint] = {
                    "columns": set(),
                    "columns": [],
                    "primary_key": kind.lower() == "primary key",
                    "unique": kind.lower() in ["primary key", "unique"],
                    "foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None,
@@ -148,7 +148,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
                    "index": False,
                }
            # Record the details
            constraints[constraint]['columns'].add(column)
            constraints[constraint]['columns'].append(column)
        # Now get CHECK constraint columns
        cursor.execute("""
            SELECT kc.constraint_name, kc.column_name
@@ -166,7 +166,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
            # If we're the first column, make the record
            if constraint not in constraints:
                constraints[constraint] = {
                    "columns": set(),
                    "columns": [],
                    "primary_key": False,
                    "unique": False,
                    "foreign_key": False,
@@ -174,17 +174,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
                    "index": False,
                }
            # Record the details
            constraints[constraint]['columns'].add(column)
            constraints[constraint]['columns'].append(column)
        # Now get indexes
        cursor.execute("""
            SELECT
                c2.relname,
                ARRAY(
                    SELECT attr.attname
                    FROM unnest(idx.indkey) i, pg_catalog.pg_attribute attr
                    WHERE
                        attr.attnum = i AND
                        attr.attrelid = c.oid
                    SELECT (SELECT attname FROM pg_catalog.pg_attribute WHERE attnum = i AND attrelid = c.oid)
                    FROM unnest(idx.indkey) i
                ),
                idx.indisunique,
                idx.indisprimary
@@ -197,7 +194,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
        for index, columns, unique, primary in cursor.fetchall():
            if index not in constraints:
                constraints[index] = {
                    "columns": set(columns),
                    "columns": list(columns),
                    "primary_key": primary,
                    "unique": unique,
                    "foreign_key": False,
+7 −6
Original line number Diff line number Diff line
@@ -87,6 +87,7 @@ class BaseDatabaseSchemaEditor(object):
        cursor = self.connection.cursor()
        # Log the command we're running, then run it
        logger.debug("%s; (params %r)" % (sql, params))
        #print("%s; (params %r)" % (sql, params))
        cursor.execute(sql, params)

    def quote_name(self, name):
@@ -228,12 +229,12 @@ class BaseDatabaseSchemaEditor(object):
        Note: The input unique_togethers must be doubly-nested, not the single-
        nested ["foo", "bar"] format.
        """
        olds = set(frozenset(fields) for fields in old_unique_together)
        news = set(frozenset(fields) for fields in new_unique_together)
        olds = set(tuple(fields) for fields in old_unique_together)
        news = set(tuple(fields) for fields in new_unique_together)
        # Deleted uniques
        for fields in olds.difference(news):
            columns = [model._meta.get_field_by_name(field)[0].column for field in fields]
            constraint_names = self._constraint_names(model, list(columns), unique=True)
            constraint_names = self._constraint_names(model, columns, unique=True)
            if len(constraint_names) != 1:
                raise ValueError("Found wrong number (%s) of constraints for %s(%s)" % (
                    len(constraint_names),
@@ -261,8 +262,8 @@ class BaseDatabaseSchemaEditor(object):
        Note: The input index_togethers must be doubly-nested, not the single-
        nested ["foo", "bar"] format.
        """
        olds = set(frozenset(fields) for fields in old_index_together)
        news = set(frozenset(fields) for fields in new_index_together)
        olds = set(tuple(fields) for fields in old_index_together)
        news = set(tuple(fields) for fields in new_index_together)
        # Deleted indexes
        for fields in olds.difference(news):
            columns = [model._meta.get_field_by_name(field)[0].column for field in fields]
@@ -646,7 +647,7 @@ class BaseDatabaseSchemaEditor(object):
        """
        Returns all constraint names matching the columns and conditions
        """
        column_names = set(column_names) if column_names else None
        column_names = list(column_names) if column_names else None
        constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
        result = []
        for name, infodict in constraints.items():
+3 −3
Original line number Diff line number Diff line
@@ -197,14 +197,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
            for index_rank, column_rank, column in cursor.fetchall():
                if index not in constraints:
                    constraints[index] = {
                        "columns": set(),
                        "columns": [],
                        "primary_key": False,
                        "unique": bool(unique),
                        "foreign_key": False,
                        "check": False,
                        "index": True,
                    }
                constraints[index]['columns'].add(column)
                constraints[index]['columns'].append(column)
        # Get the PK
        pk_column = self.get_primary_key_column(cursor, table_name)
        if pk_column:
@@ -213,7 +213,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
            # deletes PK constraints by name, as you can't delete constraints
            # in SQLite; we remake the table with a new PK instead.
            constraints["__primary__"] = {
                "columns": set([pk_column]),
                "columns": [pk_column],
                "primary_key": True,
                "unique": False,  # It's not actually a unique constraint.
                "foreign_key": False,
+9 −9
Original line number Diff line number Diff line
@@ -128,7 +128,7 @@ class SchemaTests(TransactionTestCase):
        # Make sure the new FK constraint is present
        constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
        for name, details in constraints.items():
            if details['columns'] == set(["author_id"]) and details['foreign_key']:
            if details['columns'] == ["author_id"] and details['foreign_key']:
                self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
                break
        else:
@@ -285,7 +285,7 @@ class SchemaTests(TransactionTestCase):
        constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
        if connection.features.supports_foreign_keys:
            for name, details in constraints.items():
                if details['columns'] == set(["tag_id"]) and details['foreign_key']:
                if details['columns'] == ["tag_id"] and details['foreign_key']:
                    self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
                    break
            else:
@@ -306,7 +306,7 @@ class SchemaTests(TransactionTestCase):
            constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
            if connection.features.supports_foreign_keys:
                for name, details in constraints.items():
                    if details['columns'] == set(["uniquetest_id"]) and details['foreign_key']:
                    if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
                        self.assertEqual(details['foreign_key'], ('schema_uniquetest', 'id'))
                        break
                else:
@@ -327,7 +327,7 @@ class SchemaTests(TransactionTestCase):
        # Ensure the constraint exists
        constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
        for name, details in constraints.items():
            if details['columns'] == set(["height"]) and details['check']:
            if details['columns'] == ["height"] and details['check']:
                break
        else:
            self.fail("No check constraint for height found")
@@ -343,7 +343,7 @@ class SchemaTests(TransactionTestCase):
            )
        constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
        for name, details in constraints.items():
            if details['columns'] == set(["height"]) and details['check']:
            if details['columns'] == ["height"] and details['check']:
                self.fail("Check constraint for height found")
        # Alter the column to re-add it
        with connection.schema_editor() as editor:
@@ -355,7 +355,7 @@ class SchemaTests(TransactionTestCase):
            )
        constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
        for name, details in constraints.items():
            if details['columns'] == set(["height"]) and details['check']:
            if details['columns'] == ["height"] and details['check']:
                break
        else:
            self.fail("No check constraint for height found")
@@ -465,7 +465,7 @@ class SchemaTests(TransactionTestCase):
            any(
                c["index"]
                for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
                if c['columns'] == set(["slug", "title"])
                if c['columns'] == ["slug", "title"]
            ),
        )
        # Alter the model to add an index
@@ -481,7 +481,7 @@ class SchemaTests(TransactionTestCase):
            any(
                c["index"]
                for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
                if c['columns'] == set(["slug", "title"])
                if c['columns'] == ["slug", "title"]
            ),
        )
        # Alter it back
@@ -499,7 +499,7 @@ class SchemaTests(TransactionTestCase):
            any(
                c["index"]
                for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
                if c['columns'] == set(["slug", "title"])
                if c['columns'] == ["slug", "title"]
            ),
        )