Commit 4a66a692 authored by Greg Chapple's avatar Greg Chapple Committed by Tim Graham
Browse files

Fixed #24887 -- Removed one-arg limit from models.aggregate

parent 6c592e79
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -25,7 +25,8 @@ class GeoAggregate(Aggregate):

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        c = super(GeoAggregate, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
        if not hasattr(c.input_field.field, 'geom_type'):
        for expr in c.get_source_expressions():
            if not hasattr(expr.field, 'geom_type'):
                raise ValueError('Geospatial aggregates only allowed on geometry fields.')
        return c

+13 −11
Original line number Diff line number Diff line
@@ -35,16 +35,18 @@ class DatabaseOperations(BaseDatabaseOperations):
        bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
        bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
        if isinstance(expression, bad_aggregates):
            for expr in expression.get_source_expressions():
                try:
                output_field = expression.input_field.output_field
                    output_field = expr.output_field
                    if isinstance(output_field, bad_fields):
                        raise NotImplementedError(
                        'You cannot use Sum, Avg, StdDev and Variance aggregations '
                        'on date/time fields in sqlite3 '
                        'since date/time is saved as text.')
                            'You cannot use Sum, Avg, StdDev, and Variance '
                            'aggregations on date/time fields in sqlite3 '
                            'since date/time is saved as text.'
                        )
                except FieldError:
                # not every sub-expression has an output_field which is fine to
                # ignore
                    # Not every subexpression has an output_field which is fine
                    # to ignore.
                    pass

    def date_extract_sql(self, lookup_type, field_name):
+10 −7
Original line number Diff line number Diff line
@@ -15,13 +15,15 @@ class Aggregate(Func):
    name = None

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        assert len(self.source_expressions) == 1
        # Aggregates are not allowed in UPDATE queries, so ignore for_save
        c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
        if c.source_expressions[0].contains_aggregate and not summarize:
            name = self.source_expressions[0].name
            raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
                c.name, name, name))
        if not summarize:
            expressions = c.get_source_expressions()
            for index, expr in enumerate(expressions):
                if expr.contains_aggregate:
                    before_resolved = self.get_source_expressions()[index]
                    name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
                    raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
        c._patch_aggregate(query)  # backward-compatibility support
        return c

@@ -31,8 +33,9 @@ class Aggregate(Func):

    @property
    def default_alias(self):
        if hasattr(self.source_expressions[0], 'name'):
            return '%s__%s' % (self.source_expressions[0].name, self.name.lower())
        expressions = self.get_source_expressions()
        if len(expressions) == 1 and hasattr(expressions[0], 'name'):
            return '%s__%s' % (expressions[0].name, self.name.lower())
        raise TypeError("Complex expressions require an alias")

    def get_group_by_cols(self):
+23 −1
Original line number Diff line number Diff line
@@ -985,9 +985,31 @@ class AggregateTestCase(TestCase):
        self.assertEqual(author.sum_age, other_author.sum_age)

    def test_annotated_aggregate_over_annotated_aggregate(self):
        with six.assertRaisesRegex(self, FieldError, "Cannot compute Sum\('id__max'\): 'id__max' is an aggregate"):
        with self.assertRaisesMessage(FieldError, "Cannot compute Sum('id__max'): 'id__max' is an aggregate"):
            Book.objects.annotate(Max('id')).annotate(Sum('id__max'))

        class MyMax(Max):
            def as_sql(self, compiler, connection):
                self.set_source_expressions(self.get_source_expressions()[0:1])
                return super(MyMax, self).as_sql(compiler, connection)

        with self.assertRaisesMessage(FieldError, "Cannot compute Max('id__max'): 'id__max' is an aggregate"):
            Book.objects.annotate(Max('id')).annotate(my_max=MyMax('id__max', 'price'))

    def test_multi_arg_aggregate(self):
        class MyMax(Max):
            def as_sql(self, compiler, connection):
                self.set_source_expressions(self.get_source_expressions()[0:1])
                return super(MyMax, self).as_sql(compiler, connection)

        with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'):
            Book.objects.aggregate(MyMax('pages', 'price'))

        with self.assertRaisesMessage(TypeError, 'Complex annotations require an alias'):
            Book.objects.annotate(MyMax('pages', 'price'))

        Book.objects.aggregate(max_field=MyMax('pages', 'price'))

    def test_add_implementation(self):
        class MySum(Sum):
            pass