Commit 21b858cb authored by Josh Smeaton's avatar Josh Smeaton
Browse files

Fixed #24060 -- Added OrderBy Expressions

parent f48e2258
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -300,7 +300,7 @@ class DatabaseOperations(BaseDatabaseOperations):
        columns. If no ordering would otherwise be applied, we don't want any
        implicit sorting going on.
        """
        return [(None, ("NULL", [], 'asc', False))]
        return [(None, ("NULL", [], False))]

    def fulltext_search_sql(self, field_name):
        return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name
+67 −10
Original line number Diff line number Diff line
@@ -118,7 +118,7 @@ class CombinableMixin(object):
        )


class ExpressionNode(CombinableMixin):
class BaseExpression(object):
    """
    Base class for all query expressions.
    """
@@ -189,6 +189,10 @@ class ExpressionNode(CombinableMixin):
        """
        c = self.copy()
        c.is_summary = summarize
        c.set_source_expressions([
            expr.resolve_expression(query, allow_joins, reuse, summarize)
            for expr in c.get_source_expressions()
        ])
        return c

    def _prepare(self):
@@ -319,6 +323,22 @@ class ExpressionNode(CombinableMixin):
        """
        return [e._output_field_or_none for e in self.get_source_expressions()]

    def asc(self):
        return OrderBy(self)

    def desc(self):
        return OrderBy(self, descending=True)

    def reverse_ordering(self):
        return self


class ExpressionNode(BaseExpression, CombinableMixin):
    """
    An expression that can be combined with other expressions.
    """
    pass


class Expression(ExpressionNode):

@@ -412,6 +432,12 @@ class F(CombinableMixin):
    def refs_aggregate(self, existing_aggregates):
        return refs_aggregate(self.name.split(LOOKUP_SEP), existing_aggregates)

    def asc(self):
        return OrderBy(self)

    def desc(self):
        return OrderBy(self, descending=True)


class Func(ExpressionNode):
    """
@@ -526,15 +552,6 @@ class Random(ExpressionNode):
        return connection.ops.random_function_sql(), []


class ColIndexRef(ExpressionNode):
    def __init__(self, idx):
        self.idx = idx
        super(ColIndexRef, self).__init__()

    def as_sql(self, compiler, connection):
        return str(self.idx), []


class Col(ExpressionNode):
    def __init__(self, alias, target, source=None):
        if source is None:
@@ -678,3 +695,43 @@ class DateTime(ExpressionNode):
            value = value.replace(tzinfo=None)
            value = timezone.make_aware(value, self.tzinfo)
        return value


class OrderBy(BaseExpression):
    template = '%(expression)s %(ordering)s'
    descending_template = 'DESC'
    ascending_template = 'ASC'

    def __init__(self, expression, descending=False):
        self.descending = descending
        if not hasattr(expression, 'resolve_expression'):
            raise ValueError('expression must be an expression type')
        self.expression = expression

    def set_source_expressions(self, exprs):
        self.expression = exprs[0]

    def get_source_expressions(self):
        return [self.expression]

    def as_sql(self, compiler, connection):
        expression_sql, params = compiler.compile(self.expression)
        placeholders = {'expression': expression_sql}
        placeholders['ordering'] = 'DESC' if self.descending else 'ASC'
        return (self.template % placeholders).rstrip(), params

    def get_group_by_cols(self):
        cols = []
        for source in self.get_source_expressions():
            cols.extend(source.get_group_by_cols())
        return cols

    def reverse_ordering(self):
        self.descending = not self.descending
        return self

    def asc(self):
        self.descending = False

    def desc(self):
        self.descending = True
+55 −31
Original line number Diff line number Diff line
from itertools import chain
import re
import warnings

from django.core.exceptions import FieldError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import RawSQL, Ref, Random, ColIndexRef
from django.db.models.expressions import OrderBy, Random, RawSQL, Ref
from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
        ORDER_DIR, GET_ITERATOR_CHUNK_SIZE)
@@ -28,6 +29,7 @@ class SQLCompiler(object):
        self.select = None
        self.annotation_col_map = None
        self.klass_info = None
        self.ordering_parts = re.compile(r'(.*)\s(ASC|DESC)(.*)')

    def setup_query(self):
        if all(self.query.alias_refcount[a] == 0 for a in self.query.tables):
@@ -105,14 +107,14 @@ class SQLCompiler(object):
            cols = expr.get_group_by_cols()
            for col in cols:
                expressions.append(col)
        for expr, _ in order_by:
        for expr, (sql, params, is_ref) in order_by:
            if expr.contains_aggregate:
                continue
            # We can skip References to select clause, as all expressions in
            # the select clause are already part of the group by.
            if isinstance(expr, Ref):
            if is_ref:
                continue
            expressions.append(expr)
            expressions.extend(expr.get_source_expressions())
        having = self.query.having.get_group_by_cols()
        for expr in having:
            expressions.append(expr)
@@ -234,54 +236,75 @@ class SQLCompiler(object):

        order_by = []
        for pos, field in enumerate(ordering):
            if field == '?':
                order_by.append((Random(), asc, False))
            if hasattr(field, 'resolve_expression'):
                if not isinstance(field, OrderBy):
                    field = field.asc()
                if not self.query.standard_ordering:
                    field.reverse_ordering()
                order_by.append((field, False))
                continue
            if isinstance(field, int):
                if field < 0:
                    field = -field
                    int_ord = desc
                order_by.append((ColIndexRef(field), int_ord, True))
            if field == '?':  # random
                order_by.append((OrderBy(Random()), False))
                continue

            col, order = get_order_dir(field, asc)
            descending = True if order == 'DESC' else False

            if col in self.query.annotation_select:
                order_by.append((Ref(col, self.query.annotation_select[col]), order, True))
                order_by.append((
                    OrderBy(Ref(col, self.query.annotation_select[col]), descending=descending),
                    True))
                continue

            if '.' in field:
                # This came in through an extra(order_by=...) addition. Pass it
                # on verbatim.
                table, col = col.split('.', 1)
                expr = RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), [])
                order_by.append((expr, order, False))
                order_by.append((
                    OrderBy(RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), [])),
                    False))
                continue
            if not self.query._extra or get_order_dir(field)[0] not in self.query._extra:

            if not self.query._extra or col not in self.query._extra:
                # 'col' is of the form 'field' or 'field1__field2' or
                # '-field1__field2__field', etc.
                order_by.extend(self.find_ordering_name(field, self.query.get_meta(),
                                                        default_order=asc))
                order_by.extend(self.find_ordering_name(
                    field, self.query.get_meta(), default_order=asc))
            else:
                if col not in self.query.extra_select:
                    order_by.append((RawSQL(*self.query.extra[col]), order, False))
                    order_by.append((
                        OrderBy(RawSQL(*self.query.extra[col]), descending=descending),
                        False))
                else:
                    order_by.append((Ref(col, RawSQL(*self.query.extra[col])),
                                     order, True))
                    order_by.append((
                        OrderBy(Ref(col, RawSQL(*self.query.extra[col])), descending=descending),
                        True))
        result = []
        seen = set()
        for expr, order, is_ref in order_by:
            sql, params = self.compile(expr)
            if (sql, tuple(params)) in seen:

        for expr, is_ref in order_by:
            resolved = expr.resolve_expression(
                self.query, allow_joins=True, reuse=None)
            sql, params = self.compile(resolved)
            # Don't add the same column twice, but the order direction is
            # not taken into account so we strip it. When this entire method
            # is refactored into expressions, then we can check each part as we
            # generate it.
            without_ordering = self.ordering_parts.search(sql).group(1)
            if (without_ordering, tuple(params)) in seen:
                continue
            seen.add((sql, tuple(params)))
            result.append((expr, (sql, params, order, is_ref)))
            seen.add((without_ordering, tuple(params)))
            result.append((resolved, (sql, params, is_ref)))
        return result

    def get_extra_select(self, order_by, select):
        extra_select = []
        select_sql = [t[1] for t in select]
        if self.query.distinct and not self.query.distinct_fields:
            for expr, (sql, params, _, is_ref) in order_by:
                if not is_ref and (sql, params) not in select_sql:
                    extra_select.append((expr, (sql, params), None))
            for expr, (sql, params, is_ref) in order_by:
                without_ordering = self.ordering_parts.search(sql).group(1)
                if not is_ref and (without_ordering, params) not in select_sql:
                    extra_select.append((expr, (without_ordering, params), None))
        return extra_select

    def __call__(self, name):
@@ -392,8 +415,8 @@ class SQLCompiler(object):

            if order_by:
                ordering = []
                for _, (o_sql, o_params, order, _) in order_by:
                    ordering.append('%s %s' % (o_sql, order))
                for _, (o_sql, o_params, _) in order_by:
                    ordering.append(o_sql)
                    params.extend(o_params)
                result.append('ORDER BY %s' % ', '.join(ordering))

@@ -514,6 +537,7 @@ class SQLCompiler(object):
        The 'name' is of the form 'field1__field2__...__fieldN'.
        """
        name, order = get_order_dir(name, default_order)
        descending = True if order == 'DESC' else False
        pieces = name.split(LOOKUP_SEP)
        field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias)

@@ -535,7 +559,7 @@ class SQLCompiler(object):
                                                       order, already_seen))
            return results
        targets, alias, _ = self.query.trim_joins(targets, joins, path)
        return [(t.get_col(alias), order, False) for t in targets]
        return [(OrderBy(t.get_col(alias), descending=descending), False) for t in targets]

    def _setup_joins(self, pieces, opts, alias):
        """
+3 −3
Original line number Diff line number Diff line
@@ -1691,14 +1691,14 @@ class Query(object):
        """
        Adds items from the 'ordering' sequence to the query's "order by"
        clause. These items are either field names (not column names) --
        possibly with a direction prefix ('-' or '?') -- or ordinals,
        corresponding to column positions in the 'select' list.
        possibly with a direction prefix ('-' or '?') -- or OrderBy
        expressions.

        If 'ordering' is empty, all ordering is cleared from the query.
        """
        errors = []
        for item in ordering:
            if not ORDER_PATTERN.match(item):
            if not hasattr(item, 'resolve_expression') and not ORDER_PATTERN.match(item):
                errors.append(item)
        if errors:
            raise FieldError('Invalid order_by arguments: %s' % errors)
+23 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ Query Expressions
.. currentmodule:: django.db.models

Query expressions describe a value or a computation that can be used as part of
a filter, an annotation, or an aggregation. There are a number of built-in
a filter, order by, annotation, or aggregate. There are a number of built-in
expressions (documented below) that can be used to help you write queries.
Expressions can be combined, or in some cases nested, to form more complex
computations.
@@ -58,6 +58,10 @@ Some examples
    # Aggregates can contain complex computations also
    Company.objects.annotate(num_offerings=Count(F('products') + F('services')))

    # Expressions can also be used in order_by()
    Company.objects.order_by(Length('name').asc())
    Company.objects.order_by(Length('name').desc())


Built-in Expressions
====================
@@ -428,6 +432,24 @@ calling the appropriate methods on the wrapped expression.
        nested expressions. ``F()`` objects, in particular, hold a reference
        to a column.

    .. method:: asc()

        Returns the expression ready to be sorted in ascending order.

    .. method:: desc()

        Returns the expression ready to be sorted in descending order.

    .. method:: reverse_ordering()

        Returns ``self`` with any modifications required to reverse the sort
        order within an ``order_by`` call. As an example, an expression
        implementing ``NULLS LAST`` would change its value to be
        ``NULLS FIRST``. Modifications are only required for expressions that
        implement sort order like ``OrderBy``. This method is called when
        :meth:`~django.db.models.query.QuerySet.reverse()` is called on a
        queryset.

Writing your own Query Expressions
----------------------------------

Loading