Commit bb218245 authored by Malcolm Tredinnick's avatar Malcolm Tredinnick
Browse files

Fixed handling of multiple fields in a model pointing to the same related model.

Thanks to ElliotM, mk and oyvind for some excellent test cases for this. Fixed #7110, #7125.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@7778 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent d800c0b0
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -692,6 +692,11 @@ class ForeignKey(RelatedField, Field):
    def contribute_to_class(self, cls, name):
        super(ForeignKey, self).contribute_to_class(cls, name)
        setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
        if isinstance(self.rel.to, basestring):
            target = self.rel.to
        else:
            target = self.rel.to._meta.db_table
        cls._meta.duplicate_targets[self.column] = (target, "o2m")

    def contribute_to_related_class(self, cls, related):
        setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
@@ -826,6 +831,12 @@ class ManyToManyField(RelatedField, Field):
        # Set up the accessor for the m2m table name for the relation
        self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)

        if isinstance(self.rel.to, basestring):
            target = self.rel.to
        else:
            target = self.rel.to._meta.db_table
        cls._meta.duplicate_targets[self.column] = (target, "m2m")

    def contribute_to_related_class(self, cls, related):
        # m2m relations to self do not have a ManyRelatedObjectsDescriptor,
        # as it would be redundant - unless the field is non-symmetrical.
+19 −0
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ class Options(object):
        self.one_to_one_field = None
        self.abstract = False
        self.parents = SortedDict()
        self.duplicate_targets = {}

    def contribute_to_class(self, cls, name):
        from django.db import connection
@@ -115,6 +116,24 @@ class Options(object):
                        auto_created=True)
                model.add_to_class('id', auto)

        # Determine any sets of fields that are pointing to the same targets
        # (e.g. two ForeignKeys to the same remote model). The query
        # construction code needs to know this. At the end of this,
        # self.duplicate_targets will map each duplicate field column to the
        # columns it duplicates.
        collections = {}
        for column, target in self.duplicate_targets.iteritems():
            try:
                collections[target].add(column)
            except KeyError:
                collections[target] = set([column])
        self.duplicate_targets = {}
        for elt in collections.itervalues():
            if len(elt) == 1:
                continue
            for column in elt:
                self.duplicate_targets[column] = elt.difference(set([column]))

    def add_field(self, field):
        # Insert the given field in the order in which it was created, using
        # the "creation_counter" attribute of the field.
+87 −12
Original line number Diff line number Diff line
@@ -57,6 +57,7 @@ class Query(object):
        self.start_meta = None
        self.select_fields = []
        self.related_select_fields = []
        self.dupe_avoidance = {}

        # SQL-related attributes
        self.select = []
@@ -165,6 +166,7 @@ class Query(object):
        obj.start_meta = self.start_meta
        obj.select_fields = self.select_fields[:]
        obj.related_select_fields = self.related_select_fields[:]
        obj.dupe_avoidance = self.dupe_avoidance.copy()
        obj.select = self.select[:]
        obj.tables = self.tables[:]
        obj.where = deepcopy(self.where)
@@ -830,8 +832,8 @@ class Query(object):

        if reuse and always_create and table in self.table_map:
            # Convert the 'reuse' to case to be "exclude everything but the
            # reusable set for this table".
            exclusions = set(self.table_map[table]).difference(reuse)
            # reusable set, minus exclusions, for this table".
            exclusions = set(self.table_map[table]).difference(reuse).union(set(exclusions))
            always_create = False
        t_ident = (lhs_table, table, lhs_col, col)
        if not always_create:
@@ -866,7 +868,8 @@ class Query(object):
        return alias

    def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
            used=None, requested=None, restricted=None, nullable=None):
            used=None, requested=None, restricted=None, nullable=None,
            dupe_set=None):
        """
        Fill in the information needed for a select_related query. The current
        depth is measured as the number of connections away from the root model
@@ -876,6 +879,7 @@ class Query(object):
        if not restricted and self.max_depth and cur_depth > self.max_depth:
            # We've recursed far enough; bail out.
            return

        if not opts:
            opts = self.get_meta()
            root_alias = self.get_initial_alias()
@@ -883,6 +887,10 @@ class Query(object):
            self.related_select_fields = []
        if not used:
            used = set()
        if dupe_set is None:
            dupe_set = set()
        orig_dupe_set = dupe_set
        orig_used = used

        # Setup for the case when only particular related fields should be
        # included in the related selection.
@@ -897,6 +905,8 @@ class Query(object):
            if (not f.rel or (restricted and f.name not in requested) or
                    (not restricted and f.null) or f.rel.parent_link):
                continue
            dupe_set = orig_dupe_set.copy()
            used = orig_used.copy()
            table = f.rel.to._meta.db_table
            if nullable or f.null:
                promote = True
@@ -907,12 +917,26 @@ class Query(object):
                alias = root_alias
                for int_model in opts.get_base_chain(model):
                    lhs_col = int_opts.parents[int_model].column
                    dedupe = lhs_col in opts.duplicate_targets
                    if dedupe:
                        used.update(self.dupe_avoidance.get(id(opts), lhs_col),
                                ())
                        dupe_set.add((opts, lhs_col))
                    int_opts = int_model._meta
                    alias = self.join((alias, int_opts.db_table, lhs_col,
                            int_opts.pk.column), exclusions=used,
                            promote=promote)
                    for (dupe_opts, dupe_col) in dupe_set:
                        self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
            else:
                alias = root_alias

            dedupe = f.column in opts.duplicate_targets
            if dupe_set or dedupe:
                used.update(self.dupe_avoidance.get((id(opts), f.column), ()))
                if dedupe:
                    dupe_set.add((opts, f.column))

            alias = self.join((alias, table, f.column,
                    f.rel.get_related_field().column), exclusions=used,
                    promote=promote)
@@ -928,8 +952,10 @@ class Query(object):
                new_nullable = f.null
            else:
                new_nullable = None
            for dupe_opts, dupe_col in dupe_set:
                self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
                    used, next, restricted, new_nullable)
                    used, next, restricted, new_nullable, dupe_set)

    def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
            can_reuse=None):
@@ -1128,7 +1154,9 @@ class Query(object):
        (which gives the table we are joining to), 'alias' is the alias for the
        table we are joining to. If dupe_multis is True, any many-to-many or
        many-to-one joins will always create a new alias (necessary for
        disjunctive filters).
        disjunctive filters). If can_reuse is not None, it's a list of aliases
        that can be reused in these joins (nothing else can be reused in this
        case).

        Returns the final field involved in the join, the target database
        column (used for any 'where' constraint), the final 'opts' value and the
@@ -1136,7 +1164,14 @@ class Query(object):
        """
        joins = [alias]
        last = [0]
        dupe_set = set()
        exclusions = set()
        for pos, name in enumerate(names):
            try:
                exclusions.add(int_alias)
            except NameError:
                pass
            exclusions.add(alias)
            last.append(len(joins))
            if name == 'pk':
                name = opts.pk.name
@@ -1155,6 +1190,7 @@ class Query(object):
                    names = opts.get_all_field_names()
                    raise FieldError("Cannot resolve keyword %r into field. "
                            "Choices are: %s" % (name, ", ".join(names)))

            if not allow_many and (m2m or not direct):
                for alias in joins:
                    self.unref_alias(alias)
@@ -1164,12 +1200,27 @@ class Query(object):
                alias_list = []
                for int_model in opts.get_base_chain(model):
                    lhs_col = opts.parents[int_model].column
                    dedupe = lhs_col in opts.duplicate_targets
                    if dedupe:
                        exclusions.update(self.dupe_avoidance.get(
                                (id(opts), lhs_col), ()))
                        dupe_set.add((opts, lhs_col))
                    opts = int_model._meta
                    alias = self.join((alias, opts.db_table, lhs_col,
                            opts.pk.column), exclusions=joins)
                            opts.pk.column), exclusions=exclusions)
                    joins.append(alias)
                    exclusions.add(alias)
                    for (dupe_opts, dupe_col) in dupe_set:
                        self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
            cached_data = opts._join_cache.get(name)
            orig_opts = opts
            dupe_col = direct and field.column or field.field.column
            dedupe = dupe_col in opts.duplicate_targets
            if dupe_set or dedupe:
                if dedupe:
                    dupe_set.add((opts, dupe_col))
                exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col),
                        ()))

            if direct:
                if m2m:
@@ -1191,9 +1242,11 @@ class Query(object):
                                target)

                    int_alias = self.join((alias, table1, from_col1, to_col1),
                            dupe_multis, joins, nullable=True, reuse=can_reuse)
                            dupe_multis, exclusions, nullable=True,
                            reuse=can_reuse)
                    alias = self.join((int_alias, table2, from_col2, to_col2),
                            dupe_multis, joins, nullable=True, reuse=can_reuse)
                            dupe_multis, exclusions, nullable=True,
                            reuse=can_reuse)
                    joins.extend([int_alias, alias])
                elif field.rel:
                    # One-to-one or many-to-one field
@@ -1209,7 +1262,7 @@ class Query(object):
                                opts, target)

                    alias = self.join((alias, table, from_col, to_col),
                            exclusions=joins, nullable=field.null)
                            exclusions=exclusions, nullable=field.null)
                    joins.append(alias)
                else:
                    # Non-relation fields.
@@ -1237,9 +1290,11 @@ class Query(object):
                                target)

                    int_alias = self.join((alias, table1, from_col1, to_col1),
                            dupe_multis, joins, nullable=True, reuse=can_reuse)
                            dupe_multis, exclusions, nullable=True,
                            reuse=can_reuse)
                    alias = self.join((int_alias, table2, from_col2, to_col2),
                            dupe_multis, joins, nullable=True, reuse=can_reuse)
                            dupe_multis, exclusions, nullable=True,
                            reuse=can_reuse)
                    joins.extend([int_alias, alias])
                else:
                    # One-to-many field (ForeignKey defined on the target model)
@@ -1257,14 +1312,34 @@ class Query(object):
                                opts, target)

                    alias = self.join((alias, table, from_col, to_col),
                            dupe_multis, joins, nullable=True, reuse=can_reuse)
                            dupe_multis, exclusions, nullable=True,
                            reuse=can_reuse)
                    joins.append(alias)

            for (dupe_opts, dupe_col) in dupe_set:
                try:
                    self.update_dupe_avoidance(dupe_opts, dupe_col, int_alias)
                except NameError:
                    self.update_dupe_avoidance(dupe_opts, dupe_col, alias)

        if pos != len(names) - 1:
            raise FieldError("Join on field %r not permitted." % name)

        return field, target, opts, joins, last

    def update_dupe_avoidance(self, opts, col, alias):
        """
        For a column that is one of multiple pointing to the same table, update
        the internal data structures to note that this alias shouldn't be used
        for those other columns.
        """
        ident = id(opts)
        for name in opts.duplicate_targets[col]:
            try:
                self.dupe_avoidance[ident, name].add(alias)
            except KeyError:
                self.dupe_avoidance[ident, name] = set([alias])

    def split_exclude(self, filter_expr, prefix):
        """
        When doing an exclude against any kind of N-to-many relation, we need
+40 −0
Original line number Diff line number Diff line
@@ -28,6 +28,24 @@ class Child(models.Model):
    parent = models.ForeignKey(Parent)


# Multiple paths to the same model (#7110, #7125)
class Category(models.Model):
    name = models.CharField(max_length=20)

    def __unicode__(self):
        return self.name

class Record(models.Model):
    category = models.ForeignKey(Category)

class Relation(models.Model):
    left = models.ForeignKey(Record, related_name='left_set')
    right = models.ForeignKey(Record, related_name='right_set')

    def __unicode__(self):
        return u"%s - %s" % (self.left.category.name, self.right.category.name)


__test__ = {'API_TESTS':"""
>>> Third.objects.create(id='3', name='An example')
<Third: Third object>
@@ -73,4 +91,26 @@ Traceback (most recent call last):
    ...
ValueError: Cannot assign "<First: First object>": "Child.parent" must be a "Parent" instance.

# Test of multiple ForeignKeys to the same model (bug #7125)

>>> c1 = Category.objects.create(name='First')
>>> c2 = Category.objects.create(name='Second')
>>> c3 = Category.objects.create(name='Third')
>>> r1 = Record.objects.create(category=c1)
>>> r2 = Record.objects.create(category=c1)
>>> r3 = Record.objects.create(category=c2)
>>> r4 = Record.objects.create(category=c2)
>>> r5 = Record.objects.create(category=c3)
>>> r = Relation.objects.create(left=r1, right=r2)
>>> r = Relation.objects.create(left=r3, right=r4)
>>> r = Relation.objects.create(left=r1, right=r3)
>>> r = Relation.objects.create(left=r5, right=r2)
>>> r = Relation.objects.create(left=r3, right=r2)

>>> Relation.objects.filter(left__category__name__in=['First'], right__category__name__in=['Second'])
[<Relation: First - Second>]

>>> Category.objects.filter(record__left_set__right__category__name='Second').order_by('name')
[<Category: First>, <Category: Second>]

"""}
+0 −0

Empty file added.

Loading