Commit 6ff118cd authored by Adrian Holovaty's avatar Adrian Holovaty
Browse files

Fixed #17644 -- Changed Query.alias_map to use namedtuples

This makes the code easier to understand and may even have a benefit in memory usage (namedtuples instead of dicts). Thanks, lrekucki and akaariai
parent af71ce06
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
from itertools import izip
from django.db.backends.util import truncate_name, typecast_timestamp
from django.db.models.sql import compiler
from django.db.models.sql.constants import TABLE_NAME, MULTI
from django.db.models.sql.constants import MULTI

SQLCompiler = compiler.SQLCompiler

@@ -35,7 +35,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
            for col, field in izip(self.query.select, self.query.select_fields):
                if isinstance(col, (list, tuple)):
                    alias, column = col
                    table = self.query.alias_map[alias][TABLE_NAME]
                    table = self.query.alias_map[alias].table_name
                    if table in only_load and column not in only_load[table]:
                        continue
                    r = self.get_field_select(field, alias, column)
@@ -138,7 +138,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
                # aliases will have already been set up in pre_sql_setup(), so
                # we can save time here.
                alias = self.query.included_inherited_models[model]
            table = self.query.alias_map[alias][TABLE_NAME]
            table = self.query.alias_map[alias].table_name
            if table in only_load and field.column not in only_load[table]:
                continue
            if as_pairs:
+9 −9
Original line number Diff line number Diff line
@@ -188,7 +188,7 @@ class SQLCompiler(object):
            for col in self.query.select:
                if isinstance(col, (list, tuple)):
                    alias, column = col
                    table = self.query.alias_map[alias][TABLE_NAME]
                    table = self.query.alias_map[alias].table_name
                    if table in only_load and column not in only_load[table]:
                        continue
                    r = '%s.%s' % (qn(alias), qn(column))
@@ -289,7 +289,7 @@ class SQLCompiler(object):
                # aliases will have already been set up in pre_sql_setup(), so
                # we can save time here.
                alias = self.query.included_inherited_models[model]
            table = self.query.alias_map[alias][TABLE_NAME]
            table = self.query.alias_map[alias].table_name
            if table in only_load and field.column not in only_load[table]:
                continue
            if as_pairs:
@@ -432,7 +432,7 @@ class SQLCompiler(object):
            # Firstly, avoid infinite loops.
            if not already_seen:
                already_seen = set()
            join_tuple = tuple([self.query.alias_map[j][TABLE_NAME] for j in joins])
            join_tuple = tuple([self.query.alias_map[j].table_name for j in joins])
            if join_tuple in already_seen:
                raise FieldError('Infinite loop caused by ordering.')
            already_seen.add(join_tuple)
@@ -470,7 +470,7 @@ class SQLCompiler(object):
        # Ordering or distinct must not affect the returned set, and INNER
        # JOINS for nullable fields could do this.
        self.query.promote_alias_chain(joins,
            self.query.alias_map[joins[0]][JOIN_TYPE] == self.query.LOUTER)
            self.query.alias_map[joins[0]].join_type == self.query.LOUTER)
        return field, col, alias, joins, opts

    def _final_join_removal(self, col, alias):
@@ -485,11 +485,11 @@ class SQLCompiler(object):
        if alias:
            while 1:
                join = self.query.alias_map[alias]
                if col != join[RHS_JOIN_COL]:
                if col != join.rhs_join_col:
                    break
                self.query.unref_alias(alias)
                alias = join[LHS_ALIAS]
                col = join[LHS_JOIN_COL]
                alias = join.lhs_alias
                col = join.lhs_join_col
        return col, alias

    def get_from_clause(self):
@@ -641,7 +641,7 @@ class SQLCompiler(object):
                    alias_chain.append(alias)
                    for (dupe_opts, dupe_col) in dupe_set:
                        self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
                if self.query.alias_map[root_alias][JOIN_TYPE] == self.query.LOUTER:
                if self.query.alias_map[root_alias].join_type == self.query.LOUTER:
                    self.query.promote_alias_chain(alias_chain, True)
            else:
                alias = root_alias
@@ -659,7 +659,7 @@ class SQLCompiler(object):
            columns, aliases = self.get_default_columns(start_alias=alias,
                    opts=f.rel.to._meta, as_pairs=True)
            self.query.related_select_cols.extend(columns)
            if self.query.alias_map[alias][JOIN_TYPE] == self.query.LOUTER:
            if self.query.alias_map[alias].join_type == self.query.LOUTER:
                self.query.promote_alias_chain(aliases, True)
            self.query.related_select_fields.extend(f.rel.to._meta.fields)
            if restricted:
+7 −11
Original line number Diff line number Diff line
from collections import namedtuple
import re

# Valid query types (a dictionary is used for speedy lookups).
@@ -17,13 +18,9 @@ LOOKUP_SEP = '__'
# Constants to make looking up tuple values clearer.
# Join lists (indexes into the tuples that are values in the alias_map
# dictionary in the Query class).
TABLE_NAME = 0
RHS_ALIAS = 1
JOIN_TYPE = 2
LHS_ALIAS = 3
LHS_JOIN_COL = 4
RHS_JOIN_COL = 5
NULLABLE = 6
JoinInfo = namedtuple('JoinInfo',
                      'table_name rhs_alias join_type lhs_alias '
                      'lhs_join_col rhs_join_col nullable')

# How many results to expect from a cursor.execute call
MULTI = 'multi'
@@ -32,6 +29,5 @@ SINGLE = 'single'
ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$')
ORDER_DIR = {
    'ASC': ('ASC', 'DESC'),
    'DESC': ('DESC', 'ASC')}

    'DESC': ('DESC', 'ASC'),
}
+37 −34
Original line number Diff line number Diff line
@@ -101,7 +101,11 @@ class Query(object):
    def __init__(self, model, where=WhereNode):
        self.model = model
        self.alias_refcount = SortedDict()
        self.alias_map = {}     # Maps alias to join information
        # alias_map is the most important data structure regarding joins.
        # It's used for recording which joins exist in the query and what
        # type they are. The key is the alias of the joined table (possibly
        # the table name) and the value is JoinInfo from constants.py.
        self.alias_map = {}
        self.table_map = {}     # Maps table names to list of aliases.
        self.join_map = {}
        self.default_cols = True
@@ -686,11 +690,11 @@ class Query(object):

        Returns True if the join was promoted by this call.
        """
        if ((unconditional or self.alias_map[alias][NULLABLE]) and
                self.alias_map[alias][JOIN_TYPE] != self.LOUTER):
            data = list(self.alias_map[alias])
            data[JOIN_TYPE] = self.LOUTER
            self.alias_map[alias] = tuple(data)
        if ((unconditional or self.alias_map[alias].nullable) and
                self.alias_map[alias].join_type != self.LOUTER):
            data = self.alias_map[alias]
            data = data._replace(join_type=self.LOUTER)
            self.alias_map[alias] = data
            return True
        return False

@@ -730,7 +734,7 @@ class Query(object):
                continue
            if (alias not in initial_refcounts or
                    self.alias_refcount[alias] == initial_refcounts[alias]):
                parent = self.alias_map[alias][LHS_ALIAS]
                parent = self.alias_map[alias].lhs_alias
                must_promote = considered.get(parent, False)
                promoted = self.promote_alias(alias, must_promote)
                considered[alias] = must_promote or promoted
@@ -767,14 +771,14 @@ class Query(object):
            aliases = tuple([change_map.get(a, a) for a in aliases])
            self.join_map[k] = aliases
        for old_alias, new_alias in change_map.iteritems():
            alias_data = list(self.alias_map[old_alias])
            alias_data[RHS_ALIAS] = new_alias
            alias_data = self.alias_map[old_alias]
            alias_data = alias_data._replace(rhs_alias=new_alias)
            self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
            del self.alias_refcount[old_alias]
            self.alias_map[new_alias] = tuple(alias_data)
            self.alias_map[new_alias] = alias_data
            del self.alias_map[old_alias]

            table_aliases = self.table_map[alias_data[TABLE_NAME]]
            table_aliases = self.table_map[alias_data.table_name]
            for pos, alias in enumerate(table_aliases):
                if alias == old_alias:
                    table_aliases[pos] = new_alias
@@ -789,11 +793,10 @@ class Query(object):

        # 3. Update any joins that refer to the old alias.
        for alias, data in self.alias_map.iteritems():
            lhs = data[LHS_ALIAS]
            lhs = data.lhs_alias
            if lhs in change_map:
                data = list(data)
                data[LHS_ALIAS] = change_map[lhs]
                self.alias_map[alias] = tuple(data)
                data = data._replace(lhs_alias=change_map[lhs])
                self.alias_map[alias] = data

    def bump_prefix(self, exceptions=()):
        """
@@ -876,7 +879,7 @@ class Query(object):
        """
        lhs, table, lhs_col, col = connection
        if lhs in self.alias_map:
            lhs_table = self.alias_map[lhs][TABLE_NAME]
            lhs_table = self.alias_map[lhs].table_name
        else:
            lhs_table = lhs

@@ -889,11 +892,11 @@ class Query(object):
        if not always_create:
            for alias in self.join_map.get(t_ident, ()):
                if alias not in exclusions:
                    if lhs_table and not self.alias_refcount[self.alias_map[alias][LHS_ALIAS]]:
                    if lhs_table and not self.alias_refcount[self.alias_map[alias].lhs_alias]:
                        # The LHS of this join tuple is no longer part of the
                        # query, so skip this possibility.
                        continue
                    if self.alias_map[alias][LHS_ALIAS] != lhs:
                    if self.alias_map[alias].lhs_alias != lhs:
                        continue
                    self.ref_alias(alias)
                    if promote:
@@ -910,7 +913,7 @@ class Query(object):
            join_type = self.LOUTER
        else:
            join_type = self.INNER
        join = (table, alias, join_type, lhs, lhs_col, col, nullable)
        join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable)
        self.alias_map[alias] = join
        if t_ident in self.join_map:
            self.join_map[t_ident] += (alias,)
@@ -1150,7 +1153,7 @@ class Query(object):
                # also be promoted, regardless of whether they have been
                # promoted as a result of this pass through the tables.
                unconditional = (unconditional or
                    self.alias_map[join][JOIN_TYPE] == self.LOUTER)
                    self.alias_map[join].join_type == self.LOUTER)
                if join == table and self.alias_refcount[join] > 1:
                    # We have more than one reference to this join table.
                    # This means that we are dealing with two different query
@@ -1181,8 +1184,8 @@ class Query(object):
            if lookup_type != 'isnull':
                if len(join_list) > 1:
                    for alias in join_list:
                        if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
                            j_col = self.alias_map[alias][RHS_JOIN_COL]
                        if self.alias_map[alias].join_type == self.LOUTER:
                            j_col = self.alias_map[alias].rhs_join_col
                            entry = self.where_class()
                            entry.add(
                                (Constraint(alias, j_col, None), 'isnull', True),
@@ -1510,7 +1513,7 @@ class Query(object):
            join_list = join_list[:penultimate]
            final = penultimate
            penultimate = last.pop()
            col = self.alias_map[extra[0]][LHS_JOIN_COL]
            col = self.alias_map[extra[0]].lhs_join_col
            for alias in extra:
                self.unref_alias(alias)
        else:
@@ -1518,12 +1521,12 @@ class Query(object):
        alias = join_list[-1]
        while final > 1:
            join = self.alias_map[alias]
            if (col != join[RHS_JOIN_COL] or join[JOIN_TYPE] != self.INNER or
            if (col != join.rhs_join_col or join.join_type != self.INNER or
                    nonnull_check):
                break
            self.unref_alias(alias)
            alias = join[LHS_ALIAS]
            col = join[LHS_JOIN_COL]
            alias = join.lhs_alias
            col = join.lhs_join_col
            join_list.pop()
            final -= 1
            if final == penultimate:
@@ -1646,10 +1649,10 @@ class Query(object):
                col = target.column
                if len(joins) > 1:
                    join = self.alias_map[final_alias]
                    if col == join[RHS_JOIN_COL]:
                    if col == join.rhs_join_col:
                        self.unref_alias(final_alias)
                        final_alias = join[LHS_ALIAS]
                        col = join[LHS_JOIN_COL]
                        final_alias = join.lhs_alias
                        col = join.lhs_join_col
                        joins = joins[:-1]
                self.promote_alias_chain(joins[1:])
                self.select.append((final_alias, col))
@@ -1923,7 +1926,7 @@ class Query(object):
        alias = self.get_initial_alias()
        field, col, opts, joins, last, extra = self.setup_joins(
                start.split(LOOKUP_SEP), opts, alias, False)
        select_col = self.alias_map[joins[1]][LHS_JOIN_COL]
        select_col = self.alias_map[joins[1]].lhs_join_col
        select_alias = alias

        # The call to setup_joins added an extra reference to everything in
@@ -1936,12 +1939,12 @@ class Query(object):
        # is *always* the same value as lhs).
        for alias in joins[1:]:
            join_info = self.alias_map[alias]
            if (join_info[LHS_JOIN_COL] != select_col
                    or join_info[JOIN_TYPE] != self.INNER):
            if (join_info.lhs_join_col != select_col
                    or join_info.join_type != self.INNER):
                break
            self.unref_alias(select_alias)
            select_alias = join_info[RHS_ALIAS]
            select_col = join_info[RHS_JOIN_COL]
            select_alias = join_info.rhs_alias
            select_col = join_info.rhs_join_col
        self.select = [(select_alias, select_col)]
        self.remove_inherited_models()