Commit 4755f8fc authored by Marc Tamlyn's avatar Marc Tamlyn
Browse files

Fixed #24343 -- Ensure db converters are used for foreign keys.

Joint effort between myself, Josh, Anssi and Shai.
parent dbacbc72
Loading
Loading
Loading
Loading
+8 −5
Original line number Diff line number Diff line
@@ -585,10 +585,10 @@ class Random(ExpressionNode):


class Col(ExpressionNode):
    def __init__(self, alias, target, source=None):
        if source is None:
            source = target
        super(Col, self).__init__(output_field=source)
    def __init__(self, alias, target, output_field=None):
        if output_field is None:
            output_field = target
        super(Col, self).__init__(output_field=output_field)
        self.alias, self.target = alias, target

    def __repr__(self):
@@ -606,7 +606,10 @@ class Col(ExpressionNode):
        return [self]

    def get_db_converters(self, connection):
        if self.target == self.output_field:
            return self.output_field.get_db_converters(connection)
        return (self.output_field.get_db_converters(connection) +
                self.target.get_db_converters(connection))


class Ref(ExpressionNode):
+5 −5
Original line number Diff line number Diff line
@@ -330,12 +330,12 @@ class Field(RegisterLookupMixin):
            ]
        return []

    def get_col(self, alias, source=None):
        if source is None:
            source = self
        if alias != self.model._meta.db_table or source != self:
    def get_col(self, alias, output_field=None):
        if output_field is None:
            output_field = self
        if alias != self.model._meta.db_table or output_field != self:
            from django.db.models.expressions import Col
            return Col(alias, self, source)
            return Col(alias, self, output_field)
        else:
            return self.cached_col

+14 −0
Original line number Diff line number Diff line
@@ -2064,6 +2064,20 @@ class ForeignKey(ForeignObject):
    def db_parameters(self, connection):
        return {"type": self.db_type(connection), "check": []}

    def convert_empty_strings(self, value, connection, context):
        if (not value) and isinstance(value, six.string_types):
            return None
        return value

    def get_db_converters(self, connection):
        converters = super(ForeignKey, self).get_db_converters(connection)
        if connection.features.interprets_empty_strings_as_nulls:
            converters += [self.convert_empty_strings]
        return converters

    def get_col(self, alias, output_field=None):
        return super(ForeignKey, self).get_col(alias, output_field or self.related_field)


class OneToOneField(ForeignKey):
    """
+3 −3
Original line number Diff line number Diff line
@@ -57,7 +57,7 @@ class ModelIterator(BaseIterator):
        model_cls = klass_info['model']
        select_fields = klass_info['select_fields']
        model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
        init_list = [f[0].output_field.attname
        init_list = [f[0].target.attname
                     for f in select[model_fields_start:model_fields_end]]
        if len(init_list) != len(model_cls._meta.concrete_fields):
            init_set = set(init_list)
@@ -1618,7 +1618,7 @@ class RelatedPopulator(object):
            self.cols_start = select_fields[0]
            self.cols_end = select_fields[-1] + 1
            self.init_list = [
                f[0].output_field.attname for f in select[self.cols_start:self.cols_end]
                f[0].target.attname for f in select[self.cols_start:self.cols_end]
            ]
            self.reorder_for_init = None
        else:
@@ -1627,7 +1627,7 @@ class RelatedPopulator(object):
            ]
            reorder_map = []
            for idx in select_fields:
                field = select[idx][0].output_field
                field = select[idx][0].target
                init_pos = model_init_attnames.index(field.attname)
                reorder_map.append((init_pos, field.attname, idx))
            reorder_map.sort()
+1 −1
Original line number Diff line number Diff line
@@ -1458,7 +1458,7 @@ class Query(object):
        # database from tripping over IN (...,NULL,...) selects and returning
        # nothing
        col = query.select[0]
        select_field = col.field
        select_field = col.target
        alias = col.alias
        if self.is_nullable(select_field):
            lookup_class = select_field.get_lookup('isnull')
Loading