Commit e654123f authored by Josh Smeaton's avatar Josh Smeaton
Browse files

Fixed #24485 -- Allowed combined expressions to set output_field

parent 3a1886d1
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -4,7 +4,9 @@ import warnings

from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured  # NOQA
from django.db.models.query import Q, QuerySet, Prefetch  # NOQA
from django.db.models.expressions import Expression, F, Value, Func, Case, When  # NOQA
from django.db.models.expressions import (  # NOQA
    Expression, ExpressionWrapper, F, Value, Func, Case, When,
)
from django.db.models.manager import Manager  # NOQA
from django.db.models.base import Model  # NOQA
from django.db.models.aggregates import *  # NOQA
+26 −3
Original line number Diff line number Diff line
@@ -126,12 +126,12 @@ class BaseExpression(object):
    # aggregate specific fields
    is_summary = False

    def get_db_converters(self, connection):
        return [self.convert_value] + self.output_field.get_db_converters(connection)

    def __init__(self, output_field=None):
        self._output_field = output_field

    def get_db_converters(self, connection):
        return [self.convert_value] + self.output_field.get_db_converters(connection)

    def get_source_expressions(self):
        return []

@@ -656,6 +656,29 @@ class Ref(Expression):
        return [self]


class ExpressionWrapper(Expression):
    """
    An expression that can wrap another expression so that it can provide
    extra context to the inner expression, such as the output_field.
    """

    def __init__(self, expression, output_field):
        super(ExpressionWrapper, self).__init__(output_field=output_field)
        self.expression = expression

    def set_source_expressions(self, exprs):
        self.expression = exprs[0]

    def get_source_expressions(self):
        return [self.expression]

    def as_sql(self, compiler, connection):
        return self.expression.as_sql(compiler, connection)

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, self.expression)


class When(Expression):
    template = 'WHEN %(condition)s THEN %(result)s'

+34 −11
Original line number Diff line number Diff line
@@ -165,6 +165,27 @@ values, rather than on Python values.
This is documented in :ref:`using F() expressions in queries
<using-f-expressions-in-filters>`.

.. _using-f-with-annotations:

Using ``F()`` with annotations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

``F()`` can be used to create dynamic fields on your models by combining
different fields with arithmetic::

    company = Company.objects.annotate(
        chairs_needed=F('num_employees') - F('num_chairs'))

If the fields that you're combining are of different types you'll need
to tell Django what kind of field will be returned. Since ``F()`` does not
directly support ``output_field`` you will need to wrap the expression with
:class:`ExpressionWrapper`::

    from django.db.models import DateTimeField, ExpressionWrapper, F

    Ticket.objects.annotate(
        expires=ExpressionWrapper(
            F('active_at') + F('duration'), output_field=DateTimeField()))

.. _func-expressions:

@@ -278,17 +299,6 @@ should define the desired ``output_field``. For example, adding an
``IntegerField()`` and a ``FloatField()`` together should probably have
``output_field=FloatField()`` defined.

.. note::

    When you need to define the ``output_field`` for ``F`` expression
    arithmetic between different types, it's necessary to surround the
    expression in another expression::

        from django.db.models import DateTimeField, Expression, F

        Race.objects.annotate(finish=Expression(
            F('start') + F('duration'), output_field=DateTimeField()))

.. versionchanged:: 1.8

    ``output_field`` is a new parameter.
@@ -347,6 +357,19 @@ instantiating the model field as any arguments relating to data validation
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
output value.

``ExpressionWrapper()`` expressions
-----------------------------------

.. class:: ExpressionWrapper(expression, output_field)

.. versionadded:: 1.8

``ExpressionWrapper`` simply surrounds another expression and provides access
to properties, such as ``output_field``, that may not be available on other
expressions. ``ExpressionWrapper`` is necessary when using arithmetic on
``F()`` expressions with different types as described in
:ref:`using-f-with-annotations`.

Conditional expressions
-----------------------

+9 −0
Original line number Diff line number Diff line
@@ -84,3 +84,12 @@ class Company(models.Model):
        return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)'
            % (self.name, self.motto, self.ticker_name, self.description)
        )


@python_2_unicode_compatible
class Ticket(models.Model):
    active_at = models.DateTimeField()
    duration = models.DurationField()

    def __str__(self):
        return '{} - {}'.format(self.active_at, self.duration)
+23 −2
Original line number Diff line number Diff line
@@ -5,12 +5,15 @@ from decimal import Decimal

from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db.models import (
    F, BooleanField, CharField, Count, Func, IntegerField, Sum, Value,
    F, BooleanField, CharField, Count, DateTimeField, ExpressionWrapper, Func,
    IntegerField, Sum, Value,
)
from django.test import TestCase
from django.utils import six

from .models import Author, Book, Company, DepartmentStore, Employee, Store
from .models import (
    Author, Book, Company, DepartmentStore, Employee, Publisher, Store, Ticket,
)


def cxOracle_513_py3_bug(func):
@@ -52,6 +55,24 @@ class NonAggregateAnnotationTestCase(TestCase):
        for book in books:
            self.assertEqual(book.num_awards, book.publisher.num_awards)

    def test_mixed_type_annotation_date_interval(self):
        active = datetime.datetime(2015, 3, 20, 14, 0, 0)
        duration = datetime.timedelta(hours=1)
        expires = datetime.datetime(2015, 3, 20, 14, 0, 0) + duration
        Ticket.objects.create(active_at=active, duration=duration)
        t = Ticket.objects.annotate(
            expires=ExpressionWrapper(F('active_at') + F('duration'), output_field=DateTimeField())
        ).first()
        self.assertEqual(t.expires, expires)

    def test_mixed_type_annotation_numbers(self):
        test = self.b1
        b = Book.objects.annotate(
            combined=ExpressionWrapper(F('pages') + F('rating'), output_field=IntegerField())
        ).get(isbn=test.isbn)
        combined = int(test.pages + test.rating)
        self.assertEqual(b.combined, combined)

    def test_annotate_with_aggregation(self):
        books = Book.objects.annotate(
            is_book=Value(1, output_field=IntegerField()),
Loading