Commit 53361589 authored by Kai Feldhoff's avatar Kai Feldhoff Committed by Tim Graham
Browse files

Fixed #25759 -- Added keyword arguments to customize Expressions' as_sql().

parent f1db8c36
Loading
Loading
Loading
Loading
+26 −15
Original line number Diff line number Diff line
@@ -534,7 +534,7 @@ class Func(Expression):
            c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
        return c

    def as_sql(self, compiler, connection, function=None, template=None):
    def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context):
        connection.ops.check_expression_support(self)
        sql_parts = []
        params = []
@@ -542,13 +542,19 @@ class Func(Expression):
            arg_sql, arg_params = compiler.compile(arg)
            sql_parts.append(arg_sql)
            params.extend(arg_params)
        if function is None:
            self.extra['function'] = self.extra.get('function', self.function)
        data = self.extra.copy()
        data.update(**extra_context)
        # Use the first supplied value in this order: the parameter to this
        # method, a value supplied in __init__()'s **extra (the value in
        # `data`), or the value defined on the class.
        if function is not None:
            data['function'] = function
        else:
            self.extra['function'] = function
        self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts)
        template = template or self.extra.get('template', self.template)
        return template % self.extra, params
            data.setdefault('function', self.function)
        template = template or data.get('template', self.template)
        arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner)
        data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
        return template % data, params

    def as_sqlite(self, compiler, connection):
        sql, params = self.as_sql(compiler, connection)
@@ -778,9 +784,9 @@ class When(Expression):
        c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
        return c

    def as_sql(self, compiler, connection, template=None):
    def as_sql(self, compiler, connection, template=None, **extra_context):
        connection.ops.check_expression_support(self)
        template_params = {}
        template_params = extra_context
        sql_params = []
        condition_sql, condition_params = compiler.compile(self.condition)
        template_params['condition'] = condition_sql
@@ -822,6 +828,7 @@ class Case(Expression):
        super(Case, self).__init__(output_field)
        self.cases = list(cases)
        self.default = self._parse_expressions(default)[0]
        self.extra = extra

    def __str__(self):
        return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
@@ -849,22 +856,24 @@ class Case(Expression):
        c.cases = c.cases[:]
        return c

    def as_sql(self, compiler, connection, template=None, extra=None):
    def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context):
        connection.ops.check_expression_support(self)
        if not self.cases:
            return compiler.compile(self.default)
        template_params = dict(extra) if extra else {}
        template_params = self.extra.copy()
        template_params.update(extra_context)
        case_parts = []
        sql_params = []
        for case in self.cases:
            case_sql, case_params = compiler.compile(case)
            case_parts.append(case_sql)
            sql_params.extend(case_params)
        template_params['cases'] = self.case_joiner.join(case_parts)
        case_joiner = case_joiner or self.case_joiner
        template_params['cases'] = case_joiner.join(case_parts)
        default_sql, default_params = compiler.compile(self.default)
        template_params['default'] = default_sql
        sql_params.extend(default_params)
        template = template or self.template
        template = template or template_params.get('template', self.template)
        sql = template % template_params
        if self._output_field_or_none is not None:
            sql = connection.ops.unification_cast_sql(self.output_field) % sql
@@ -995,14 +1004,16 @@ class OrderBy(BaseExpression):
    def get_source_expressions(self):
        return [self.expression]

    def as_sql(self, compiler, connection):
    def as_sql(self, compiler, connection, template=None, **extra_context):
        connection.ops.check_expression_support(self)
        expression_sql, params = compiler.compile(self.expression)
        placeholders = {
            'expression': expression_sql,
            'ordering': 'DESC' if self.descending else 'ASC',
        }
        return (self.template % placeholders).rstrip(), params
        placeholders.update(extra_context)
        template = template or self.template
        return (template % placeholders).rstrip(), params

    def get_group_by_cols(self):
        cols = []
+1 −2
Original line number Diff line number Diff line
@@ -43,9 +43,8 @@ class ConcatPair(Func):

    def as_sqlite(self, compiler, connection):
        coalesced = self.coalesce()
        coalesced.arg_joiner = ' || '
        return super(ConcatPair, coalesced).as_sql(
            compiler, connection, template='%(expressions)s',
            compiler, connection, template='%(expressions)s', arg_joiner=' || '
        )

    def as_mysql(self, compiler, connection):
+13 −7
Original line number Diff line number Diff line
@@ -261,12 +261,13 @@ The ``Func`` API is as follows:
        different number of expressions, ``TypeError`` will be raised. Defaults
        to ``None``.

    .. method:: as_sql(compiler, connection, function=None, template=None)
    .. method:: as_sql(compiler, connection, function=None, template=None, arg_joiner=None, **extra_context)

        Generates the SQL for the database function.

        The ``as_vendor()`` methods should use the ``function`` and
        ``template`` parameters to customize the SQL as needed. For example:
        The ``as_vendor()`` methods should use the ``function``, ``template``,
        ``arg_joiner``, and any other ``**extra_context`` parameters to
        customize the SQL as needed. For example:

        .. snippet::
            :filename: django/db/models/functions.py
@@ -283,6 +284,11 @@ The ``Func`` API is as follows:
                        template="%(function)s('', %(expressions)s)",
                    )

        .. versionchanged:: 1.10

            Support for the ``arg_joiner`` and ``**extra_context`` parameters
            was added.

The ``*expressions`` argument is a list of positional expressions that the
function will be applied to. The expressions will be converted to strings,
joined together with ``arg_joiner``, and then interpolated into the ``template``
@@ -293,10 +299,10 @@ assumed to be column references and will be wrapped in ``F()`` expressions
while other values will be wrapped in ``Value()`` expressions.

The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. Note that the keywords ``function`` and
``template`` can be used to replace the ``function`` and ``template``
attributes respectively, without having to define your own class.
``output_field`` can be used to define the expected return type.
into the ``template`` attribute. The ``function``, ``template``, and
``arg_joiner`` keywords can be used to replace the attributes of the same name
without having to define your own class. ``output_field`` can be used to define
the expected return type.

``Aggregate()`` expressions
---------------------------
+7 −0
Original line number Diff line number Diff line
@@ -212,6 +212,13 @@ Database backends
  ``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys
  on objects created using ``QuerySet.bulk_create()``.

* Added keyword arguments to the ``as_sql()`` methods of various expressions
  (``Func``, ``When``, ``Case``, and ``OrderBy``) to allow database backends to
  customize them without mutating ``self``, which isn't safe when using
  different database backends. See the ``arg_joiner`` and ``**extra_context``
  parameters of :meth:`Func.as_sql() <django.db.models.Func.as_sql>` for an
  example.

Email
~~~~~