Commit 534aaf56 authored by Josh Smeaton's avatar Josh Smeaton
Browse files

Fixed #24629 -- Unified Transform and Expression APIs

parent 8dc3ba5c
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -81,14 +81,14 @@ class KeyTransformFactory(object):


@HStoreField.register_lookup
class KeysTransform(lookups.FunctionTransform):
class KeysTransform(Transform):
    lookup_name = 'keys'
    function = 'akeys'
    output_field = ArrayField(TextField())


@HStoreField.register_lookup
class ValuesTransform(lookups.FunctionTransform):
class ValuesTransform(Transform):
    lookup_name = 'values'
    function = 'avals'
    output_field = ArrayField(TextField())
+3 −3
Original line number Diff line number Diff line
@@ -173,7 +173,7 @@ class AdjacentToLookup(lookups.PostgresSimpleLookup):


@RangeField.register_lookup
class RangeStartsWith(lookups.FunctionTransform):
class RangeStartsWith(models.Transform):
    lookup_name = 'startswith'
    function = 'lower'

@@ -183,7 +183,7 @@ class RangeStartsWith(lookups.FunctionTransform):


@RangeField.register_lookup
class RangeEndsWith(lookups.FunctionTransform):
class RangeEndsWith(models.Transform):
    lookup_name = 'endswith'
    function = 'upper'

@@ -193,7 +193,7 @@ class RangeEndsWith(lookups.FunctionTransform):


@RangeField.register_lookup
class IsEmpty(lookups.FunctionTransform):
class IsEmpty(models.Transform):
    lookup_name = 'isempty'
    function = 'isempty'
    output_field = models.BooleanField()
+1 −7
Original line number Diff line number Diff line
@@ -9,12 +9,6 @@ class PostgresSimpleLookup(Lookup):
        return '%s %s %s' % (lhs, self.operator, rhs), params


class FunctionTransform(Transform):
    def as_sql(self, qn, connection):
        lhs, params = qn.compile(self.lhs)
        return "%s(%s)" % (self.function, lhs), params


class DataContains(PostgresSimpleLookup):
    lookup_name = 'contains'
    operator = '@>'
@@ -45,7 +39,7 @@ class HasAnyKeys(PostgresSimpleLookup):
    operator = '?|'


class Unaccent(FunctionTransform):
class Unaccent(Transform):
    bilateral = True
    lookup_name = 'unaccent'
    function = 'UNACCENT'
+1 −164
Original line number Diff line number Diff line
@@ -20,10 +20,7 @@ from django.core import checks, exceptions, validators
# purposes.
from django.core.exceptions import FieldDoesNotExist  # NOQA
from django.db import connection, connections, router
from django.db.models.lookups import (
    Lookup, RegisterLookupMixin, Transform, default_lookups,
)
from django.db.models.query_utils import QueryWrapper
from django.db.models.query_utils import QueryWrapper, RegisterLookupMixin
from django.utils import six, timezone
from django.utils.datastructures import DictWrapper
from django.utils.dateparse import (
@@ -120,7 +117,6 @@ class Field(RegisterLookupMixin):
        'unique_for_date': _("%(field_label)s must be unique for "
                             "%(date_field_label)s %(lookup_type)s."),
    }
    class_lookups = default_lookups.copy()
    system_check_deprecated_details = None
    system_check_removed_details = None

@@ -1492,22 +1488,6 @@ class DateTimeField(DateField):
        return super(DateTimeField, self).formfield(**defaults)


@DateTimeField.register_lookup
class DateTimeDateTransform(Transform):
    lookup_name = 'date'

    @cached_property
    def output_field(self):
        return DateField()

    def as_sql(self, compiler, connection):
        lhs, lhs_params = compiler.compile(self.lhs)
        tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
        sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
        lhs_params.extend(tz_params)
        return sql, lhs_params


class DecimalField(Field):
    empty_strings_allowed = False
    default_error_messages = {
@@ -2450,146 +2430,3 @@ class UUIDField(Field):
        }
        defaults.update(kwargs)
        return super(UUIDField, self).formfield(**defaults)


class DateTransform(Transform):
    def as_sql(self, compiler, connection):
        sql, params = compiler.compile(self.lhs)
        lhs_output_field = self.lhs.output_field
        if isinstance(lhs_output_field, DateTimeField):
            tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
            sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
            params.extend(tz_params)
        elif isinstance(lhs_output_field, DateField):
            sql = connection.ops.date_extract_sql(self.lookup_name, sql)
        elif isinstance(lhs_output_field, TimeField):
            sql = connection.ops.time_extract_sql(self.lookup_name, sql)
        else:
            raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
        return sql, params

    @cached_property
    def output_field(self):
        return IntegerField()


class YearTransform(DateTransform):
    lookup_name = 'year'


class YearLookup(Lookup):
    def year_lookup_bounds(self, connection, year):
        output_field = self.lhs.lhs.output_field
        if isinstance(output_field, DateTimeField):
            bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
        else:
            bounds = connection.ops.year_lookup_bounds_for_date_field(year)
        return bounds


@YearTransform.register_lookup
class YearExact(YearLookup):
    lookup_name = 'exact'

    def as_sql(self, compiler, connection):
        # We will need to skip the extract part and instead go
        # directly with the originating field, that is self.lhs.lhs.
        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        bounds = self.year_lookup_bounds(connection, rhs_params[0])
        params.extend(bounds)
        return '%s BETWEEN %%s AND %%s' % lhs_sql, params


class YearComparisonLookup(YearLookup):
    def as_sql(self, compiler, connection):
        # We will need to skip the extract part and instead go
        # directly with the originating field, that is self.lhs.lhs.
        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        rhs_sql = self.get_rhs_op(connection, rhs_sql)
        start, finish = self.year_lookup_bounds(connection, rhs_params[0])
        params.append(self.get_bound(start, finish))
        return '%s %s' % (lhs_sql, rhs_sql), params

    def get_rhs_op(self, connection, rhs):
        return connection.operators[self.lookup_name] % rhs

    def get_bound(self):
        raise NotImplementedError(
            'subclasses of YearComparisonLookup must provide a get_bound() method'
        )


@YearTransform.register_lookup
class YearGt(YearComparisonLookup):
    lookup_name = 'gt'

    def get_bound(self, start, finish):
        return finish


@YearTransform.register_lookup
class YearGte(YearComparisonLookup):
    lookup_name = 'gte'

    def get_bound(self, start, finish):
        return start


@YearTransform.register_lookup
class YearLt(YearComparisonLookup):
    lookup_name = 'lt'

    def get_bound(self, start, finish):
        return start


@YearTransform.register_lookup
class YearLte(YearComparisonLookup):
    lookup_name = 'lte'

    def get_bound(self, start, finish):
        return finish


class MonthTransform(DateTransform):
    lookup_name = 'month'


class DayTransform(DateTransform):
    lookup_name = 'day'


class WeekDayTransform(DateTransform):
    lookup_name = 'week_day'


class HourTransform(DateTransform):
    lookup_name = 'hour'


class MinuteTransform(DateTransform):
    lookup_name = 'minute'


class SecondTransform(DateTransform):
    lookup_name = 'second'


DateField.register_lookup(YearTransform)
DateField.register_lookup(MonthTransform)
DateField.register_lookup(DayTransform)
DateField.register_lookup(WeekDayTransform)

TimeField.register_lookup(HourTransform)
TimeField.register_lookup(MinuteTransform)
TimeField.register_lookup(SecondTransform)

DateTimeField.register_lookup(YearTransform)
DateTimeField.register_lookup(MonthTransform)
DateTimeField.register_lookup(DayTransform)
DateTimeField.register_lookup(WeekDayTransform)
DateTimeField.register_lookup(HourTransform)
DateTimeField.register_lookup(MinuteTransform)
DateTimeField.register_lookup(SecondTransform)
+9 −5
Original line number Diff line number Diff line
"""
Classes that represent database functions.
"""
from django.db.models import DateTimeField, IntegerField
from django.db.models.expressions import Func, Value
from django.db.models import (
    DateTimeField, Func, IntegerField, Transform, Value,
)


class Coalesce(Func):
@@ -123,9 +124,10 @@ class Least(Func):
        return super(Least, self).as_sql(compiler, connection, function='MIN')


class Length(Func):
class Length(Transform):
    """Returns the number of characters in the expression"""
    function = 'LENGTH'
    lookup_name = 'length'

    def __init__(self, expression, **extra):
        output_field = extra.pop('output_field', IntegerField())
@@ -136,8 +138,9 @@ class Length(Func):
        return super(Length, self).as_sql(compiler, connection)


class Lower(Func):
class Lower(Transform):
    function = 'LOWER'
    lookup_name = 'lower'

    def __init__(self, expression, **extra):
        super(Lower, self).__init__(expression, **extra)
@@ -188,8 +191,9 @@ class Substr(Func):
        return super(Substr, self).as_sql(compiler, connection)


class Upper(Func):
class Upper(Transform):
    function = 'UPPER'
    lookup_name = 'upper'

    def __init__(self, expression, **extra):
        super(Upper, self).__init__(expression, **extra)
Loading