Commit b4492a8c authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Fixed #19837 -- Refactored split_exclude() join generation

The refactoring mainly concentrates on making sure the inner and outer
query agree about the split position. The split position is where the
multijoin happens, and thus the split position also determines the
columns used in the "WHERE col1 IN (SELECT col2 from ...)" condition.

This commit fixes a regression caused by #10790 and commit
69597e5b. The regression was caused
by wrong cols in the split position.
parent ffcfb19f
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -12,8 +12,10 @@ class MultiJoin(Exception):
    multi-valued join was attempted (if the caller wants to treat that
    exceptionally).
    """
    def __init__(self, level):
        self.level = level
    def __init__(self, names_pos, path_with_names):
        self.level = names_pos
        # The path travelled, this includes the path to the multijoin.
        self.names_with_path = path_with_names

class Empty(object):
    pass
+67 −74
Original line number Diff line number Diff line
@@ -1200,7 +1200,7 @@ class Query(object):
                can_reuse.update(join_list)
        except MultiJoin as e:
            self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
                    can_reuse)
                               can_reuse, e.names_with_path)
            return

        if (lookup_type == 'isnull' and value is True and not negate and
@@ -1324,7 +1324,7 @@ class Query(object):
        (the last used join field), and target (which is a field guaranteed to
        contain the same value as the final field).
        """
        path = []
        path, names_with_path = [], []
        for pos, name in enumerate(names):
            if name == 'pk':
                name = opts.pk.name
@@ -1361,16 +1361,17 @@ class Query(object):
                                             opts, final_field, False, True))
            if hasattr(field, 'get_path_info'):
                pathinfos, opts, target, final_field = field.get_path_info()
                if not allow_many:
                    for inner_pos, p in enumerate(pathinfos):
                        if p.m2m:
                            names_with_path.append((name, pathinfos[0:inner_pos + 1]))
                            raise MultiJoin(pos + 1, names_with_path)
                path.extend(pathinfos)
                names_with_path.append((name, pathinfos))
            else:
                # Local non-relational field.
                final_field = target = field
                break
        multijoin_pos = None
        for m2mpos, pathinfo in enumerate(path):
            if pathinfo.m2m:
                multijoin_pos = m2mpos
                break

        if pos != len(names) - 1:
            if pos == len(names) - 2:
@@ -1379,8 +1380,6 @@ class Query(object):
                    "the lookup type?" % (name, names[pos + 1]))
            else:
                raise FieldError("Join on field %r not permitted." % name)
        if multijoin_pos is not None and len(path) >= multijoin_pos and not allow_many:
            raise MultiJoin(multijoin_pos + 1)
        return path, final_field, target

    def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
@@ -1454,7 +1453,7 @@ class Query(object):
                break
        return target.column, joins[-1], joins

    def split_exclude(self, filter_expr, prefix, can_reuse):
    def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
        """
        When doing an exclude against any kind of N-to-many relation, we need
        to use a subquery. This method constructs the nested query, given the
@@ -1462,11 +1461,10 @@ class Query(object):
        N-to-many relation field.

        As an example we could have original filter ~Q(child__name='foo').
        We would get here with filter_expr = child_name, prefix = child and
        can_reuse is a set of joins we can reuse for filtering in the original
        query.
        We would get here with filter_expr = child__name, prefix = child and
        can_reuse is a set of joins usable for filters in the original query.

        We will turn this into
        We will turn this into equivalent of:
            WHERE pk NOT IN (SELECT parent_id FROM thetable
                             WHERE name = 'foo' AND parent_id IS NOT NULL)

@@ -1474,42 +1472,46 @@ class Query(object):
        saner null handling, and is easier for the backend's optimizer to
        handle.
        """
        # Generate the inner query.
        query = Query(self.model)
        query.add_filter(filter_expr)
        query.bump_prefix()
        query.clear_ordering(True)
        query.set_start(prefix)
        # Adding extra check to make sure the selected field will not be null
        # Try to have as simple as possible subquery -> trim leading joins from
        # the subquery.
        trimmed_joins = query.trim_start(names_with_path)
        # Add extra check to make sure the selected field will not be null
        # since we are adding a IN <subquery> clause. This prevents the
        # database from tripping over IN (...,NULL,...) selects and returning
        # nothing
        alias, col = query.select[0].col
        query.where.add((Constraint(alias, col, None), 'isnull', False), AND)
        # We need to trim the last part from the prefix.
        trimmed_prefix = LOOKUP_SEP.join(prefix.split(LOOKUP_SEP)[0:-1])
        if not trimmed_prefix:
            rel, _, direct, m2m = self.model._meta.get_field_by_name(prefix)
            if not m2m:
                trimmed_prefix = rel.field.rel.field_name
            else:
                if direct:
                    trimmed_prefix = rel.m2m_target_field_name()
                else:
                    trimmed_prefix = rel.field.m2m_reverse_target_field_name()

        # Still make sure that the trimmed parts in the inner query and
        # trimmed prefix are in sync. So, use the trimmed_joins to make sure
        # as many path elements are in the prefix as there were trimmed joins.
        # In addition, convert the path elements back to names so that
        # add_filter() can handle them.
        trimmed_prefix = []
        paths_in_prefix = trimmed_joins
        for name, path in names_with_path:
            if paths_in_prefix - len(path) > 0:
                trimmed_prefix.append(name)
                paths_in_prefix -= len(path)
            else:
                trimmed_prefix.append(
                    path[paths_in_prefix - len(path)].from_field.name)
                break
        trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
        self.add_filter(('%s__in' % trimmed_prefix, query), negate=True,
                        can_reuse=can_reuse)

        # If there's more than one join in the inner query (before any initial
        # bits were trimmed -- which means the last active table is more than
        # two places into the alias list), we need to also handle the
        # possibility that the earlier joins don't match anything by adding a
        # comparison to NULL (e.g. in
        # Tag.objects.exclude(parent__parent__name='t1'), a tag with no parent
        # would otherwise be overlooked).
        active_positions = len([count for count
                                in query.alias_refcount.items() if count])
        if active_positions > 1:
        # If there's more than one join in the inner query, we need to also
        # handle the possibility that the earlier joins don't match anything
        # by adding a comparison to NULL (e.g. in
        #     Tag.objects.exclude(parent__parent__name='t1')
        # a tag with no parent would otherwise be overlooked).
        if trimmed_joins > 1:
            self.add_filter(('%s__isnull' % trimmed_prefix, False), negate=True,
                            can_reuse=can_reuse)

@@ -1869,42 +1871,33 @@ class Query(object):
            return self.extra
    extra_select = property(_extra_select)

    def set_start(self, start):
        """
        Sets the table from which to start joining. The start position is
        specified by the related attribute from the base model. This will
        automatically set to the select column to be the column linked from the
        previous table.

        This method is primarily for internal use and the error checking isn't
        as friendly as add_filter(). Mostly useful for querying directly
        against the join table of many-to-many relation in a subquery.
        """
        opts = self.model._meta
        alias = self.get_initial_alias()
        field, col, opts, joins, extra = self.setup_joins(
                start.split(LOOKUP_SEP), opts, alias)
        select_col = self.alias_map[joins[1]].lhs_join_col
        select_alias = alias

        # The call to setup_joins added an extra reference to everything in
        # joins. Reverse that.
        for alias in joins:
            self.unref_alias(alias)

        # We might be able to trim some joins from the front of this query,
        # providing that we only traverse "always equal" connections (i.e. rhs
        # is *always* the same value as lhs).
        for alias in joins[1:]:
            join_info = self.alias_map[alias]
            if (join_info.lhs_join_col != select_col
                    or join_info.join_type != self.INNER):
    def trim_start(self, names_with_path):
        """
        Trims joins from the start of the join path. The candidates for trim
        are the PathInfos in names_with_path structure. Outer joins are not
        eligible for removal. Also sets the select column so the start
        matches the join.

        This method is mostly useful for generating the subquery joins & col
        in "WHERE somecol IN (subquery)". This construct is needed by
        split_exclude().
        _"""
        join_pos = 0
        for _, paths in names_with_path:
            for path in paths:
                peek = self.tables[join_pos + 1]
                if self.alias_map[peek].join_type == self.LOUTER:
                    # Back up one level and break
                    select_alias = self.tables[join_pos]
                    select_col = path.from_field.column
                    break
            self.unref_alias(select_alias)
            select_alias = join_info.rhs_alias
            select_col = join_info.rhs_join_col
                select_alias = self.tables[join_pos + 1]
                select_col = path.to_field.column
                self.unref_alias(self.tables[join_pos])
                join_pos += 1
        self.select = [SelectInfo((select_alias, select_col), None)]
        self.remove_inherited_models()
        return join_pos

    def is_nullable(self, field):
        """
+14 −0
Original line number Diff line number Diff line
@@ -439,3 +439,17 @@ class BaseA(models.Model):
    a = models.ForeignKey(FK1, null=True)
    b = models.ForeignKey(FK2, null=True)
    c = models.ForeignKey(FK3, null=True)

@python_2_unicode_compatible
class Identifier(models.Model):
    name = models.CharField(max_length=100)

    def __str__(self):
        return self.name

class Program(models.Model):
    identifier = models.OneToOneField(Identifier)

class Channel(models.Model):
    programs = models.ManyToManyField(Program)
    identifier = models.OneToOneField(Identifier)
+20 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover,
    Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory,
    SpecialCategory, OneToOneCategory, NullableName, ProxyCategory,
    SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job,
    JobResponsibilities, BaseA)
    JobResponsibilities, BaseA, Identifier, Program, Channel)


class BaseQuerysetTest(TestCase):
@@ -2612,3 +2612,22 @@ class DisjunctionPromotionTests(TestCase):
        qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | (Q(pk=1) & Q(pk=2)))
        self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2)
        self.assertEqual(str(qs.query).count('INNER JOIN'), 0)


class ManyToManyExcludeTest(TestCase):
    def test_exclude_many_to_many(self):
        Identifier.objects.create(name='extra')
        program = Program.objects.create(identifier=Identifier.objects.create(name='program'))
        channel = Channel.objects.create(identifier=Identifier.objects.create(name='channel'))
        channel.programs.add(program)

        # channel contains 'program1', so all Identifiers except that one
        # should be returned
        self.assertQuerysetEqual(
            Identifier.objects.exclude(program__channel=channel).order_by('name'),
            ['<Identifier: channel>', '<Identifier: extra>']
        )
        self.assertQuerysetEqual(
            Identifier.objects.exclude(program__channel=None).order_by('name'),
            ['<Identifier: program>']
        )

tests/tmp.txt

0 → 100644
+1 −0
Original line number Diff line number Diff line
SELECT "queries_tag"."id", "queries_tag"."name", "queries_tag"."parent_id", "queries_tag"."category_id" FROM "queries_tag" WHERE NOT (("queries_tag"."id" IN (SELECT U0."id" FROM "queries_tag" U0 LEFT OUTER JOIN "queries_tag" U1 ON (U0."id" = U1."parent_id") WHERE (U1."id" IS NULL AND U0."id" IS NOT NULL)) AND "queries_tag"."id" IS NOT NULL)) ORDER BY "queries_tag"."name" ASC