Commit 00aa5628 authored by Thomas Chaumeny's avatar Thomas Chaumeny Committed by Anssi Kääriäinen
Browse files

Fixed #23493 -- Added bilateral attribute to Transform

parent 6b39401b
Loading
Loading
Loading
Loading
+83 −20
Original line number Diff line number Diff line
from copy import copy
from itertools import repeat
import inspect

from django.conf import settings
@@ -7,6 +6,8 @@ from django.utils import timezone
from django.utils.functional import cached_property
from django.utils.six.moves import xrange

from .query_utils import QueryWrapper


class RegisterLookupMixin(object):
    def _get_lookup(self, lookup_name):
@@ -57,6 +58,9 @@ class RegisterLookupMixin(object):


class Transform(RegisterLookupMixin):

    bilateral = False

    def __init__(self, lhs, lookups):
        self.lhs = lhs
        self.init_lookups = lookups[:]
@@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin):
class Lookup(RegisterLookupMixin):
    lookup_name = None

    def __init__(self, lhs, rhs):
    def __init__(self, lhs, rhs, bilateral_transforms=None):
        self.lhs, self.rhs = lhs, rhs
        self.rhs = self.get_prep_lookup()
        if bilateral_transforms is None:
            bilateral_transforms = []
        if bilateral_transforms:
            # We should warn the user as soon as possible if he is trying to apply
            # a bilateral transformation on a nested QuerySet: that won't work.
            # We need to import QuerySet here so as to avoid circular
            from django.db.models.query import QuerySet
            if isinstance(rhs, QuerySet):
                raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
        self.bilateral_transforms = bilateral_transforms

    def apply_bilateral_transforms(self, value):
        for transform, lookups in self.bilateral_transforms:
            value = transform(value, lookups)
        return value

    def batch_process_rhs(self, qn, connection, rhs=None):
        if rhs is None:
            rhs = self.rhs
        if self.bilateral_transforms:
            sqls, sqls_params = [], []
            for p in rhs:
                value = QueryWrapper('%s',
                    [self.lhs.output_field.get_db_prep_value(p, connection)])
                value = self.apply_bilateral_transforms(value)
                sql, sql_params = qn.compile(value)
                sqls.append(sql)
                sqls_params.extend(sql_params)
        else:
            params = self.lhs.output_field.get_db_prep_lookup(
                self.lookup_name, rhs, connection, prepared=True)
            sqls, sqls_params = ['%s'] * len(params), params
        return sqls, sqls_params

    def get_prep_lookup(self):
        return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
@@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin):

    def process_rhs(self, qn, connection):
        value = self.rhs
        if self.bilateral_transforms:
            if self.rhs_is_direct_value():
                # Do not call get_db_prep_lookup here as the value will be
                # transformed before being used for lookup
                value = QueryWrapper("%s",
                    [self.lhs.output_field.get_db_prep_value(value, connection)])
            value = self.apply_bilateral_transforms(value)
        # Due to historical reasons there are a couple of different
        # ways to produce sql here. get_compiler is likely a Query
        # instance, _as_sql QuerySet and as_sql just something with
@@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual
class In(BuiltinLookup):
    lookup_name = 'in'

    def get_db_prep_lookup(self, value, connection):
        params = self.lhs.output_field.get_db_prep_lookup(
            self.lookup_name, value, connection, prepared=True)
        if not params:
            # TODO: check why this leads to circular import
    def process_rhs(self, qn, connection):
        if self.rhs_is_direct_value():
            # rhs should be an iterable, we use batch_process_rhs
            # to prepare/transform those values
            rhs = list(self.rhs)
            if not rhs:
                from django.db.models.sql.datastructures import EmptyResultSet
                raise EmptyResultSet
        placeholder = '(' + ', '.join('%s' for p in params) + ')'
        return (placeholder, params)
            sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs)
            placeholder = '(' + ', '.join(sqls) + ')'
            return (placeholder, sqls_params)
        else:
            return super(In, self).process_rhs(qn, connection)

    def get_rhs_op(self, connection, rhs):
        return 'IN %s' % rhs
@@ -220,8 +268,10 @@ class In(BuiltinLookup):
        max_in_list_size = connection.ops.max_in_list_size()
        if self.rhs_is_direct_value() and (max_in_list_size and
                                           len(self.rhs) > max_in_list_size):
            rhs, rhs_params = self.process_rhs(qn, connection)
            # This is a special case for Oracle which limits the number of elements
            # which can appear in an 'IN' clause.
            lhs, lhs_params = self.process_lhs(qn, connection)
            rhs, rhs_params = self.batch_process_rhs(qn, connection)
            in_clause_elements = ['(']
            params = []
            for offset in xrange(0, len(rhs_params), max_in_list_size):
@@ -229,11 +279,12 @@ class In(BuiltinLookup):
                    in_clause_elements.append(' OR ')
                in_clause_elements.append('%s IN (' % lhs)
                params.extend(lhs_params)
                group_size = min(len(rhs_params) - offset, max_in_list_size)
                param_group = ', '.join(repeat('%s', group_size))
                sqls = rhs[offset: offset + max_in_list_size]
                sqls_params = rhs_params[offset: offset + max_in_list_size]
                param_group = ', '.join(sqls)
                in_clause_elements.append(param_group)
                in_clause_elements.append(')')
                params.extend(rhs_params[offset: offset + max_in_list_size])
                params.extend(sqls_params)
            in_clause_elements.append(')')
            return ''.join(in_clause_elements), params
        else:
@@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup):
        # we need to add the % pattern match to the lookup by something like
        #     col LIKE othercol || '%%'
        # So, for Python values we don't need any special pattern, but for
        # SQL reference values we need the correct pattern added.
        value = self.rhs
        if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
                or hasattr(value, '_as_sql')):
        # SQL reference values or SQL transformations we need the correct
        # pattern added.
        if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
                or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
            return connection.pattern_ops[self.lookup_name] % rhs
        else:
            return super(PatternLookup, self).get_rhs_op(connection, rhs)
@@ -291,8 +342,20 @@ class Year(Between):
default_lookups['year'] = Year


class Range(Between):
class Range(BuiltinLookup):
    lookup_name = 'range'

    def get_rhs_op(self, connection, rhs):
        return "BETWEEN %s AND %s" % (rhs[0], rhs[1])

    def process_rhs(self, qn, connection):
        if self.rhs_is_direct_value():
            # rhs should be an iterable of 2 values, we use batch_process_rhs
            # to prepare/transform those values
            return self.batch_process_rhs(qn, connection)
        else:
            return super(Range, self).process_rhs(qn, connection)

default_lookups['range'] = Range


+4 −1
Original line number Diff line number Diff line
@@ -1111,18 +1111,21 @@ class Query(object):

    def build_lookup(self, lookups, lhs, rhs):
        lookups = lookups[:]
        bilaterals = []
        while lookups:
            lookup = lookups[0]
            if len(lookups) == 1:
                final_lookup = lhs.get_lookup(lookup)
                if final_lookup:
                    return final_lookup(lhs, rhs)
                    return final_lookup(lhs, rhs, bilaterals)
                # We didn't find a lookup, so we are going to try get_transform
                # + get_lookup('exact').
                lookups.append('exact')
            next = lhs.get_transform(lookup)
            if next:
                lhs = next(lhs, lookups)
                if getattr(next, 'bilateral', False):
                    bilaterals.append((next, lookups))
            else:
                raise FieldError(
                    "Unsupported lookup '%s' for %s or join on the field not "
+43 −5
Original line number Diff line number Diff line
@@ -127,7 +127,7 @@ function ``ABS()`` to transform the value before comparison::
          lhs, params = qn.compile(self.lhs)
          return "ABS(%s)" % lhs, params

Next, lets register it for ``IntegerField``::
Next, let's register it for ``IntegerField``::

  from django.db.models import IntegerField
  IntegerField.register_lookup(AbsoluteValue)
@@ -144,9 +144,7 @@ SQL::

    SELECT ... WHERE ABS("experiments"."change") < 27

Subclasses of ``Transform`` usually only operate on the left-hand side of the
expression. Further lookups will work on the transformed value. Note that in
this case where there is no other lookup specified, Django interprets
Note that in case there is no other lookup specified, Django interprets
``change__abs=27`` as ``change__abs__exact=27``.

When looking for which lookups are allowable after the ``Transform`` has been
@@ -197,7 +195,7 @@ Notice also that as both sides are used multiple times in the query the params
need to contain ``lhs_params`` and ``rhs_params`` multiple times.

The final query does the inversion (``27`` to ``-27``) directly in the
database. The reason for doing this is that if the self.rhs is something else
database. The reason for doing this is that if the ``self.rhs`` is something else
than a plain integer value (for example an ``F()`` reference) we can't do the
transformations in Python.

@@ -208,6 +206,46 @@ transformations in Python.
    want to add an index on ``abs(change)`` which would allow these queries to
    be very efficient.

A bilateral transformer example
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``AbsoluteValue`` example we discussed previously is a transformation which
applies to the left-hand side of the lookup. There may be some cases where you
want the transformation to be applied to both the left-hand side and the
right-hand side. For instance, if you want to filter a queryset based on the
equality of the left and right-hand side insensitively to some SQL function.

Let's examine the simple example of case-insensitive transformation here. This
transformation isn't very useful in practice as Django already comes with a bunch
of built-in case-insensitive lookups, but it will be a nice demonstration of
bilateral transformations in a database-agnostic way.

We define an ``UpperCase`` transformer which uses the SQL function ``UPPER()`` to
transform the values before comparison. We define
:attr:`bilateral = True <django.db.models.Transform.bilateral>` to indicate that
this transformation should apply to both ``lhs`` and ``rhs``::

  from django.db.models import Transform

  class UpperCase(Transform):
      lookup_name = 'upper'
      bilateral = True

      def as_sql(self, qn, connection):
          lhs, params = qn.compile(self.lhs)
          return "UPPER(%s)" % lhs, params

Next, let's register it::

  from django.db.models import CharField, TextField
  CharField.register_lookup(UpperCase)
  TextField.register_lookup(UpperCase)

Now, the queryset ``Author.objects.filter(name__upper="doe")`` will generate a case
insensitive query like this::

    SELECT ... WHERE UPPER("author"."name") = UPPER('doe')

Writing alternative implementations for existing lookups
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

+9 −0
Original line number Diff line number Diff line
@@ -129,6 +129,15 @@ Transform reference
    This class follows the :ref:`Query Expression API <query-expression>`, which
    implies that you can use ``<expression>__<transform1>__<transform2>``.

    .. attribute:: bilateral

        .. versionadded:: 1.8

        A boolean indicating whether this transformation should apply to both
        ``lhs`` and ``rhs``. Bilateral transformations will be applied to ``rhs`` in
        the same order as they appear in the lookup expression. By default it is set
        to ``False``. For example usage, see :doc:`/howto/custom-lookups`.

    .. attribute:: lhs

        The left-hand side - what is being transformed. It must follow the
+5 −0
Original line number Diff line number Diff line
@@ -306,6 +306,11 @@ Models
* :doc:`Custom Lookups</howto/custom-lookups>` can now be registered using
  a decorator pattern.

* The new :attr:`Transform.bilateral <django.db.models.Transform.bilateral>`
  attribute allows creating bilateral transformations. These transformations
  are applied to both ``lhs`` and ``rhs`` when used in a lookup expression,
  providing opportunities for more sophisticated lookups.

Signals
^^^^^^^

Loading