Commit ab89414f authored by Anssi Kääriäinen's avatar Anssi Kääriäinen Committed by Tim Graham
Browse files

Fixed #23853 -- Added Join class to replace JoinInfo

Also removed Query.join_map. This structure was used to speed up join
reuse calculation. Initial benchmarking shows that this isn't actually
needed. If there are use cases where the removal has real-world
performance implications, it should be relatively straightforward to
reintroduce it as map {alias: [Join-like objects]}.
parent c7175fcd
Loading
Loading
Loading
Loading
+10 −35
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ class SQLCompiler(object):
        # cleaned. We are not using a clone() of the query here.
        """
        if not self.query.tables:
            self.query.join((None, self.query.get_meta().db_table, None))
            self.query.get_initial_alias()
        if (not self.query.select and self.query.default_cols and not
                self.query.included_inherited_models):
            self.query.setup_inherited_models()
@@ -171,7 +171,6 @@ class SQLCompiler(object):

        # Finally do cleanup - get rid of the joins we created above.
        self.query.reset_refcounts(refcounts_before)

        return ' '.join(result), tuple(params)

    def as_nested_sql(self):
@@ -511,51 +510,27 @@ class SQLCompiler(object):
        ordering and distinct must be done first.
        """
        result = []
        qn = self.quote_name_unless_alias
        qn2 = self.connection.ops.quote_name
        first = True
        from_params = []
        params = []
        for alias in self.query.tables:
            if not self.query.alias_refcount[alias]:
                continue
            try:
                name, alias, join_type, lhs, join_cols, _, join_field = self.query.alias_map[alias]
                from_clause = self.query.alias_map[alias]
            except KeyError:
                # Extra tables can end up in self.tables, but not in the
                # alias_map if they aren't in a join. That's OK. We skip them.
                continue
            alias_str = '' if alias == name else (' %s' % alias)
            if join_type and not first:
                extra_cond = join_field.get_extra_restriction(
                    self.query.where_class, alias, lhs)
                if extra_cond:
                    extra_sql, extra_params = self.compile(extra_cond)
                    extra_sql = 'AND (%s)' % extra_sql
                    from_params.extend(extra_params)
                else:
                    extra_sql = ""
                result.append('%s %s%s ON ('
                        % (join_type, qn(name), alias_str))
                for index, (lhs_col, rhs_col) in enumerate(join_cols):
                    if index != 0:
                        result.append(' AND ')
                    result.append('%s.%s = %s.%s' %
                    (qn(lhs), qn2(lhs_col), qn(alias), qn2(rhs_col)))
                result.append('%s)' % extra_sql)
            else:
                connector = '' if first else ', '
                result.append('%s%s%s' % (connector, qn(name), alias_str))
            first = False
            clause_sql, clause_params = self.compile(from_clause)
            result.append(clause_sql)
            params.extend(clause_params)
        for t in self.query.extra_tables:
            alias, _ = self.query.table_alias(t)
            # Only add the alias if it's not already present (the table_alias()
            # calls increments the refcount, so an alias refcount of one means
            # this is the only reference.
            # call increments the refcount, so an alias refcount of one means
            # this is the only reference).
            if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1:
                connector = '' if first else ', '
                result.append('%s%s' % (connector, qn(alias)))
                first = False
        return result, from_params
                result.append(', %s' % self.quote_name_unless_alias(alias))
        return result, params

    def get_grouping(self, having_group_by, ordering_group_by):
        """
+4 −6
Original line number Diff line number Diff line
@@ -21,12 +21,6 @@ GET_ITERATOR_CHUNK_SIZE = 100

# Namedtuples for sql.* internal use.

# Join lists (indexes into the tuples that are values in the alias_map
# dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo',
                      'table_name rhs_alias join_type lhs_alias '
                      'join_cols nullable join_field')

# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')

@@ -41,3 +35,7 @@ ORDER_DIR = {
    'ASC': ('ASC', 'DESC'),
    'DESC': ('DESC', 'ASC'),
}

# SQL join types.
INNER = 'INNER JOIN'
LOUTER = 'LEFT OUTER JOIN'
+117 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
Useful auxiliary data structures for query construction. Not useful outside
the SQL domain.
"""
from django.db.models.sql.constants import INNER, LOUTER


class EmptyResultSet(Exception):
@@ -22,3 +23,119 @@ class MultiJoin(Exception):

class Empty(object):
    pass


class Join(object):
    """
    Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
    FROM entry. For example, the SQL generated could be
        LEFT OUTER JOIN "sometable" T1 ON ("othertable"."sometable_id" = "sometable"."id")

    This class is primarily used in Query.alias_map. All entries in alias_map
    must be Join compatible by providing the following attributes and methods:
        - table_name (string)
        - table_alias (possible alias for the table, can be None)
        - join_type (can be None for those entries that aren't joined from
          anything)
        - parent_alias (which table is this join's parent, can be None similarly
          to join_type)
        - as_sql()
        - relabeled_clone()

    """
    def __init__(self, table_name, parent_alias, table_alias, join_type,
                 join_field, nullable):
        # Join table
        self.table_name = table_name
        self.parent_alias = parent_alias
        # Note: table_alias is not necessarily known at instantiation time.
        self.table_alias = table_alias
        # LOUTER or INNER
        self.join_type = join_type
        # A list of 2-tuples to use in the ON clause of the JOIN.
        # Each 2-tuple will create one join condition in the ON clause.
        self.join_cols = join_field.get_joining_columns()
        # Along which field (or RelatedObject in the reverse join case)
        self.join_field = join_field
        # Is this join nullabled?
        self.nullable = nullable

    def as_sql(self, compiler, connection):
        """
        Generates the full
           LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
        clause for this join.
        """
        params = []
        sql = []
        alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
        qn = compiler.quote_name_unless_alias
        qn2 = connection.ops.quote_name
        sql.append('%s %s%s ON (' % (self.join_type, qn(self.table_name), alias_str))
        for index, (lhs_col, rhs_col) in enumerate(self.join_cols):
            if index != 0:
                sql.append(' AND ')
            sql.append('%s.%s = %s.%s' % (
                qn(self.parent_alias),
                qn2(lhs_col),
                qn(self.table_alias),
                qn2(rhs_col),
            ))
        extra_cond = self.join_field.get_extra_restriction(
            compiler.query.where_class, self.table_alias, self.parent_alias)
        if extra_cond:
            extra_sql, extra_params = compiler.compile(extra_cond)
            extra_sql = 'AND (%s)' % extra_sql
            params.extend(extra_params)
            sql.append('%s' % extra_sql)
        sql.append(')')
        return ' '.join(sql), params

    def relabeled_clone(self, change_map):
        new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
        new_table_alias = change_map.get(self.table_alias, self.table_alias)
        return self.__class__(
            self.table_name, new_parent_alias, new_table_alias, self.join_type,
            self.join_field, self.nullable)

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return (
                self.table_name == other.table_name and
                self.parent_alias == other.parent_alias and
                self.join_field == other.join_field
            )
        return False

    def demote(self):
        new = self.relabeled_clone({})
        new.join_type = INNER
        return new

    def promote(self):
        new = self.relabeled_clone({})
        new.join_type = LOUTER
        return new


class BaseTable(object):
    """
    The BaseTable class is used for base table references in FROM clause. For
    example, the SQL "foo" in
        SELECT * FROM "foo" WHERE somecond
    could be generated by this class.
    """
    join_type = None
    parent_alias = None

    def __init__(self, table_name, alias):
        self.table_name = table_name
        self.table_alias = alias

    def as_sql(self, compiler, connection):
        alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
        base_sql = compiler.quote_name_unless_alias(self.table_name)
        return base_sql + alias_str, []

    def relabeled_clone(self, change_map):
        return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
+53 −86
Original line number Diff line number Diff line
@@ -20,8 +20,9 @@ from django.db.models.query_utils import Q, refs_aggregate
from django.db.models.related import PathInfo
from django.db.models.aggregates import Count
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
        ORDER_PATTERN, JoinInfo, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
        ORDER_PATTERN, SelectInfo, INNER, LOUTER)
from django.db.models.sql.datastructures import (
    EmptyResultSet, Empty, MultiJoin, Join, BaseTable)
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
    ExtraWhere, AND, OR, EmptyWhere)
from django.utils import six
@@ -87,10 +88,6 @@ class Query(object):
    """
    A single SQL query.
    """
    # SQL join types. These are part of the class because their string forms
    # vary from database to database and can be customised by a subclass.
    INNER = 'INNER JOIN'
    LOUTER = 'LEFT OUTER JOIN'

    alias_prefix = 'T'
    subq_aliases = frozenset([alias_prefix])
@@ -103,15 +100,15 @@ class Query(object):
        self.alias_refcount = {}
        # alias_map is the most important data structure regarding joins.
        # It's used for recording which joins exist in the query and what
        # type they are. The key is the alias of the joined table (possibly
        # the table name) and the value is JoinInfo from constants.py.
        # types they are. The key is the alias of the joined table (possibly
        # the table name) and the value is a Join-like object (see
        # sql.datastructures.Join for more information).
        self.alias_map = {}
        # Sometimes the query contains references to aliases in outer queries (as
        # a result of split_exclude). Correct alias quoting needs to know these
        # aliases too.
        self.external_aliases = set()
        self.table_map = {}     # Maps table names to list of aliases.
        self.join_map = {}
        self.default_cols = True
        self.default_ordering = True
        self.standard_ordering = True
@@ -246,7 +243,6 @@ class Query(object):
        obj.alias_map = self.alias_map.copy()
        obj.external_aliases = self.external_aliases.copy()
        obj.table_map = self.table_map.copy()
        obj.join_map = self.join_map.copy()
        obj.default_cols = self.default_cols
        obj.default_ordering = self.default_ordering
        obj.standard_ordering = self.standard_ordering
@@ -495,19 +491,17 @@ class Query(object):
        self.get_initial_alias()
        joinpromoter = JoinPromoter(connector, 2, False)
        joinpromoter.add_votes(
            j for j in self.alias_map if self.alias_map[j].join_type == self.INNER)
            j for j in self.alias_map if self.alias_map[j].join_type == INNER)
        rhs_votes = set()
        # Now, add the joins from rhs query into the new query (skipping base
        # table).
        for alias in rhs.tables[1:]:
            table, _, join_type, lhs, join_cols, nullable, join_field = rhs.alias_map[alias]
            join = rhs.alias_map[alias]
            # If the left side of the join was already relabeled, use the
            # updated alias.
            lhs = change_map.get(lhs, lhs)
            new_alias = self.join(
                (lhs, table, join_cols), reuse=reuse,
                nullable=nullable, join_field=join_field)
            if join_type == self.INNER:
            join = join.relabeled_clone(change_map)
            new_alias = self.join(join, reuse=reuse)
            if join.join_type == INNER:
                rhs_votes.add(new_alias)
            # We can't reuse the same join again in the query. If we have two
            # distinct joins for the same connection in rhs query, then the
@@ -714,27 +708,26 @@ class Query(object):
        aliases = list(aliases)
        while aliases:
            alias = aliases.pop(0)
            if self.alias_map[alias].join_cols[0][1] is None:
            if self.alias_map[alias].join_type is None:
                # This is the base table (first FROM entry) - this table
                # isn't really joined at all in the query, so we should not
                # alter its join type.
                continue
            # Only the first alias (skipped above) should have None join_type
            assert self.alias_map[alias].join_type is not None
            parent_alias = self.alias_map[alias].lhs_alias
            parent_alias = self.alias_map[alias].parent_alias
            parent_louter = (
                parent_alias
                and self.alias_map[parent_alias].join_type == self.LOUTER)
            already_louter = self.alias_map[alias].join_type == self.LOUTER
                and self.alias_map[parent_alias].join_type == LOUTER)
            already_louter = self.alias_map[alias].join_type == LOUTER
            if ((self.alias_map[alias].nullable or parent_louter) and
                    not already_louter):
                data = self.alias_map[alias]._replace(join_type=self.LOUTER)
                self.alias_map[alias] = data
                self.alias_map[alias] = self.alias_map[alias].promote()
                # Join type of 'alias' changed, so re-examine all aliases that
                # refer to this one.
                aliases.extend(
                    join for join in self.alias_map.keys()
                    if (self.alias_map[join].lhs_alias == alias
                    if (self.alias_map[join].parent_alias == alias
                        and join not in aliases))

    def demote_joins(self, aliases):
@@ -750,10 +743,10 @@ class Query(object):
        aliases = list(aliases)
        while aliases:
            alias = aliases.pop(0)
            if self.alias_map[alias].join_type == self.LOUTER:
                self.alias_map[alias] = self.alias_map[alias]._replace(join_type=self.INNER)
                parent_alias = self.alias_map[alias].lhs_alias
                if self.alias_map[parent_alias].join_type == self.INNER:
            if self.alias_map[alias].join_type == LOUTER:
                self.alias_map[alias] = self.alias_map[alias].demote()
                parent_alias = self.alias_map[alias].parent_alias
                if self.alias_map[parent_alias].join_type == INNER:
                    aliases.append(parent_alias)

    def reset_refcounts(self, to_counts):
@@ -792,19 +785,13 @@ class Query(object):
                (key, relabel_column(col)) for key, col in self._annotations.items())

        # 2. Rename the alias in the internal table/alias datastructures.
        for ident, aliases in self.join_map.items():
            del self.join_map[ident]
            aliases = tuple(change_map.get(a, a) for a in aliases)
            ident = (change_map.get(ident[0], ident[0]),) + ident[1:]
            self.join_map[ident] = aliases
        for old_alias, new_alias in six.iteritems(change_map):
            alias_data = self.alias_map.get(old_alias)
            if alias_data is None:
            if old_alias not in self.alias_map:
                continue
            alias_data = alias_data._replace(rhs_alias=new_alias)
            alias_data = self.alias_map[old_alias].relabeled_clone(change_map)
            self.alias_map[new_alias] = alias_data
            self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
            del self.alias_refcount[old_alias]
            self.alias_map[new_alias] = alias_data
            del self.alias_map[old_alias]

            table_aliases = self.table_map[alias_data.table_name]
@@ -819,14 +806,6 @@ class Query(object):
        for key, alias in self.included_inherited_models.items():
            if alias in change_map:
                self.included_inherited_models[key] = change_map[alias]

        # 3. Update any joins that refer to the old alias.
        for alias, data in six.iteritems(self.alias_map):
            lhs = data.lhs_alias
            if lhs in change_map:
                data = data._replace(lhs_alias=change_map[lhs])
                self.alias_map[alias] = data

        self.external_aliases = {change_map.get(alias, alias)
                                 for alias in self.external_aliases}

@@ -862,7 +841,7 @@ class Query(object):
            alias = self.tables[0]
            self.ref_alias(alias)
        else:
            alias = self.join((None, self.get_meta().db_table, None))
            alias = self.join(BaseTable(self.get_meta().db_table, None))
        return alias

    def count_active_tables(self):
@@ -873,7 +852,7 @@ class Query(object):
        """
        return len([1 for count in self.alias_refcount.values() if count])

    def join(self, connection, reuse=None, nullable=False, join_field=None):
    def join(self, join, reuse=None):
        """
        Returns an alias for the join in 'connection', either reusing an
        existing alias for that join or creating a new one. 'connection' is a
@@ -897,40 +876,22 @@ class Query(object):

        The 'join_field' is the field we are joining along (if any).
        """
        lhs, table, join_cols = connection
        assert lhs is None or join_field is not None
        existing = self.join_map.get(connection, ())
        if reuse is None:
            reuse = existing
        else:
            reuse = [a for a in existing if a in reuse]
        for alias in reuse:
            if join_field and self.alias_map[alias].join_field != join_field:
                # The join_map doesn't contain join_field (mainly because
                # fields in Query structs are problematic in pickling), so
                # check that the existing join is created using the same
                # join_field used for the under work join.
                continue
            self.ref_alias(alias)
            return alias
        reuse = [a for a, j in self.alias_map.items()
                 if (reuse is None or a in reuse) and j == join]
        if reuse:
            self.ref_alias(reuse[0])
            return reuse[0]

        # No reuse is possible, so we need a new alias.
        alias, _ = self.table_alias(table, create=True)
        if not lhs:
            # Not all tables need to be joined to anything. No join type
            # means the later columns are ignored.
            join_type = None
        elif self.alias_map[lhs].join_type == self.LOUTER or nullable:
            join_type = self.LOUTER
        alias, _ = self.table_alias(join.table_name, create=True)
        if join.join_type:
            if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
                join_type = LOUTER
            else:
            join_type = self.INNER
        join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable,
                        join_field)
                join_type = INNER
            join.join_type = join_type
        join.table_alias = alias
        self.alias_map[alias] = join
        if connection in self.join_map:
            self.join_map[connection] += (alias,)
        else:
            self.join_map[connection] = (alias,)
        return alias

    def setup_inherited_models(self):
@@ -1249,7 +1210,7 @@ class Query(object):
            require_outer = True
            if (lookup_type != 'isnull' and (
                    self.is_nullable(targets[0]) or
                    self.alias_map[join_list[-1]].join_type == self.LOUTER)):
                    self.alias_map[join_list[-1]].join_type == LOUTER)):
                # The condition added here will be SQL like this:
                # NOT (col IS NOT NULL), where the first NOT is added in
                # upper layers of code. The reason for addition is that if col
@@ -1326,7 +1287,7 @@ class Query(object):
        # rel_a doesn't produce any rows, then the whole condition must fail.
        # So, demotion is OK.
        existing_inner = set(
            (a for a in self.alias_map if self.alias_map[a].join_type == self.INNER))
            (a for a in self.alias_map if self.alias_map[a].join_type == INNER))
        clause, require_inner = self._add_q(where_part, self.used_aliases)
        self.where.add(clause, AND)
        for hp in having_parts:
@@ -1490,10 +1451,9 @@ class Query(object):
                nullable = self.is_nullable(join.join_field)
            else:
                nullable = True
            connection = alias, opts.db_table, join.join_field.get_joining_columns()
            connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable)
            reuse = can_reuse if join.m2m else None
            alias = self.join(
                connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
            alias = self.join(connection, reuse=reuse)
            joins.append(alias)
        if hasattr(final_field, 'field'):
            final_field = final_field.field
@@ -1991,9 +1951,10 @@ class Query(object):
        for trimmed_paths, path in enumerate(all_paths):
            if path.m2m:
                break
            if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == self.LOUTER:
            if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == LOUTER:
                contains_louter = True
            self.unref_alias(lookup_tables[trimmed_paths])
            alias = lookup_tables[trimmed_paths]
            self.unref_alias(alias)
        # The path.join_field is a Rel, lets get the other side's field
        join_field = path.join_field.field
        # Build the filter prefix.
@@ -2010,7 +1971,7 @@ class Query(object):
        # Lets still see if we can trim the first join from the inner query
        # (that is, self). We can't do this for LEFT JOINs because we would
        # miss those rows that have nothing on the outer side.
        if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type != self.LOUTER:
        if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type != LOUTER:
            select_fields = [r[0] for r in join_field.related_fields]
            select_alias = lookup_tables[trimmed_paths + 1]
            self.unref_alias(lookup_tables[trimmed_paths])
@@ -2024,6 +1985,12 @@ class Query(object):
            # values in select_fields. Lets punt this one for now.
            select_fields = [r[1] for r in join_field.related_fields]
            select_alias = lookup_tables[trimmed_paths]
        # The found starting point is likely a Join instead of a BaseTable reference.
        # But the first entry in the query's FROM clause must not be a JOIN.
        for table in self.tables:
            if self.alias_refcount[table] > 0:
                self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table)
                break
        self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields]
        return trimmed_prefix, contains_louter

+5 −4
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from django.core.exceptions import FieldError
from django.db import connection, DEFAULT_DB_ALIAS
from django.db.models import Count, F, Q
from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode
from django.db.models.sql.constants import LOUTER
from django.db.models.sql.datastructures import EmptyResultSet
from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
@@ -128,7 +129,7 @@ class Queries1Tests(BaseQuerysetTest):
    def test_ticket2306(self):
        # Checking that no join types are "left outer" joins.
        query = Item.objects.filter(tags=self.t2).query
        self.assertNotIn(query.LOUTER, [x[2] for x in query.alias_map.values()])
        self.assertNotIn(LOUTER, [x.join_type for x in query.alias_map.values()])

        self.assertQuerysetEqual(
            Item.objects.filter(Q(tags=self.t1)).order_by('name'),
@@ -336,7 +337,7 @@ class Queries1Tests(BaseQuerysetTest):

        # Excluding from a relation that cannot be NULL should not use outer joins.
        query = Item.objects.exclude(creator__in=[self.a1, self.a2]).query
        self.assertNotIn(query.LOUTER, [x[2] for x in query.alias_map.values()])
        self.assertNotIn(LOUTER, [x.join_type for x in query.alias_map.values()])

        # Similarly, when one of the joins cannot possibly, ever, involve NULL
        # values (Author -> ExtraInfo, in the following), it should never be
@@ -344,7 +345,7 @@ class Queries1Tests(BaseQuerysetTest):
        # involve one "left outer" join (Author -> Item is 0-to-many).
        qs = Author.objects.filter(id=self.a1.id).filter(Q(extra__note=self.n1) | Q(item__note=self.n3))
        self.assertEqual(
            len([x[2] for x in qs.query.alias_map.values() if x[2] == query.LOUTER and qs.query.alias_refcount[x[1]]]),
            len([x for x in qs.query.alias_map.values() if x.join_type == LOUTER and qs.query.alias_refcount[x.table_alias]]),
            1
        )

@@ -855,7 +856,7 @@ class Queries1Tests(BaseQuerysetTest):
        )
        q = Note.objects.filter(Q(extrainfo__author=self.a1) | Q(extrainfo=xx)).query
        self.assertEqual(
            len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]),
            len([x for x in q.alias_map.values() if x.join_type == LOUTER and q.alias_refcount[x.table_alias]]),
            1
        )