Commit 542709d0 authored by Russell Keith-Magee's avatar Russell Keith-Magee
Browse files

Fixed #10182 -- Corrected realiasing and the process of evaluating values()...

Fixed #10182 -- Corrected realiasing and the process of evaluating values() for queries with aggregate clauses. This means that aggregate queries can now be used as subqueries (such as in an __in clause). Thanks to omat for the report.

This involves a slight change to the interaction of annotate() and values() clauses that specify a list of columns. See the docs for details.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9888 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 4bd24474
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ class Aggregate(object):
        # Validate that the backend has a fully supported, correct
        # implementation of this aggregate
        query.connection.ops.check_aggregate_support(aggregate)
        query.aggregate_select[alias] = aggregate
        query.aggregates[alias] = aggregate

class Avg(Aggregate):
    name = 'Avg'
+12 −4
Original line number Diff line number Diff line
@@ -596,7 +596,7 @@ class QuerySet(object):

        obj = self._clone()

        obj._setup_aggregate_query()
        obj._setup_aggregate_query(kwargs.keys())

        # Add the aggregates to the query
        for (alias, aggregate_expr) in kwargs.items():
@@ -693,7 +693,7 @@ class QuerySet(object):
        """
        pass

    def _setup_aggregate_query(self):
    def _setup_aggregate_query(self, aggregates):
        """
        Prepare the query for computing a result that contains aggregate annotations.
        """
@@ -773,6 +773,8 @@ class ValuesQuerySet(QuerySet):

        self.query.select = []
        self.query.add_fields(self.field_names, False)
        if self.aggregate_names is not None:
            self.query.set_aggregate_mask(self.aggregate_names)

    def _clone(self, klass=None, setup=False, **kwargs):
        """
@@ -798,13 +800,17 @@ class ValuesQuerySet(QuerySet):
            raise TypeError("Merging '%s' classes must involve the same values in each case."
                    % self.__class__.__name__)

    def _setup_aggregate_query(self):
    def _setup_aggregate_query(self, aggregates):
        """
        Prepare the query for computing a result that contains aggregate annotations.
        """
        self.query.set_group_by()

        super(ValuesQuerySet, self)._setup_aggregate_query()
        if self.aggregate_names is not None:
            self.aggregate_names.extend(aggregates)
            self.query.set_aggregate_mask(self.aggregate_names)

        super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)

    def as_sql(self):
        """
@@ -824,6 +830,7 @@ class ValuesListQuerySet(ValuesQuerySet):
    def iterator(self):
        if self.extra_names is not None:
            self.query.trim_extra_select(self.extra_names)

        if self.flat and len(self._fields) == 1:
            for row in self.query.results_iter():
                yield row[0]
@@ -837,6 +844,7 @@ class ValuesListQuerySet(ValuesQuerySet):
            extra_names = self.query.extra_select.keys()
            field_names = self.field_names
            aggregate_names = self.query.aggregate_select.keys()

            names = extra_names + field_names + aggregate_names

            # If a field list has been specified, use it. Otherwise, use the
+55 −15
Original line number Diff line number Diff line
@@ -77,7 +77,9 @@ class BaseQuery(object):
        self.related_select_cols = []

        # SQL aggregate-related attributes
        self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
        self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
        self.aggregate_select_mask = None
        self._aggregate_select_cache = None

        # Arbitrary maximum limit for select_related. Prevents infinite
        # recursion. Can be changed by the depth parameter to select_related().
@@ -187,7 +189,15 @@ class BaseQuery(object):
        obj.distinct = self.distinct
        obj.select_related = self.select_related
        obj.related_select_cols = []
        obj.aggregate_select = self.aggregate_select.copy()
        obj.aggregates = self.aggregates.copy()
        if self.aggregate_select_mask is None:
            obj.aggregate_select_mask = None
        else:
            obj.aggregate_select_mask = self.aggregate_select_mask[:]
        if self._aggregate_select_cache is None:
            obj._aggregate_select_cache = None
        else:
            obj._aggregate_select_cache = self._aggregate_select_cache.copy()
        obj.max_depth = self.max_depth
        obj.extra_select = self.extra_select.copy()
        obj.extra_tables = self.extra_tables
@@ -940,12 +950,15 @@ class BaseQuery(object):
        """
        assert set(change_map.keys()).intersection(set(change_map.values())) == set()

        # 1. Update references in "select" and "where".
        # 1. Update references in "select" (normal columns plus aliases),
        # "group by", "where" and "having".
        self.where.relabel_aliases(change_map)
        for pos, col in enumerate(self.select):
        self.having.relabel_aliases(change_map)
        for columns in (self.select, self.aggregates.values(), self.group_by or []):
            for pos, col in enumerate(columns):
                if isinstance(col, (list, tuple)):
                    old_alias = col[0]
                self.select[pos] = (change_map.get(old_alias, old_alias), col[1])
                    columns[pos] = (change_map.get(old_alias, old_alias), col[1])
                else:
                    col.relabel_aliases(change_map)

@@ -1205,11 +1218,11 @@ class BaseQuery(object):
        opts = model._meta
        field_list = aggregate.lookup.split(LOOKUP_SEP)
        if (len(field_list) == 1 and
            aggregate.lookup in self.aggregate_select.keys()):
            aggregate.lookup in self.aggregates.keys()):
            # Aggregate is over an annotation
            field_name = field_list[0]
            col = field_name
            source = self.aggregate_select[field_name]
            source = self.aggregates[field_name]
        elif (len(field_list) > 1 or
            field_list[0] not in [i.name for i in opts.fields]):
            field, source, opts, join_list, last, _ = self.setup_joins(
@@ -1299,7 +1312,7 @@ class BaseQuery(object):
            value = SQLEvaluator(value, self)
            having_clause = value.contains_aggregate

        for alias, aggregate in self.aggregate_select.items():
        for alias, aggregate in self.aggregates.items():
            if alias == parts[0]:
                entry = self.where_class()
                entry.add((aggregate, lookup_type, value), AND)
@@ -1824,8 +1837,8 @@ class BaseQuery(object):
        self.group_by = []
        if self.connection.features.allows_group_by_pk:
            if len(self.select) == len(self.model._meta.fields):
                self.group_by.append('.'.join([self.model._meta.db_table,
                                               self.model._meta.pk.column]))
                self.group_by.append((self.model._meta.db_table,
                                      self.model._meta.pk.column))
                return

        for sel in self.select:
@@ -1858,7 +1871,11 @@ class BaseQuery(object):
            # Distinct handling is done in Count(), so don't do it at this
            # level.
            self.distinct = False
        self.aggregate_select = {None: count}

        # Set only aggregate to be the count column.
        # Clear out the select cache to reflect the new unmasked aggregates.
        self.aggregates = {None: count}
        self.set_aggregate_mask(None)

    def add_select_related(self, fields):
        """
@@ -1920,6 +1937,29 @@ class BaseQuery(object):
        for key in set(self.extra_select).difference(set(names)):
            del self.extra_select[key]

    def set_aggregate_mask(self, names):
        "Set the mask of aggregates that will actually be returned by the SELECT"
        self.aggregate_select_mask = names
        self._aggregate_select_cache = None

    def _aggregate_select(self):
        """The SortedDict of aggregate columns that are not masked, and should
        be used in the SELECT clause.

        This result is cached for optimization purposes.
        """
        if self._aggregate_select_cache is not None:
            return self._aggregate_select_cache
        elif self.aggregate_select_mask is not None:
            self._aggregate_select_cache = SortedDict([
                (k,v) for k,v in self.aggregates.items()
                if k in self.aggregate_select_mask
            ])
            return self._aggregate_select_cache
        else:
            return self.aggregates
    aggregate_select = property(_aggregate_select)

    def set_start(self, start):
        """
        Sets the table from which to start joining. The start position is
+8 −4
Original line number Diff line number Diff line
@@ -213,10 +213,14 @@ class WhereNode(tree.Node):
            elif isinstance(child, tree.Node):
                self.relabel_aliases(change_map, child)
            else:
                if isinstance(child[0], (list, tuple)):
                    elt = list(child[0])
                    if elt[0] in change_map:
                        elt[0] = change_map[elt[0]]
                        node.children[pos] = (tuple(elt),) + child[1:]
                else:
                    child[0].relabel_aliases(change_map)

                # Check if the query value also requires relabelling
                if hasattr(child[3], 'relabel_aliases'):
                    child[3].relabel_aliases(change_map)
+10 −4
Original line number Diff line number Diff line
@@ -284,9 +284,6 @@ two authors with the same name, their results will be merged into a single
result in the output of the query; the average will be computed as the
average over the books written by both authors.

The annotation name will be added to the fields returned
as part of the ``ValuesQuerySet``.

Order of ``annotate()`` and ``values()`` clauses
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

@@ -303,12 +300,21 @@ output.
For example, if we reverse the order of the ``values()`` and ``annotate()``
clause from our previous example::

    >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name')
    >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name', 'average_rating')

This will now yield one unique result for each author; however, only
the author's name and the ``average_rating`` annotation will be returned
in the output data.

You should also note that ``average_rating`` has been explicitly included
in the list of values to be returned. This is required because of the
ordering of the ``values()`` and ``annotate()`` clause.

If the ``values()`` clause precedes the ``annotate()`` clause, any annotations
will be automatically added to the result set. However, if the ``values()``
clause is applied after the ``annotate()`` clause, you need to explicitly
include the aggregate column.

Aggregating annotations
-----------------------

Loading