Commit c7fd9b24 authored by Anssi Kääriäinen's avatar Anssi Kääriäinen Committed by Tim Graham
Browse files

Fixed #23875 -- cleaned up query.get_count()

parent 87bd1361
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -335,12 +335,11 @@ class QuerySet(object):
            kwargs[arg.default_alias] = arg

        query = self.query.clone()
        force_subq = query.low_mark != 0 or query.high_mark is not None
        for (alias, aggregate_expr) in kwargs.items():
            query.add_annotation(aggregate_expr, self.model, alias, is_summary=True)
            query.add_annotation(aggregate_expr, alias, is_summary=True)
            if not query.annotations[alias].contains_aggregate:
                raise TypeError("%s is not an aggregate expression" % alias)
        return query.get_aggregation(using=self.db, force_subq=force_subq)
        return query.get_aggregation(self.db, kwargs.keys())

    def count(self):
        """
@@ -824,7 +823,7 @@ class QuerySet(object):
            if alias in names:
                raise ValueError("The annotation '%s' conflicts with a field on "
                                 "the model." % alias)
            obj.query.add_annotation(annotation, self.model, alias, is_summary=False)
            obj.query.add_annotation(annotation, alias, is_summary=False)
        # expressions need to be added to the query before we know if they contain aggregates
        added_aggregates = []
        for alias, annotation in obj.query.annotations.items():
+5 −0
Original line number Diff line number Diff line
@@ -1097,6 +1097,11 @@ class SQLAggregateCompiler(SQLCompiler):
        Creates the SQL for this query. Returns the SQL string and list of
        parameters.
        """
        # Empty SQL for the inner query is a marker that the inner query
        # isn't going to produce any results. This can happen when doing
        # LIMIT 0 queries (generated by qs[:0]) for example.
        if not self.query.subquery:
            raise EmptyResultSet
        sql, params = [], []
        for annotation in self.query.annotation_select.values():
            agg_sql, agg_params = self.compile(annotation)
+25 −92
Original line number Diff line number Diff line
@@ -313,32 +313,35 @@ class Query(object):
        clone.change_aliases(change_map)
        return clone

    def get_aggregation(self, using, force_subq=False):
    def get_aggregation(self, using, added_aggregate_names):
        """
        Returns the dictionary with the values of the existing aggregations.
        """
        if not self.annotation_select:
            return {}

        # annotations must be forced into subquery
        has_annotation = any(
        has_limit = self.low_mark != 0 or self.high_mark is not None
        has_existing_annotations = any(
            annotation for alias, annotation
            in self.annotation_select.items()
            if not annotation.contains_aggregate)

        # If there is a group by clause, aggregating does not add useful
        # information but retrieves only the first row. Aggregate
        # over the subquery instead.
        if self.group_by is not None or force_subq or has_annotation:

            in self.annotations.items()
            if alias not in added_aggregate_names
        )
        # Decide if we need to use a subquery.
        #
        # Existing annotations would cause incorrect results as get_aggregation()
        # must produce just one result and thus must not use GROUP BY. But we
        # aren't smart enough to remove the existing annotations from the
        # query, so those would force us to use GROUP BY.
        #
        # If the query has limit or distinct, then those operations must be
        # done in a subquery so that we are aggregating on the limit and/or
        # distinct results instead of applying the distinct and limit after the
        # aggregation.
        if (self.group_by or has_limit or has_existing_annotations or self.distinct):
            from django.db.models.sql.subqueries import AggregateQuery
            outer_query = AggregateQuery(self.model)
            inner_query = self.clone()
            if not force_subq:
                # In forced subq case the ordering and limits will likely
                # affect the results.
            if not has_limit and not self.distinct_fields:
                inner_query.clear_ordering(True)
                inner_query.clear_limits()
            inner_query.select_for_update = False
            inner_query.select_related = False
            inner_query.related_select_cols = []
@@ -398,34 +401,10 @@ class Query(object):
        Performs a COUNT() query using the current filter constraints.
        """
        obj = self.clone()
        if len(self.select) > 1 or self.annotation_select or (self.distinct and self.distinct_fields):
            # If a select clause exists, then the query has already started to
            # specify the columns that are to be returned.
            # In this case, we need to use a subquery to evaluate the count.
            from django.db.models.sql.subqueries import AggregateQuery
            subquery = obj
            subquery.clear_ordering(True)
            subquery.clear_limits()

            obj = AggregateQuery(obj.model)
            try:
                obj.add_subquery(subquery, using=using)
            except EmptyResultSet:
                # add_subquery evaluates the query, if it's an EmptyResultSet
                # then there are can be no results, and therefore there the
                # count is obviously 0
                return 0

        obj.add_count_column()
        number = obj.get_aggregation(using=using)[None]

        # Apply offset and limit constraints manually, since using LIMIT/OFFSET
        # in SQL (in variants that provide them) doesn't change the COUNT
        # output.
        number = max(0, number - self.low_mark)
        if self.high_mark is not None:
            number = min(number, self.high_mark - self.low_mark)

        obj.add_annotation(Count('*'), alias='__count', is_summary=True)
        number = obj.get_aggregation(using, ['__count'])['__count']
        if number is None:
            number = 0
        return number

    def has_filters(self):
@@ -986,9 +965,9 @@ class Query(object):
        warnings.warn(
            "add_aggregate() is deprecated. Use add_annotation() instead.",
            RemovedInDjango20Warning, stacklevel=2)
        self.add_annotation(aggregate, model, alias, is_summary)
        self.add_annotation(aggregate, alias, is_summary)

    def add_annotation(self, annotation, model, alias, is_summary):
    def add_annotation(self, annotation, alias, is_summary):
        """
        Adds a single annotation expression to the Query
        """
@@ -1746,52 +1725,6 @@ class Query(object):
                for col in annotation.get_group_by_cols():
                    self.group_by.append(col)

    def add_count_column(self):
        """
        Converts the query to do count(...) or count(distinct(pk)) in order to
        get its size.
        """
        summarize = False
        if not self.distinct:
            if not self.select:
                count = Count('*')
                summarize = True
            else:
                assert len(self.select) == 1, \
                    "Cannot add count col with multiple cols in 'select': %r" % self.select
                col = self.select[0].col
                if isinstance(col, (tuple, list)):
                    count = Count(col[1])
                else:
                    count = Count(col)

        else:
            opts = self.get_meta()
            if not self.select:
                lookup = self.join((None, opts.db_table, None)), opts.pk.column
                count = Count(lookup[1], distinct=True)
                summarize = True
            else:
                # Because of SQL portability issues, multi-column, distinct
                # counts need a sub-query -- see get_count() for details.
                assert len(self.select) == 1, \
                    "Cannot add count col with multiple cols in 'select'."
                col = self.select[0].col
                if isinstance(col, (tuple, list)):
                    count = Count(col[1], distinct=True)
                else:
                    count = Count(col, distinct=True)
            # Distinct handling is done in Count(), so don't do it at this
            # level.
            self.distinct = False

        # Set only aggregate to be the count column.
        # Clear out the select cache to reflect the new unmasked annotations.
        count = count.resolve_expression(self, summarize=summarize)
        self._annotations = {None: count}
        self.set_annotation_mask(None)
        self.group_by = None

    def add_select_related(self, fields):
        """
        Sets up the select_related data structure so that we only select