Commit e2d6e146 authored by Josh Smeaton's avatar Josh Smeaton
Browse files

Refs #14030 -- Improved expression support for python values

parent 07cfe1bd
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
from django.db.models.aggregates import StdDev
from django.db.models.expressions import Value
from django.db.utils import ProgrammingError
from django.utils.functional import cached_property

@@ -232,7 +231,7 @@ class BaseDatabaseFeatures(object):
    def supports_stddev(self):
        """Confirm support for STDDEV and related stats functions."""
        try:
            self.connection.ops.check_expression_support(StdDev(Value(1)))
            self.connection.ops.check_expression_support(StdDev(1))
            return True
        except NotImplementedError:
            return False
+12 −14
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import Q, refs_aggregate
from django.utils import timezone
from django.utils import six, timezone
from django.utils.functional import cached_property


@@ -138,6 +138,13 @@ class BaseExpression(object):
    def set_source_expressions(self, exprs):
        assert len(exprs) == 0

    def _parse_expressions(self, *expressions):
        return [
            arg if hasattr(arg, 'resolve_expression') else (
                F(arg) if isinstance(arg, six.string_types) else Value(arg)
            ) for arg in expressions
        ]

    def as_sql(self, compiler, connection):
        """
        Responsible for returning a (sql, [params]) tuple to be included
@@ -466,12 +473,6 @@ class Func(ExpressionNode):
    def set_source_expressions(self, exprs):
        self.source_expressions = exprs

    def _parse_expressions(self, *expressions):
        return [
            arg if hasattr(arg, 'resolve_expression') else F(arg)
            for arg in expressions
        ]

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        c = self.copy()
        c.is_summary = summarize
@@ -639,14 +640,14 @@ class Ref(ExpressionNode):
class When(ExpressionNode):
    template = 'WHEN %(condition)s THEN %(result)s'

    def __init__(self, condition=None, then=Value(None), **lookups):
    def __init__(self, condition=None, then=None, **lookups):
        if lookups and condition is None:
            condition, lookups = Q(**lookups), None
        if condition is None or not isinstance(condition, Q) or lookups:
            raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
        super(When, self).__init__(output_field=None)
        self.condition = condition
        self.result = self._parse_expression(then)
        self.result = self._parse_expressions(then)[0]

    def __str__(self):
        return "WHEN %r THEN %r" % (self.condition, self.result)
@@ -664,9 +665,6 @@ class When(ExpressionNode):
        # We're only interested in the fields of the result expressions.
        return [self.result._output_field_or_none]

    def _parse_expression(self, expression):
        return expression if hasattr(expression, 'resolve_expression') else F(expression)

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        c = self.copy()
        c.is_summary = summarize
@@ -713,11 +711,11 @@ class Case(ExpressionNode):
    def __init__(self, *cases, **extra):
        if not all(isinstance(case, When) for case in cases):
            raise TypeError("Positional arguments must all be When objects.")
        default = extra.pop('default', Value(None))
        default = extra.pop('default', None)
        output_field = extra.pop('output_field', None)
        super(Case, self).__init__(output_field)
        self.cases = list(cases)
        self.default = default if hasattr(default, 'resolve_expression') else F(default)
        self.default = self._parse_expressions(default)[0]

    def __str__(self):
        return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
+6 −6
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ We'll be using the following model in the subsequent examples::
When
----

.. class:: When(condition=None, then=Value(None), **lookups)
.. class:: When(condition=None, then=None, **lookups)

A ``When()`` object is used to encapsulate a condition and its result for use
in the conditional expression. Using a ``When()`` object is similar to using
@@ -73,8 +73,8 @@ Keep in mind that each of these values can be an expression.
    resolved in two ways::

        >>> from django.db.models import Value
        >>> When(then__exact=0, then=Value(1))
        >>> When(Q(then=0), then=Value(1))
        >>> When(then__exact=0, then=1)
        >>> When(Q(then=0), then=1)

Case
----
@@ -197,15 +197,15 @@ What if we want to find out how many clients there are for each
    >>> from django.db.models import IntegerField, Sum
    >>> Client.objects.aggregate(
    ...     regular=Sum(
    ...         Case(When(account_type=Client.REGULAR, then=Value(1)),
    ...         Case(When(account_type=Client.REGULAR, then=1),
    ...              output_field=IntegerField())
    ...     ),
    ...     gold=Sum(
    ...         Case(When(account_type=Client.GOLD, then=Value(1)),
    ...         Case(When(account_type=Client.GOLD, then=1),
    ...              output_field=IntegerField())
    ...     ),
    ...     platinum=Sum(
    ...         Case(When(account_type=Client.PLATINUM, then=Value(1)),
    ...         Case(When(account_type=Client.PLATINUM, then=1),
    ...              output_field=IntegerField())
    ...     )
    ... )
+4 −0
Original line number Diff line number Diff line
@@ -217,6 +217,10 @@ function will be applied to. The expressions will be converted to strings,
joined together with ``arg_joiner``, and then interpolated into the ``template``
as the ``expressions`` placeholder.

Positional arguments can be expressions or Python values. Strings are
assumed to be column references and will be wrapped in ``F()`` expressions
while other values will be wrapped in ``Value()`` expressions.

The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. Note that the keywords ``function`` and
``template`` can be used to replace the ``function`` and ``template``
+5 −5
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from django.core.exceptions import FieldError
from django.db import connection
from django.db.models import (
    F, Aggregate, Avg, Count, DecimalField, FloatField, Func, IntegerField,
    Max, Min, Sum, Value,
    Max, Min, Sum,
)
from django.test import TestCase, ignore_warnings
from django.test.utils import Approximate, CaptureQueriesContext
@@ -706,14 +706,14 @@ class ComplexAggregateTestCase(TestCase):
            Book.objects.aggregate(fail=F('price'))

    def test_nonfield_annotation(self):
        book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField())))[0]
        book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0]
        self.assertEqual(book.val, 2)
        book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField()))[0]
        book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0]
        self.assertEqual(book.val, 2)

    def test_missing_output_field_raises_error(self):
        with six.assertRaisesRegex(self, FieldError, 'Cannot resolve expression type, unknown output_field'):
            Book.objects.annotate(val=Max(Value(2)))[0]
            Book.objects.annotate(val=Max(2))[0]

    def test_annotation_expressions(self):
        authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
@@ -772,7 +772,7 @@ class ComplexAggregateTestCase(TestCase):
        with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'):
            Author.objects.aggregate(Sum('age') / Count('age'))
        with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'):
            Author.objects.aggregate(Sum(Value(1)))
            Author.objects.aggregate(Sum(1))

    def test_aggregate_over_complex_annotation(self):
        qs = Author.objects.annotate(
Loading