Commit bd283aa8 authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Refactored the empty/full result logic in WhereNode.as_sql()

Made sure the WhereNode.as_sql() handles various EmptyResultSet and
FullResultSet conditions correctly. Also, got rid of the FullResultSet
exception class. It is now represented by '', [] return value in the
as_sql() methods.
parent 2b9fb2e6
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -6,9 +6,6 @@ the SQL domain.
class EmptyResultSet(Exception):
    pass

class FullResultSet(Exception):
    pass

class MultiJoin(Exception):
    """
    Used by join construction code to indicate the point at which a
+48 −35
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from itertools import repeat

from django.utils import tree
from django.db.models.fields import Field
from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.aggregates import Aggregate

# Connection types
@@ -75,17 +75,21 @@ class WhereNode(tree.Node):
    def as_sql(self, qn, connection):
        """
        Returns the SQL version of the where clause and the value to be
        substituted in. Returns None, None if this node is empty.

        If 'node' is provided, that is the root of the SQL generation
        (generally not needed except by the internal implementation for
        recursion).
        substituted in. Returns '', [] if this node matches everything,
        None, [] if this node is empty, and raises EmptyResultSet if this
        node can't match anything.
        """
        if not self.children:
            return None, []
        # Note that the logic here is made slightly more complex than
        # necessary because there are two kind of empty nodes: Nodes
        # containing 0 children, and nodes that are known to match everything.
        # A match-everything node is different than empty node (which also
        # technically matches everything) for backwards compatibility reasons.
        # Refs #5261.
        result = []
        result_params = []
        empty = True
        everything_childs, nothing_childs = 0, 0
        non_empty_childs = len(self.children)

        for child in self.children:
            try:
                if hasattr(child, 'as_sql'):
@@ -93,39 +97,48 @@ class WhereNode(tree.Node):
                else:
                    # A leaf node in the tree.
                    sql, params = self.make_atom(child, qn, connection)

            except EmptyResultSet:
                if self.connector == AND and not self.negated:
                    # We can bail out early in this particular case (only).
                    raise
                elif self.negated:
                    empty = False
                nothing_childs += 1
            else:
                if sql:
                    result.append(sql)
                    result_params.extend(params)
                else:
                    if sql is None:
                        # Skip empty childs totally.
                        non_empty_childs -= 1
                        continue
            except FullResultSet:
                if self.connector == OR:
                    everything_childs += 1
            # Check if this node matches nothing or everything.
            # First check the amount of full nodes and empty nodes
            # to make this node empty/full.
            if self.connector == AND:
                full_needed, empty_needed = non_empty_childs, 1
            else:
                full_needed, empty_needed = 1, non_empty_childs
            # Now, check if this node is full/empty using the
            # counts.
            if empty_needed - nothing_childs <= 0:
                if self.negated:
                        empty = True
                        break
                    # We match everything. No need for any constraints.
                    return '', []
                else:
                    raise EmptyResultSet
            if full_needed - everything_childs <= 0:
                if self.negated:
                    empty = True
                continue

            empty = False
            if sql:
                result.append(sql)
                result_params.extend(params)
        if empty:
                    raise EmptyResultSet
                else:
                    return '', []

        if non_empty_childs == 0:
            # All the child nodes were empty, so this one is empty, too.
            return None, []
        conn = ' %s ' % self.connector
        sql_string = conn.join(result)
        if sql_string:
            if self.negated:
                sql_string = 'NOT (%s)' % sql_string
            elif len(self.children) != 1:
            if len(result) > 1:
                sql_string = '(%s)' % sql_string
            if self.negated:
                sql_string = 'NOT %s' % sql_string
        return sql_string, result_params

    def make_atom(self, child, qn, connection):
@@ -261,7 +274,7 @@ class EverythingNode(object):
    """

    def as_sql(self, qn=None, connection=None):
        raise FullResultSet
        return '', []

    def relabel_aliases(self, change_map, node=None):
        return
+86 −0
Original line number Diff line number Diff line
@@ -10,6 +10,8 @@ from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, connections, DEFAULT_DB_ALIAS
from django.db.models import Count
from django.db.models.query import Q, ITER_CHUNK_SIZE, EmptyQuerySet
from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode
from django.db.models.sql.datastructures import EmptyResultSet
from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import str_prefix
from django.utils import unittest
@@ -1316,10 +1318,23 @@ class Queries5Tests(TestCase):
        )

    def test_ticket5261(self):
        # Test different empty excludes.
        self.assertQuerysetEqual(
            Note.objects.exclude(Q()),
            ['<Note: n1>', '<Note: n2>']
        )
        self.assertQuerysetEqual(
            Note.objects.filter(~Q()),
            ['<Note: n1>', '<Note: n2>']
        )
        self.assertQuerysetEqual(
            Note.objects.filter(~Q()|~Q()),
            ['<Note: n1>', '<Note: n2>']
        )
        self.assertQuerysetEqual(
            Note.objects.exclude(~Q()&~Q()),
            ['<Note: n1>', '<Note: n2>']
        )


class SelectRelatedTests(TestCase):
@@ -2020,3 +2035,74 @@ class ProxyQueryCleanupTest(TestCase):
        self.assertEqual(qs.count(), 1)
        str(qs.query)
        self.assertEqual(qs.count(), 1)

class WhereNodeTest(TestCase):
    class DummyNode(object):
        def as_sql(self, qn, connection):
            return 'dummy', []

    def test_empty_full_handling_conjunction(self):
        qn = connection.ops.quote_name
        w = WhereNode(children=[EverythingNode()])
        self.assertEquals(w.as_sql(qn, connection), ('', []))
        w.negate()
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w = WhereNode(children=[NothingNode()])
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), ('', []))
        w = WhereNode(children=[EverythingNode(), EverythingNode()])
        self.assertEquals(w.as_sql(qn, connection), ('', []))
        w.negate()
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w = WhereNode(children=[EverythingNode(), self.DummyNode()])
        self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
        w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
        self.assertEquals(w.as_sql(qn, connection), ('(dummy AND dummy)', []))
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy AND dummy)', []))
        w = WhereNode(children=[NothingNode(), self.DummyNode()])
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), ('', []))

    def test_empty_full_handling_disjunction(self):
        qn = connection.ops.quote_name
        w = WhereNode(children=[EverythingNode()], connector='OR')
        self.assertEquals(w.as_sql(qn, connection), ('', []))
        w.negate()
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w = WhereNode(children=[NothingNode()], connector='OR')
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), ('', []))
        w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR')
        self.assertEquals(w.as_sql(qn, connection), ('', []))
        w.negate()
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR')
        self.assertEquals(w.as_sql(qn, connection), ('', []))
        w.negate()
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
        w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR')
        self.assertEquals(w.as_sql(qn, connection), ('(dummy OR dummy)', []))
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy OR dummy)', []))
        w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR')
        self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), ('NOT dummy', []))

    def test_empty_nodes(self):
        qn = connection.ops.quote_name
        empty_w = WhereNode()
        w = WhereNode(children=[empty_w, empty_w])
        self.assertEquals(w.as_sql(qn, connection), (None, []))
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), (None, []))
        w.connector = 'OR'
        self.assertEquals(w.as_sql(qn, connection), (None, []))
        w.negate()
        self.assertEquals(w.as_sql(qn, connection), (None, []))
        w = WhereNode(children=[empty_w, NothingNode()], connector='OR')
        self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)