Commit 3dcd435a authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Fixed #19500 -- Solved a regression in join reuse

The ORM didn't reuse joins for direct foreign key traversals when using
chained filters. For example:
    qs.filter(fk__somefield=1).filter(fk__somefield=2))
produced two joins.

As a bonus, reverse onetoone filters can now reuse joins correctly

The regression was caused by the join() method refactor in commit
68847135

Thanks for Simon Charette for spotting some issues with the first draft
of the patch.
parent c04c03da
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from django.db.backends.util import truncate_name
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
        GET_ITERATOR_CHUNK_SIZE, REUSE_ALL, SelectInfo)
        GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
@@ -317,7 +317,7 @@ class SQLCompiler(object):

        for name in self.query.distinct_fields:
            parts = name.split(LOOKUP_SEP)
            field, col, alias, _, _ = self._setup_joins(parts, opts, None)
            field, col, alias, _, _ = self._setup_joins(parts, opts)
            col, alias = self._final_join_removal(col, alias)
            result.append("%s.%s" % (qn(alias), qn2(col)))
        return result
@@ -450,7 +450,7 @@ class SQLCompiler(object):
        if not alias:
            alias = self.query.get_initial_alias()
        field, target, opts, joins, _ = self.query.setup_joins(
            pieces, opts, alias, REUSE_ALL)
            pieces, opts, alias)
        # We will later on need to promote those joins that were added to the
        # query afresh above.
        joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
@@ -688,7 +688,7 @@ class SQLCompiler(object):
                        int_opts = int_model._meta
                        alias = self.query.join(
                            (alias, int_opts.db_table, lhs_col, int_opts.pk.column),
                            promote=True,
                            promote=True
                        )
                        alias_chain.append(alias)
                alias = self.query.join(
+0 −3
Original line number Diff line number Diff line
@@ -44,6 +44,3 @@ ORDER_DIR = {
    'ASC': ('ASC', 'DESC'),
    'DESC': ('DESC', 'ASC'),
}

# A marker for join-reusability.
REUSE_ALL = object()
+2 −3
Original line number Diff line number Diff line
from django.core.exceptions import FieldError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import FieldDoesNotExist
from django.db.models.sql.constants import REUSE_ALL

class SQLEvaluator(object):
    def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL):
    def __init__(self, expression, query, allow_joins=True, reuse=None):
        self.expression = expression
        self.opts = query.get_meta()
        self.cols = []
@@ -54,7 +53,7 @@ class SQLEvaluator(object):
                    field_list, query.get_meta(),
                    query.get_initial_alias(), self.reuse)
                col, _, join_list = query.trim_joins(source, join_list, path)
                if self.reuse is not None and self.reuse != REUSE_ALL:
                if self.reuse is not None:
                    self.reuse.update(join_list)
                self.cols.append((node, (join_list[-1], col)))
            except FieldDoesNotExist:
+19 −20
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ from django.db.models.fields import FieldDoesNotExist
from django.db.models.loading import get_model
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
        ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo, PathInfo)
        ORDER_PATTERN, JoinInfo, SelectInfo, PathInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@@ -891,7 +891,7 @@ class Query(object):
        """
        return len([1 for count in self.alias_refcount.values() if count])

    def join(self, connection, reuse=REUSE_ALL, promote=False,
    def join(self, connection, reuse=None, promote=False,
             outer_if_first=False, nullable=False, join_field=None):
        """
        Returns an alias for the join in 'connection', either reusing an
@@ -902,10 +902,9 @@ class Query(object):

            lhs.lhs_col = table.col

        The 'reuse' parameter can be used in three ways: it can be REUSE_ALL
        which means all joins (matching the connection) are reusable, it can
        be a set containing the aliases that can be reused, or it can be None
        which means a new join is always created.
        The 'reuse' parameter can be either None which means all joins
        (matching the connection) are reusable, or it can be a set containing
        the aliases that can be reused.

        If 'promote' is True, the join type for the alias will be LOUTER (if
        the alias previously existed, the join type will be promoted from INNER
@@ -926,10 +925,8 @@ class Query(object):
        """
        lhs, table, lhs_col, col = connection
        existing = self.join_map.get(connection, ())
        if reuse == REUSE_ALL:
        if reuse is None:
            reuse = existing
        elif reuse is None:
            reuse = set()
        else:
            reuse = [a for a in existing if a in reuse]
        for alias in reuse:
@@ -1040,7 +1037,7 @@ class Query(object):
            # then we need to explore the joins that are required.

            field, source, opts, join_list, path = self.setup_joins(
                field_list, opts, self.get_initial_alias(), REUSE_ALL)
                field_list, opts, self.get_initial_alias())

            # Process the join chain to see if it can be trimmed
            col, _, join_list = self.trim_joins(source, join_list, path)
@@ -1441,7 +1438,7 @@ class Query(object):
            raise MultiJoin(multijoin_pos + 1)
        return path, final_field, target

    def setup_joins(self, names, opts, alias, can_reuse, allow_many=True,
    def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
                    allow_explicit_fk=False):
        """
        Compute the necessary table joins for the passage through the fields
@@ -1450,9 +1447,9 @@ class Query(object):
        the table to start the joining from.

        The 'can_reuse' defines the reverse foreign key joins we can reuse. It
        can be sql.constants.REUSE_ALL in which case all joins are reusable
        or a set of aliases that can be reused. Note that Non-reverse foreign
        keys are always reusable.
        can be None in which case all joins are reusable or a set of aliases
        that can be reused. Note that non-reverse foreign keys are always
        reusable when using setup_joins().

        If 'allow_many' is False, then any reverse foreign key seen will
        generate a MultiJoin exception.
@@ -1485,8 +1482,9 @@ class Query(object):
            else:
                nullable = True
            connection = alias, opts.db_table, from_field.column, to_field.column
            alias = self.join(connection, reuse=can_reuse, nullable=nullable,
                              join_field=join_field)
            reuse = None if direct or to_field.unique else can_reuse
            alias = self.join(connection, reuse=reuse,
                              nullable=nullable, join_field=join_field)
            joins.append(alias)
        return final_field, target, opts, joins, path

@@ -1643,7 +1641,7 @@ class Query(object):
        try:
            for name in field_names:
                field, target, u2, joins, u3 = self.setup_joins(
                        name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, allow_m2m,
                        name.split(LOOKUP_SEP), opts, alias, None, allow_m2m,
                        True)
                final_alias = joins[-1]
                col = target.column
@@ -1729,7 +1727,8 @@ class Query(object):
        else:
            opts = self.model._meta
            if not self.select:
                count = self.aggregates_module.Count((self.join((None, opts.db_table, None, None)), opts.pk.column),
                count = self.aggregates_module.Count(
                    (self.join((None, opts.db_table, None, None)), opts.pk.column),
                    is_summary=True, distinct=True)
            else:
                # Because of SQL portability issues, multi-column, distinct
@@ -1934,7 +1933,7 @@ class Query(object):
        opts = self.model._meta
        alias = self.get_initial_alias()
        field, col, opts, joins, extra = self.setup_joins(
                start.split(LOOKUP_SEP), opts, alias, REUSE_ALL)
                start.split(LOOKUP_SEP), opts, alias)
        select_col = self.alias_map[joins[1]].lhs_join_col
        select_alias = alias

+0 −1
Original line number Diff line number Diff line
@@ -232,7 +232,6 @@ class DateQuery(Query):
                field_name.split(LOOKUP_SEP),
                self.get_meta(),
                self.get_initial_alias(),
                False
            )
        except FieldError:
            raise FieldDoesNotExist("%s has no field named '%s'" % (
Loading