Commit 395c6083 authored by Claude Paroz's avatar Claude Paroz
Browse files

Fixed #12460 -- Improved inspectdb handling of special field names

Thanks mihail lukin for the report and elijahr and kgibula for their
work on the patch.
parent 10d32072
Loading
Loading
Loading
Loading
+78 −39
Original line number Diff line number Diff line
from __future__ import unicode_literals

import keyword
from optparse import make_option

@@ -31,6 +33,7 @@ class Command(NoArgsCommand):
        table_name_filter = options.get('table_name_filter')

        table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
        strip_prefix = lambda s: s.startswith("u'") and s[1:] or s

        cursor = connection.cursor()
        yield "# This is an auto-generated Django model module."
@@ -41,6 +44,7 @@ class Command(NoArgsCommand):
        yield "#"
        yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
        yield "# into your database."
        yield "from __future__ import unicode_literals"
        yield ''
        yield 'from %s import models' % self.db_module
        yield ''
@@ -59,16 +63,19 @@ class Command(NoArgsCommand):
                indexes = connection.introspection.get_indexes(cursor, table_name)
            except NotImplementedError:
                indexes = {}
            used_column_names = [] # Holds column names used in the table so far
            for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
                column_name = row[0]
                att_name = column_name.lower()
                comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
                extra_params = {}  # Holds Field parameters such as 'db_column'.
                column_name = row[0]
                is_relation = i in relations

                att_name, params, notes = self.normalize_col_name(
                    column_name, used_column_names, is_relation)
                extra_params.update(params)
                comment_notes.extend(notes)

                # If the column name can't be used verbatim as a Python
                # attribute, set the "db_column" for this Field.
                if ' ' in att_name or '-' in att_name or keyword.iskeyword(att_name) or column_name != att_name:
                    extra_params['db_column'] = column_name
                used_column_names.append(att_name)

                # Add primary_key and unique, if necessary.
                if column_name in indexes:
@@ -77,30 +84,12 @@ class Command(NoArgsCommand):
                    elif indexes[column_name]['unique']:
                        extra_params['unique'] = True

                # Modify the field name to make it Python-compatible.
                if ' ' in att_name:
                    att_name = att_name.replace(' ', '_')
                    comment_notes.append('Field renamed to remove spaces.')

                if '-' in att_name:
                    att_name = att_name.replace('-', '_')
                    comment_notes.append('Field renamed to remove dashes.')

                if column_name != att_name:
                    comment_notes.append('Field name made lowercase.')

                if i in relations:
                if is_relation:
                    rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1])

                    if rel_to in known_models:
                        field_type = 'ForeignKey(%s' % rel_to
                    else:
                        field_type = "ForeignKey('%s'" % rel_to

                    if att_name.endswith('_id'):
                        att_name = att_name[:-3]
                    else:
                        extra_params['db_column'] = column_name
                else:
                    # Calling `get_field_type` to get the field type string and any
                    # additional paramters and notes.
@@ -110,16 +99,6 @@ class Command(NoArgsCommand):

                    field_type += '('

                if keyword.iskeyword(att_name):
                    att_name += '_field'
                    comment_notes.append('Field renamed because it was a Python reserved word.')

                if att_name[0].isdigit():
                    att_name = 'number_%s' % att_name
                    extra_params['db_column'] = six.text_type(column_name)
                    comment_notes.append("Field renamed because it wasn't a "
                        "valid Python identifier.")

                # Don't output 'id = meta.AutoField(primary_key=True)', because
                # that's assumed if it doesn't exist.
                if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}:
@@ -136,7 +115,9 @@ class Command(NoArgsCommand):
                if extra_params:
                    if not field_desc.endswith('('):
                        field_desc += ', '
                    field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()])
                    field_desc += ', '.join([
                        '%s=%s' % (k, strip_prefix(repr(v)))
                        for k, v in extra_params.items()])
                field_desc += ')'
                if comment_notes:
                    field_desc += ' # ' + ' '.join(comment_notes)
@@ -144,6 +125,64 @@ class Command(NoArgsCommand):
            for meta_line in self.get_meta(table_name):
                yield meta_line

    def normalize_col_name(self, col_name, used_column_names, is_relation):
        """
        Modify the column name to make it Python-compatible as a field name
        """
        field_params = {}
        field_notes = []

        new_name = col_name.lower()
        if new_name != col_name:
            field_notes.append('Field name made lowercase.')

        if is_relation:
            if new_name.endswith('_id'):
                new_name = new_name[:-3]
            else:
                field_params['db_column'] = col_name

        if ' ' in new_name:
            new_name = new_name.replace(' ', '_')
            field_notes.append('Field renamed to remove spaces.')

        if '-' in new_name:
            new_name = new_name.replace('-', '_')
            field_notes.append('Field renamed to remove dashes.')

        if new_name.find('__') >= 0:
            while new_name.find('__') >= 0:
                new_name = new_name.replace('__', '_')
            field_notes.append("Field renamed because it contained more than one '_' in a row.")

        if new_name.startswith('_'):
            new_name = 'field%s' % new_name
            field_notes.append("Field renamed because it started with '_'.")

        if new_name.endswith('_'):
            new_name = '%sfield' % new_name
            field_notes.append("Field renamed because it ended with '_'.")

        if keyword.iskeyword(new_name):
            new_name += '_field'
            field_notes.append('Field renamed because it was a Python reserved word.')

        if new_name[0].isdigit():
            new_name = 'number_%s' % new_name
            field_notes.append("Field renamed because it wasn't a valid Python identifier.")

        if new_name in used_column_names:
            num = 0
            while '%s_%d' % (new_name, num) in used_column_names:
                num += 1
            new_name = '%s_%d' % (new_name, num)
            field_notes.append('Field renamed because of name conflict.')

        if col_name != new_name and field_notes:
            field_params['db_column'] = col_name

        return new_name, field_params, field_notes

    def get_field_type(self, connection, table_name, row):
        """
        Given the database connection, the table name, and the cursor row
@@ -181,6 +220,6 @@ class Command(NoArgsCommand):
        to construct the inner Meta class for the model corresponding
        to the given database table name.
        """
        return ['    class Meta:',
                '        db_table = %r' % table_name,
                '']
        return ["    class Meta:",
                "        db_table = '%s'" % table_name,
                ""]
+6 −0
Original line number Diff line number Diff line
@@ -19,3 +19,9 @@ class DigitsInColumnName(models.Model):
    all_digits = models.CharField(max_length=11, db_column='123')
    leading_digit = models.CharField(max_length=11, db_column='4extra')
    leading_digits = models.CharField(max_length=11, db_column='45extra')

class UnderscoresInColumnName(models.Model):
    field = models.IntegerField(db_column='field')
    field_field_0 = models.IntegerField(db_column='Field_')
    field_field_1 = models.IntegerField(db_column='Field__')
    field_field_2 = models.IntegerField(db_column='__field')
+25 −12
Original line number Diff line number Diff line
from __future__ import unicode_literals

from django.core.management import call_command
from django.test import TestCase, skipUnlessDBFeature
from django.utils.six import StringIO
@@ -17,7 +19,6 @@ class InspectDBTestCase(TestCase):
        # the Django test suite, check that one of its tables hasn't been
        # inspected
        self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message)
        out.close()

    @skipUnlessDBFeature('can_introspect_foreign_keys')
    def test_attribute_name_not_python_keyword(self):
@@ -27,15 +28,16 @@ class InspectDBTestCase(TestCase):
        call_command('inspectdb',
                     table_name_filter=lambda tn:tn.startswith('inspectdb_'),
                     stdout=out)
        output = out.getvalue()
        error_message = "inspectdb generated an attribute name which is a python keyword"
        self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", out.getvalue(), msg=error_message)
        self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", output, msg=error_message)
        # As InspectdbPeople model is defined after InspectdbMessage, it should be quoted
        self.assertIn("from_field = models.ForeignKey('InspectdbPeople')", out.getvalue())
        self.assertIn("from_field = models.ForeignKey('InspectdbPeople', db_column='from_id')",
            output)
        self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)",
            out.getvalue())
            output)
        self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)",
            out.getvalue())
        out.close()
            output)

    def test_digits_column_name_introspection(self):
        """Introspection of column names consist/start with digits (#16536/#17676)"""
@@ -45,13 +47,24 @@ class InspectDBTestCase(TestCase):
        call_command('inspectdb',
                     table_name_filter=lambda tn:tn.startswith('inspectdb_'),
                     stdout=out)
        output = out.getvalue()
        error_message = "inspectdb generated a model field name which is a number"
        self.assertNotIn("    123 = models.CharField", out.getvalue(), msg=error_message)
        self.assertIn("number_123 = models.CharField", out.getvalue())
        self.assertNotIn("    123 = models.CharField", output, msg=error_message)
        self.assertIn("number_123 = models.CharField", output)

        error_message = "inspectdb generated a model field name which starts with a digit"
        self.assertNotIn("    4extra = models.CharField", out.getvalue(), msg=error_message)
        self.assertIn("number_4extra = models.CharField", out.getvalue())
        self.assertNotIn("    4extra = models.CharField", output, msg=error_message)
        self.assertIn("number_4extra = models.CharField", output)

        self.assertNotIn("    45extra = models.CharField", output, msg=error_message)
        self.assertIn("number_45extra = models.CharField", output)

        self.assertNotIn("    45extra = models.CharField", out.getvalue(), msg=error_message)
        self.assertIn("number_45extra = models.CharField", out.getvalue())
    def test_underscores_column_name_introspection(self):
        """Introspection of column names containing underscores (#12460)"""
        out = StringIO()
        call_command('inspectdb', stdout=out)
        output = out.getvalue()
        self.assertIn("field = models.IntegerField()", output)
        self.assertIn("field_field = models.IntegerField(db_column='Field_')", output)
        self.assertIn("field_field_0 = models.IntegerField(db_column='Field__')", output)
        self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)