Commit 469f1e36 authored by Andriy Sokolovskiy's avatar Andriy Sokolovskiy Committed by Tim Graham
Browse files

[1.8.x] Fixed #24833 -- Fixed Case expressions with exclude().

parent 7dcfbb2e
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -263,3 +263,17 @@ def refs_aggregate(lookup_parts, aggregates):
        if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate:
            return aggregates[level_n_lookup], lookup_parts[n:]
    return False, ()


def refs_expression(lookup_parts, annotations):
    """
    A helper method to check if the lookup_parts contains references
    to the given annotations set. Because the LOOKUP_SEP is contained in the
    default annotation names we must check each prefix of the lookup_parts
    for a match.
    """
    for n in range(len(lookup_parts) + 1):
        level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
        if level_n_lookup in annotations and annotations[level_n_lookup]:
            return annotations[level_n_lookup], lookup_parts[n:]
    return False, ()
+10 −8
Original line number Diff line number Diff line
@@ -17,7 +17,9 @@ from django.db import DEFAULT_DB_ALIAS, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
from django.db.models.query_utils import Q, PathInfo, refs_aggregate
from django.db.models.query_utils import (
    Q, PathInfo, refs_aggregate, refs_expression,
)
from django.db.models.sql.constants import (
    INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE,
)
@@ -1027,9 +1029,9 @@ class Query(object):
        """
        lookup_splitted = lookup.split(LOOKUP_SEP)
        if self._annotations:
            aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations)
            if aggregate:
                return aggregate_lookups, (), aggregate
            expression, expression_lookups = refs_expression(lookup_splitted, self.annotations)
            if expression:
                return expression_lookups, (), expression
        _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
        field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
        if len(lookup_parts) == 0:
@@ -1144,7 +1146,7 @@ class Query(object):
        arg, value = filter_expr
        if not arg:
            raise FieldError("Cannot parse keyword query %r" % arg)
        lookups, parts, reffed_aggregate = self.solve_lookup_type(arg)
        lookups, parts, reffed_expression = self.solve_lookup_type(arg)
        if not allow_joins and len(parts) > 1:
            raise FieldError("Joined field references are not permitted in this query")

@@ -1153,12 +1155,12 @@ class Query(object):
        value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse, allow_joins)

        clause = self.where_class()
        if reffed_aggregate:
            condition = self.build_lookup(lookups, reffed_aggregate, value)
        if reffed_expression:
            condition = self.build_lookup(lookups, reffed_expression, value)
            if not condition:
                # Backwards compat for custom lookups
                assert len(lookups) == 1
                condition = (reffed_aggregate, lookups[0], value)
                condition = (reffed_expression, lookups[0], value)
            clause.add(condition, AND)
            return clause, []

+3 −0
Original line number Diff line number Diff line
@@ -50,3 +50,6 @@ Bugfixes

* Fixed recording of applied status for squashed (replacement) migrations
  (:ticket:`24628`).

* Fixed queryset annotations when using ``Case`` expressions with ``exclude()``
  (:ticket:`24833`).
+12 −0
Original line number Diff line number Diff line
@@ -240,6 +240,18 @@ class CaseExpressionTests(TestCase):
            transform=itemgetter('integer', 'max', 'test')
        )

    def test_annotate_exclude(self):
        self.assertQuerysetEqual(
            CaseTestModel.objects.annotate(test=Case(
                When(integer=1, then=Value('one')),
                When(integer=2, then=Value('two')),
                default=Value('other'),
                output_field=models.CharField(),
            )).exclude(test='other').order_by('pk'),
            [(1, 'one'), (2, 'two'), (2, 'two')],
            transform=attrgetter('integer', 'test')
        )

    def test_combined_expression(self):
        self.assertQuerysetEqual(
            CaseTestModel.objects.annotate(