Commit 4c413e23 authored by Claude Paroz's avatar Claude Paroz
Browse files

Fixed #17785 -- Preferred column names in get_relations introspection

Thanks Thomas Güttler for the report and the initial patch, and
Tim Graham for the review.
parent b75c7079
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -73,11 +73,11 @@ class Command(BaseCommand):
                except NotImplementedError:
                    constraints = {}
                used_column_names = []  # Holds column names used in the table so far
                for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
                for row in connection.introspection.get_table_description(cursor, table_name):
                    comment_notes = []  # Holds Field notes, to be displayed in a Python comment.
                    extra_params = OrderedDict()  # Holds Field parameters such as 'db_column'.
                    column_name = row[0]
                    is_relation = i in relations
                    is_relation = column_name in relations

                    att_name, params, notes = self.normalize_col_name(
                        column_name, used_column_names, is_relation)
@@ -94,7 +94,7 @@ class Command(BaseCommand):
                            extra_params['unique'] = True

                    if is_relation:
                        rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
                        rel_to = "self" if relations[column_name][1] == table_name else table2model(relations[column_name][1])
                        if rel_to in known_models:
                            field_type = 'ForeignKey(%s' % rel_to
                        else:
+3 −13
Original line number Diff line number Diff line
@@ -80,25 +80,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
            )
        return fields

    def _name_to_index(self, cursor, table_name):
        """
        Returns a dictionary of {field_name: field_index} for the given table.
        Indexes are 0-based.
        """
        return {d[0]: i for i, d in enumerate(self.get_table_description(cursor, table_name))}

    def get_relations(self, cursor, table_name):
        """
        Returns a dictionary of {field_index: (field_index_other_table, other_table)}
        representing all relationships to the given table. Indexes are 0-based.
        Returns a dictionary of {field_name: (field_name_other_table, other_table)}
        representing all relationships to the given table.
        """
        my_field_dict = self._name_to_index(cursor, table_name)
        constraints = self.get_key_columns(cursor, table_name)
        relations = {}
        for my_fieldname, other_table, other_field in constraints:
            other_field_index = self._name_to_index(cursor, other_table)[other_field]
            my_field_index = my_field_dict[my_fieldname]
            relations[my_field_index] = (other_field_index, other_table)
            relations[my_fieldname] = (other_field, other_table)
        return relations

    def get_key_columns(self, cursor, table_name):
+3 −3
Original line number Diff line number Diff line
@@ -78,12 +78,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):

    def get_relations(self, cursor, table_name):
        """
        Returns a dictionary of {field_index: (field_index_other_table, other_table)}
        representing all relationships to the given table. Indexes are 0-based.
        Returns a dictionary of {field_name: (field_name_other_table, other_table)}
        representing all relationships to the given table.
        """
        table_name = table_name.upper()
        cursor.execute("""
    SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1
    SELECT ta.column_name, tb.table_name, tb.column_name
    FROM   user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb,
           user_tab_cols ta, user_tab_cols tb
    WHERE  user_constraints.table_name = %s AND
+10 −9
Original line number Diff line number Diff line
@@ -69,20 +69,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):

    def get_relations(self, cursor, table_name):
        """
        Returns a dictionary of {field_index: (field_index_other_table, other_table)}
        representing all relationships to the given table. Indexes are 0-based.
        Returns a dictionary of {field_name: (field_name_other_table, other_table)}
        representing all relationships to the given table.
        """
        cursor.execute("""
            SELECT con.conkey, con.confkey, c2.relname
            FROM pg_constraint con, pg_class c1, pg_class c2
            WHERE c1.oid = con.conrelid
                AND c2.oid = con.confrelid
                AND c1.relname = %s
            SELECT c2.relname, a1.attname, a2.attname
            FROM pg_constraint con
            LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
            LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
            LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
            LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
            WHERE c1.relname = %s
                AND con.contype = 'f'""", [table_name])
        relations = {}
        for row in cursor.fetchall():
            # row[0] and row[1] are single-item lists, so grab the single item.
            relations[row[0][0] - 1] = (row[1][0] - 1, row[2])
            relations[row[1]] = (row[2], row[0])
        return relations

    def get_key_columns(self, cursor, table_name):
+10 −11
Original line number Diff line number Diff line
@@ -106,23 +106,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
        # Walk through and look for references to other tables. SQLite doesn't
        # really have enforced references, but since it echoes out the SQL used
        # to create the table we can look for REFERENCES statements used there.
        field_names = []
        for field_index, field_desc in enumerate(results.split(',')):
        for field_desc in results.split(','):
            field_desc = field_desc.strip()
            if field_desc.startswith("UNIQUE"):
                continue

            field_names.append(field_desc.split()[0].strip('"'))
            m = re.search('references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
            if not m:
                continue

            table, column = [s.strip('"') for s in m.groups()]

            if field_desc.startswith("FOREIGN KEY"):
                # Find index of the target FK field
                # Find name of the target FK field
                m = re.match('FOREIGN KEY\(([^\)]*)\).*', field_desc, re.I)
                fkey_field = m.groups()[0].strip('"')
                field_index = field_names.index(fkey_field)
                field_name = m.groups()[0].strip('"')
            else:
                field_name = field_desc.split()[0].strip('"')

            cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
            result = cursor.fetchall()[0]
@@ -130,14 +129,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
            li, ri = other_table_results.index('('), other_table_results.rindex(')')
            other_table_results = other_table_results[li + 1:ri]

            for other_index, other_desc in enumerate(other_table_results.split(',')):
            for other_desc in other_table_results.split(','):
                other_desc = other_desc.strip()
                if other_desc.startswith('UNIQUE'):
                    continue

                name = other_desc.split(' ', 1)[0].strip('"')
                if name == column:
                    relations[field_index] = (other_index, table)
                other_name = other_desc.split(' ', 1)[0].strip('"')
                if other_name == column:
                    relations[field_name] = (other_name, table)
                    break

        return relations
Loading