Commit 5b3c66d8 authored by Alberto Avila's avatar Alberto Avila Committed by Tim Graham
Browse files

[1.8.x] Fixed #26071 -- Fixed crash with __in lookup in a Case expression.

Partial backport of afe0bb7b from master.
parent e625859f
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin):
            bilateral_transforms.append((self.__class__, self.init_lookups))
        return bilateral_transforms

    @cached_property
    def contains_aggregate(self):
        return self.lhs.contains_aggregate


class Lookup(RegisterLookupMixin):
    lookup_name = None
@@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin):
    def as_sql(self, compiler, connection):
        raise NotImplementedError

    @cached_property
    def contains_aggregate(self):
        return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)


class BuiltinLookup(Lookup):
    def process_lhs(self, compiler, connection, lhs=None):
+13 −3
Original line number Diff line number Diff line
@@ -315,9 +315,9 @@ class WhereNode(tree.Node):

    @classmethod
    def _contains_aggregate(cls, obj):
        if not isinstance(obj, tree.Node):
            return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False)
        if isinstance(obj, tree.Node):
            return any(cls._contains_aggregate(c) for c in obj.children)
        return obj.contains_aggregate

    @cached_property
    def contains_aggregate(self):
@@ -336,6 +336,7 @@ class EverythingNode(object):
    """
    A node that matches everything.
    """
    contains_aggregate = False

    def as_sql(self, compiler=None, connection=None):
        return '', []
@@ -345,11 +346,16 @@ class NothingNode(object):
    """
    A node that matches nothing.
    """
    contains_aggregate = False

    def as_sql(self, compiler=None, connection=None):
        raise EmptyResultSet


class ExtraWhere(object):
    # The contents are a black box - assume no aggregates are used.
    contains_aggregate = False

    def __init__(self, sqls, params):
        self.sqls = sqls
        self.params = params
@@ -410,6 +416,10 @@ class Constraint(object):


class SubqueryConstraint(object):
    # Even if aggregates would be used in a subquery, the outer query isn't
    # interested about those.
    contains_aggregate = False

    def __init__(self, alias, columns, targets, query_object):
        self.alias = alias
        self.columns = columns
+3 −0
Original line number Diff line number Diff line
@@ -23,3 +23,6 @@ Bugfixes
  ``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that
  already had the other specified, or when removing one of them from a field
  that had both (:ticket:`26034`).

* Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression
  (:ticket:`26071`).
+12 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from uuid import UUID

from django.core.exceptions import FieldError
from django.db import connection, models
from django.db.models import F, Q, Max, Min, Value
from django.db.models import F, Q, Max, Min, Sum, Value
from django.db.models.expressions import Case, When
from django.test import TestCase
from django.utils import six
@@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase):
            transform=attrgetter('integer', 'join_test')
        )

    def test_annotate_with_in_clause(self):
        fk_rels = FKCaseTestModel.objects.filter(integer__in=[5])
        self.assertQuerysetEqual(
            CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case(
                When(fk_rel__in=fk_rels, then=F('fk_rel__integer')),
                default=Value(0),
            ))).order_by('pk'),
            [(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)],
            transform=attrgetter('integer', 'in_test')
        )

    def test_annotate_with_join_in_condition(self):
        self.assertQuerysetEqual(
            CaseTestModel.objects.annotate(join_test=Case(