Commit c7805ee2 authored by Josh Smeaton's avatar Josh Smeaton Committed by Tim Graham
Browse files

Fixed #24699 -- Added aggregate support for DurationField on Oracle

parent e60cce4e
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -157,9 +157,6 @@ class BaseDatabaseFeatures(object):
    # Support for the DISTINCT ON clause
    can_distinct_on_fields = False

    # Can the backend use an Avg aggregate on DurationField?
    can_avg_on_durationfield = True

    # Does the backend decide to commit before SAVEPOINT statements
    # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965
    autocommits_when_autocommit_is_off = False
+0 −1
Original line number Diff line number Diff line
@@ -39,7 +39,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
    uppercases_column_names = True
    # select for update with limit can be achieved on Oracle, but not with the current backend.
    supports_select_for_update_with_limit = False
    can_avg_on_durationfield = False  # Pending implementation (#24699).

    def introspected_boolean_field_type(self, field=None, created_separately=False):
        """
+24 −0
Original line number Diff line number Diff line
from django.db.models import DecimalField, DurationField, Func


class IntervalToSeconds(Func):
    function = ''
    template = """
    EXTRACT(day from %(expressions)s) * 86400 +
    EXTRACT(hour from %(expressions)s) * 3600 +
    EXTRACT(minute from %(expressions)s) * 60 +
    EXTRACT(second from %(expressions)s)
    """

    def __init__(self, expression, **extra):
        output_field = extra.pop('output_field', DecimalField())
        super(IntervalToSeconds, self).__init__(expression, output_field=output_field, **extra)


class SecondsToInterval(Func):
    function = 'NUMTODSINTERVAL'
    template = "%(function)s(%(expressions)s, 'SECOND')"

    def __init__(self, expression, **extra):
        output_field = extra.pop('output_field', DurationField())
        super(SecondsToInterval, self).__init__(expression, output_field=output_field, **extra)
+18 −0
Original line number Diff line number Diff line
@@ -78,6 +78,15 @@ class Avg(Aggregate):
        output_field = extra.pop('output_field', FloatField())
        super(Avg, self).__init__(expression, output_field=output_field, **extra)

    def as_oracle(self, compiler, connection):
        if self.output_field.get_internal_type() == 'DurationField':
            expression = self.get_source_expressions()[0]
            from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
            return compiler.compile(
                SecondsToInterval(Avg(IntervalToSeconds(expression)))
            )
        return super(Avg, self).as_sql(compiler, connection)


class Count(Aggregate):
    function = 'COUNT'
@@ -137,6 +146,15 @@ class Sum(Aggregate):
    function = 'SUM'
    name = 'Sum'

    def as_oracle(self, compiler, connection):
        if self.output_field.get_internal_type() == 'DurationField':
            expression = self.get_source_expressions()[0]
            from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
            return compiler.compile(
                SecondsToInterval(Sum(IntervalToSeconds(expression)))
            )
        return super(Sum, self).as_sql(compiler, connection)


class Variance(Aggregate):
    name = 'Variance'
+52 −44
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from django.db.models import (
    F, Aggregate, Avg, Count, DecimalField, DurationField, FloatField, Func,
    IntegerField, Max, Min, Sum, Value,
)
from django.test import TestCase, ignore_warnings, skipUnlessDBFeature
from django.test import TestCase, ignore_warnings
from django.test.utils import Approximate, CaptureQueriesContext
from django.utils import six, timezone
from django.utils.deprecation import RemovedInDjango20Warning
@@ -441,11 +441,16 @@ class AggregateTestCase(TestCase):
        vals = Book.objects.annotate(num_authors=Count("authors__id")).aggregate(Avg("num_authors"))
        self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)})

    @skipUnlessDBFeature('can_avg_on_durationfield')
    def test_avg_duration_field(self):
        self.assertEqual(
            Publisher.objects.aggregate(Avg('duration', output_field=DurationField())),
            {'duration__avg': datetime.timedelta(1, 43200)}  # 1.5 days
            {'duration__avg': datetime.timedelta(days=1, hours=12)}
        )

    def test_sum_duration_field(self):
        self.assertEqual(
            Publisher.objects.aggregate(Sum('duration', output_field=DurationField())),
            {'duration__sum': datetime.timedelta(days=3)}
        )

    def test_sum_distinct_aggregate(self):
@@ -984,17 +989,20 @@ class AggregateTestCase(TestCase):
            Book.objects.annotate(Max('id')).annotate(Sum('id__max'))

    def test_add_implementation(self):
        try:
        class MySum(Sum):
            pass

        # test completely changing how the output is rendered
        def lower_case_function_override(self, compiler, connection):
            sql, params = compiler.compile(self.source_expressions[0])
            substitutions = dict(function=self.function.lower(), expressions=sql)
            substitutions.update(self.extra)
            return self.template % substitutions, params
            setattr(Sum, 'as_' + connection.vendor, lower_case_function_override)
        setattr(MySum, 'as_' + connection.vendor, lower_case_function_override)

            qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
                                       output_field=IntegerField()))
        qs = Book.objects.annotate(
            sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField())
        )
        self.assertEqual(str(qs.query).count('sum('), 1)
        b1 = qs.get(pk=self.b4.pk)
        self.assertEqual(b1.sums, 383)
@@ -1002,11 +1010,12 @@ class AggregateTestCase(TestCase):
        # test changing the dict and delegating
        def lower_case_function_super(self, compiler, connection):
            self.extra['function'] = self.function.lower()
                return super(Sum, self).as_sql(compiler, connection)
            setattr(Sum, 'as_' + connection.vendor, lower_case_function_super)
            return super(MySum, self).as_sql(compiler, connection)
        setattr(MySum, 'as_' + connection.vendor, lower_case_function_super)

            qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
                                       output_field=IntegerField()))
        qs = Book.objects.annotate(
            sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField())
        )
        self.assertEqual(str(qs.query).count('sum('), 1)
        b1 = qs.get(pk=self.b4.pk)
        self.assertEqual(b1.sums, 383)
@@ -1016,15 +1025,14 @@ class AggregateTestCase(TestCase):
            substitutions = dict(function='MAX', expressions='2')
            substitutions.update(self.extra)
            return self.template % substitutions, ()
            setattr(Sum, 'as_' + connection.vendor, be_evil)
        setattr(MySum, 'as_' + connection.vendor, be_evil)

            qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
                                       output_field=IntegerField()))
        qs = Book.objects.annotate(
            sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField())
        )
        self.assertEqual(str(qs.query).count('MAX('), 1)
        b1 = qs.get(pk=self.b4.pk)
        self.assertEqual(b1.sums, 2)
        finally:
            delattr(Sum, 'as_' + connection.vendor)

    def test_complex_values_aggregation(self):
        max_rating = Book.objects.values('rating').aggregate(