Commit 11699ac4 authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Fixed #19190 -- Refactored Query select clause attributes

The Query.select and Query.select_fields were collapsed into one list
because the attributes had to be always in sync. Now that they are in
one attribute it is impossible to edit them out of sync.

Similar collapse was done for Query.related_select_cols and
Query.related_select_fields.
parent 789ea334
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
        if self.query.select:
            only_load = self.deferred_to_columns()
            # This loop customized for GeoQuery.
            for col, field in zip(self.query.select, self.query.select_fields):
            for col, field in self.query.select:
                if isinstance(col, (list, tuple)):
                    alias, column = col
                    table = self.query.alias_map[alias].table_name
@@ -85,7 +85,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
        ])

        # This loop customized for GeoQuery.
        for (table, col), field in zip(self.query.related_select_cols, self.query.related_select_fields):
        for (table, col), field in self.query.related_select_cols:
            r = self.get_field_select(field, table, col)
            if with_aliases and col in col_aliases:
                c_alias = 'Col%d' % len(col_aliases)
+14 −14
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from django.db.backends.util import truncate_name
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
        GET_ITERATOR_CHUNK_SIZE)
        GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
@@ -188,7 +188,7 @@ class SQLCompiler(object):
            col_aliases = set()
        if self.query.select:
            only_load = self.deferred_to_columns()
            for col in self.query.select:
            for col, _ in self.query.select:
                if isinstance(col, (list, tuple)):
                    alias, column = col
                    table = self.query.alias_map[alias].table_name
@@ -233,7 +233,7 @@ class SQLCompiler(object):
            for alias, aggregate in self.query.aggregate_select.items()
        ])

        for table, col in self.query.related_select_cols:
        for (table, col), _ in self.query.related_select_cols:
            r = '%s.%s' % (qn(table), qn(col))
            if with_aliases and col in col_aliases:
                c_alias = 'Col%d' % len(col_aliases)
@@ -557,8 +557,9 @@ class SQLCompiler(object):
            for extra_select, extra_params in six.itervalues(self.query.extra_select):
                extra_selects.append(extra_select)
                params.extend(extra_params)
            cols = (group_by + self.query.select +
                self.query.related_select_cols + extra_selects)
            select_cols = [s.col for s in self.query.select]
            related_select_cols = [s.col for s in self.query.related_select_cols]
            cols = (group_by + select_cols + related_select_cols + extra_selects)
            seen = set()
            for col in cols:
                if col in seen:
@@ -589,7 +590,6 @@ class SQLCompiler(object):
            opts = self.query.get_meta()
            root_alias = self.query.get_initial_alias()
            self.query.related_select_cols = []
            self.query.related_select_fields = []
        if not used:
            used = set()
        if dupe_set is None:
@@ -664,8 +664,8 @@ class SQLCompiler(object):
            used.add(alias)
            columns, aliases = self.get_default_columns(start_alias=alias,
                    opts=f.rel.to._meta, as_pairs=True)
            self.query.related_select_cols.extend(columns)
            self.query.related_select_fields.extend(f.rel.to._meta.fields)
            self.query.related_select_cols.extend(
                SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields))
            if restricted:
                next = requested.get(f.name, {})
            else:
@@ -734,8 +734,8 @@ class SQLCompiler(object):
                used.add(alias)
                columns, aliases = self.get_default_columns(start_alias=alias,
                    opts=model._meta, as_pairs=True, local_only=True)
                self.query.related_select_cols.extend(columns)
                self.query.related_select_fields.extend(model._meta.fields)
                self.query.related_select_cols.extend(
                    SelectInfo(col, field) for col, field in zip(columns, model._meta.fields))

                next = requested.get(f.related_query_name(), {})
                # Use True here because we are looking at the _reverse_ side of
@@ -772,7 +772,7 @@ class SQLCompiler(object):
                if resolve_columns:
                    if fields is None:
                        # We only set this up here because
                        # related_select_fields isn't populated until
                        # related_select_cols isn't populated until
                        # execute_sql() has been called.

                        # We also include types of fields of related models that
@@ -782,11 +782,11 @@ class SQLCompiler(object):

                        # This code duplicates the logic for the order of fields
                        # found in get_columns(). It would be nice to clean this up.
                        if self.query.select_fields:
                            fields = self.query.select_fields
                        if self.query.select:
                            fields = [f.field for f in self.query.select]
                        else:
                            fields = self.query.model._meta.fields
                        fields = fields + self.query.related_select_fields
                        fields = fields + [f.field for f in self.query.related_select_cols]

                        # If the field was deferred, exclude it from being passed
                        # into `resolve_columns` because it wasn't selected.
+3 −0
Original line number Diff line number Diff line
@@ -25,6 +25,9 @@ JoinInfo = namedtuple('JoinInfo',
                      'table_name rhs_alias join_type lhs_alias '
                      'lhs_join_col rhs_join_col nullable')

# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')

# How many results to expect from a cursor.execute call
MULTI = 'multi'
SINGLE = 'single'
+40 −47
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
        ORDER_PATTERN, JoinInfo)
        ORDER_PATTERN, JoinInfo, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@@ -115,17 +115,20 @@ class Query(object):
        self.default_ordering = True
        self.standard_ordering = True
        self.ordering_aliases = []
        self.related_select_fields = []
        self.dupe_avoidance = {}
        self.used_aliases = set()
        self.filter_is_sticky = False
        self.included_inherited_models = {}

        # SQL-related attributes  
        # Select and related select clauses as SelectInfo instances.
        # The select is used for cases where we want to set up the select
        # clause to contain other than default fields (values(), annotate(),
        # subqueries...)
        self.select = []
        # For each to-be-selected field in self.select there must be a
        # corresponding entry in self.select - git seems to need this.
        self.select_fields = []
        # The related_select_cols is used for columns needed for
        # select_related - this is populated in compile stage.
        self.related_select_cols = []
        self.tables = []    # Aliases in the order they are created.
        self.where = where()
        self.where_class = where
@@ -138,7 +141,6 @@ class Query(object):
        self.select_for_update = False
        self.select_for_update_nowait = False
        self.select_related = False
        self.related_select_cols = []

        # SQL aggregate-related attributes
        self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
@@ -191,15 +193,14 @@ class Query(object):
        Pickling support.
        """
        obj_dict = self.__dict__.copy()
        obj_dict['related_select_fields'] = []
        obj_dict['related_select_cols'] = []

        # Fields can't be pickled, so if a field list has been
        # specified, we pickle the list of field names instead.
        # None is also a possible value; that can pass as-is
        obj_dict['select_fields'] = [
            f is not None and f.name or None
            for f in obj_dict['select_fields']
        obj_dict['select'] = [
            (s.col, s.field is not None and s.field.name or None)
            for s in obj_dict['select']
        ]
        return obj_dict

@@ -209,9 +210,9 @@ class Query(object):
        """
        # Rebuild list of field instances
        opts = obj_dict['model']._meta
        obj_dict['select_fields'] = [
            name is not None and opts.get_field(name) or None
            for name in obj_dict['select_fields']
        obj_dict['select'] = [
            SelectInfo(tpl[0], tpl[1] is not None and opts.get_field(tpl[1]) or None)
            for tpl in obj_dict['select']
        ]

        self.__dict__.update(obj_dict)
@@ -256,10 +257,9 @@ class Query(object):
        obj.standard_ordering = self.standard_ordering
        obj.included_inherited_models = self.included_inherited_models.copy()
        obj.ordering_aliases = []
        obj.select_fields = self.select_fields[:]
        obj.related_select_fields = self.related_select_fields[:]
        obj.dupe_avoidance = self.dupe_avoidance.copy()
        obj.select = self.select[:]
        obj.related_select_cols = []
        obj.tables = self.tables[:]
        obj.where = copy.deepcopy(self.where, memo=memo)
        obj.where_class = self.where_class
@@ -275,7 +275,6 @@ class Query(object):
        obj.select_for_update = self.select_for_update
        obj.select_for_update_nowait = self.select_for_update_nowait
        obj.select_related = self.select_related
        obj.related_select_cols = []
        obj.aggregates = copy.deepcopy(self.aggregates, memo=memo)
        if self.aggregate_select_mask is None:
            obj.aggregate_select_mask = None
@@ -384,7 +383,6 @@ class Query(object):
        query.select_for_update = False
        query.select_related = False
        query.related_select_cols = []
        query.related_select_fields = []

        result = query.get_compiler(using).execute_sql(SINGLE)
        if result is None:
@@ -527,14 +525,14 @@ class Query(object):

        # Selection columns and extra extensions are those provided by 'rhs'.
        self.select = []
        for col in rhs.select:
        for col, field in rhs.select:
            if isinstance(col, (list, tuple)):
                self.select.append((change_map.get(col[0], col[0]), col[1]))
                new_col = change_map.get(col[0], col[0]), col[1]
                self.select.append(SelectInfo(new_col, field))
            else:
                item = copy.deepcopy(col)
                item.relabel_aliases(change_map)
                self.select.append(item)
        self.select_fields = rhs.select_fields[:]
                self.select.append(SelectInfo(item, field))

        if connector == OR:
            # It would be nice to be able to handle this, but the queries don't
@@ -750,24 +748,23 @@ class Query(object):
        """
        assert set(change_map.keys()).intersection(set(change_map.values())) == set()

        # 1. Update references in "select" (normal columns plus aliases),
        # "group by", "where" and "having".
        self.where.relabel_aliases(change_map)
        self.having.relabel_aliases(change_map)
        for columns in [self.select, self.group_by or []]:
            for pos, col in enumerate(columns):
        def relabel_column(col):
            if isinstance(col, (list, tuple)):
                old_alias = col[0]
                    columns[pos] = (change_map.get(old_alias, old_alias), col[1])
                else:
                    col.relabel_aliases(change_map)
        for mapping in [self.aggregates]:
            for key, col in mapping.items():
                if isinstance(col, (list, tuple)):
                    old_alias = col[0]
                    mapping[key] = (change_map.get(old_alias, old_alias), col[1])
                return (change_map.get(old_alias, old_alias), col[1])
            else:
                col.relabel_aliases(change_map)
                return col
        # 1. Update references in "select" (normal columns plus aliases),
        # "group by", "where" and "having".
        self.where.relabel_aliases(change_map)
        self.having.relabel_aliases(change_map)
        if self.group_by:
            self.group_by = [relabel_column(col) for col in self.group_by]
        self.select = [SelectInfo(relabel_column(s.col), s.field)
                       for s in self.select]
        self.aggregates = SortedDict(
            (key, relabel_column(col)) for key, col in self.aggregates.items())

        # 2. Rename the alias in the internal table/alias datastructures.
        for k, aliases in self.join_map.items():
@@ -1570,7 +1567,7 @@ class Query(object):
        # since we are adding a IN <subquery> clause. This prevents the
        # database from tripping over IN (...,NULL,...) selects and returning
        # nothing
        alias, col = query.select[0]
        alias, col = query.select[0].col
        query.where.add((Constraint(alias, col, None), 'isnull', False), AND)

        self.add_filter(('%s__in' % prefix, query), negate=True, trim=True,
@@ -1629,7 +1626,6 @@ class Query(object):
        Removes all fields from SELECT clause.
        """
        self.select = []
        self.select_fields = []
        self.default_cols = False
        self.select_related = False
        self.set_extra_mask(())
@@ -1642,7 +1638,6 @@ class Query(object):
        columns.
        """
        self.select = []
        self.select_fields = []

    def add_distinct_fields(self, *field_names):
        """
@@ -1674,8 +1669,7 @@ class Query(object):
                        col = join.lhs_join_col
                        joins = joins[:-1]
                self.promote_joins(joins[1:])
                self.select.append((final_alias, col))
                self.select_fields.append(field)
                self.select.append(SelectInfo((final_alias, col), field))
        except MultiJoin:
            raise FieldError("Invalid field name: '%s'" % name)
        except FieldError:
@@ -1731,8 +1725,8 @@ class Query(object):
        """
        self.group_by = []

        for sel in self.select:
            self.group_by.append(sel)
        for col, _ in self.select:
            self.group_by.append(col)

    def add_count_column(self):
        """
@@ -1745,7 +1739,7 @@ class Query(object):
            else:
                assert len(self.select) == 1, \
                        "Cannot add count col with multiple cols in 'select': %r" % self.select
                count = self.aggregates_module.Count(self.select[0])
                count = self.aggregates_module.Count(self.select[0].col)
        else:
            opts = self.model._meta
            if not self.select:
@@ -1757,7 +1751,7 @@ class Query(object):
                assert len(self.select) == 1, \
                        "Cannot add count col with multiple cols in 'select'."

                count = self.aggregates_module.Count(self.select[0], distinct=True)
                count = self.aggregates_module.Count(self.select[0].col, distinct=True)
            # Distinct handling is done in Count(), so don't do it at this
            # level.
            self.distinct = False
@@ -1781,7 +1775,6 @@ class Query(object):
                d = d.setdefault(part, {})
        self.select_related = field_dict
        self.related_select_cols = []
        self.related_select_fields = []

    def add_extra(self, select, select_params, where, params, tables, order_by):
        """
@@ -1975,7 +1968,7 @@ class Query(object):
            self.unref_alias(select_alias)
            select_alias = join_info.rhs_alias
            select_col = join_info.rhs_join_col
        self.select = [(select_alias, select_col)]
        self.select = [SelectInfo((select_alias, select_col), None)]
        self.remove_inherited_models()

    def is_nullable(self, field):
+2 −2
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ class DeleteQuery(Query):
                return
            else:
                innerq.clear_select_clause()
                innerq.select, innerq.select_fields = [(self.get_initial_alias(), pk.column)], [None]
                innerq.select = [SelectInfo((self.get_initial_alias(), pk.column), None)]
                values = innerq
            where = self.where_class()
            where.add((Constraint(None, pk.column, pk), 'in', values), AND)
@@ -244,7 +244,7 @@ class DateQuery(Query):
        alias = result[3][-1]
        select = Date((alias, field.column), lookup_type)
        self.clear_select_clause()
        self.select, self.select_fields = [select], [None]
        self.select = [SelectInfo(select, None)]
        self.distinct = True
        self.order_by = order == 'ASC' and [1] or [-1]