Commit 2a4af0ea authored by Josh Smeaton's avatar Josh Smeaton
Browse files

Fixed #25774 -- Refactor datetime expressions into public API

parent 77b73e79
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
from django.db.models import DateTimeField
from django.db.models.functions import Func
from django.db.models import DateTimeField, Func


class TransactionNow(Func):
+1 −107
Original line number Diff line number Diff line
import copy
import datetime

from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.query_utils import Q
from django.utils import six, timezone
from django.utils import six
from django.utils.functional import cached_property


@@ -860,111 +859,6 @@ class Case(Expression):
        return sql, sql_params


class Date(Expression):
    """
    Add a date selection column.
    """
    def __init__(self, lookup, lookup_type):
        super(Date, self).__init__(output_field=fields.DateField())
        self.lookup = lookup
        self.col = None
        self.lookup_type = lookup_type

    def __repr__(self):
        return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)

    def get_source_expressions(self):
        return [self.col]

    def set_source_expressions(self, exprs):
        self.col, = exprs

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        copy = self.copy()
        copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
        field = copy.col.output_field
        assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name
        if settings.USE_TZ:
            assert not isinstance(field, fields.DateTimeField), (
                "%r is a DateTimeField, not a DateField." % field.name
            )
        return copy

    def as_sql(self, compiler, connection):
        sql, params = self.col.as_sql(compiler, connection)
        assert not(params)
        return connection.ops.date_trunc_sql(self.lookup_type, sql), []

    def copy(self):
        copy = super(Date, self).copy()
        copy.lookup = self.lookup
        copy.lookup_type = self.lookup_type
        return copy

    def convert_value(self, value, expression, connection, context):
        if isinstance(value, datetime.datetime):
            value = value.date()
        return value


class DateTime(Expression):
    """
    Add a datetime selection column.
    """
    def __init__(self, lookup, lookup_type, tzinfo):
        super(DateTime, self).__init__(output_field=fields.DateTimeField())
        self.lookup = lookup
        self.col = None
        self.lookup_type = lookup_type
        if tzinfo is None:
            self.tzname = None
        else:
            self.tzname = timezone._get_timezone_name(tzinfo)
        self.tzinfo = tzinfo

    def __repr__(self):
        return "{}({}, {}, {})".format(
            self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)

    def get_source_expressions(self):
        return [self.col]

    def set_source_expressions(self, exprs):
        self.col, = exprs

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        copy = self.copy()
        copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
        field = copy.col.output_field
        assert isinstance(field, fields.DateTimeField), (
            "%r isn't a DateTimeField." % field.name
        )
        return copy

    def as_sql(self, compiler, connection):
        sql, params = self.col.as_sql(compiler, connection)
        assert not(params)
        return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)

    def copy(self):
        copy = super(DateTime, self).copy()
        copy.lookup = self.lookup
        copy.lookup_type = self.lookup_type
        copy.tzname = self.tzname
        return copy

    def convert_value(self, value, expression, connection, context):
        if settings.USE_TZ:
            if value is None:
                raise ValueError(
                    "Database returned an invalid value in QuerySet.datetimes(). "
                    "Are time zone definitions for your database and pytz installed?"
                )
            value = value.replace(tzinfo=None)
            value = timezone.make_aware(value, self.tzinfo)
        return value


class OrderBy(BaseExpression):
    template = '%(expression)s %(ordering)s'

+21 −0
Original line number Diff line number Diff line
from .base import (
    Cast, Coalesce, Concat, ConcatPair, Greatest, Least, Length, Lower, Now,
    Substr, Upper,
)

from .datetime import (
    Extract, ExtractDay, ExtractHour, ExtractMinute, ExtractMonth,
    ExtractSecond, ExtractWeekDay, ExtractYear, Trunc, TruncDate, TruncDay,
    TruncHour, TruncMinute, TruncMonth, TruncSecond, TruncYear,
)

__all__ = [
    # base
    'Cast', 'Coalesce', 'Concat', 'ConcatPair', 'Greatest', 'Least', 'Length',
    'Lower', 'Now', 'Substr', 'Upper',
    # datetime
    'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
    'ExtractSecond', 'ExtractWeekDay', 'ExtractYear',
    'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
    'TruncSecond', 'TruncYear',
]
+250 −0
Original line number Diff line number Diff line
from __future__ import absolute_import
from datetime import datetime

from django.conf import settings
from django.db.models import (
    DateField, DateTimeField, IntegerField, TimeField, Transform,
)
from django.db.models.lookups import (
    YearExact, YearGt, YearGte, YearLt, YearLte,
)
from django.utils import timezone
from django.utils.functional import cached_property


class TimezoneMixin(object):
    tzinfo = None

    def get_tzname(self):
        # Timezone conversions must happen to the input datetime *before*
        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
        # database as 2016-01-01 01:00:00 +00:00. Any results should be
        # based on the input datetime not the stored datetime.
        tzname = None
        if settings.USE_TZ:
            if self.tzinfo is None:
                tzname = timezone.get_current_timezone_name()
            else:
                tzname = timezone._get_timezone_name(self.tzinfo)
        return tzname


class Extract(TimezoneMixin, Transform):
    lookup_name = None

    def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
        if self.lookup_name is None:
            self.lookup_name = lookup_name
        if self.lookup_name is None:
            raise ValueError('lookup_name must be provided')
        self.tzinfo = tzinfo
        super(Extract, self).__init__(expression, **extra)

    def as_sql(self, compiler, connection):
        sql, params = compiler.compile(self.lhs)
        lhs_output_field = self.lhs.output_field
        if isinstance(lhs_output_field, DateTimeField):
            tzname = self.get_tzname()
            sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
            params.extend(tz_params)
        elif isinstance(lhs_output_field, DateField):
            sql = connection.ops.date_extract_sql(self.lookup_name, sql)
        elif isinstance(lhs_output_field, TimeField):
            sql = connection.ops.time_extract_sql(self.lookup_name, sql)
        else:
            # resolve_expression has already validated the output_field so this
            # assert should never be hit.
            assert False, "Tried to Extract from an invalid type."
        return sql, params

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        copy = super(Extract, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
        field = copy.lhs.output_field
        if not isinstance(field, (DateField, DateTimeField, TimeField)):
            raise ValueError('Extract input expression must be DateField, DateTimeField, or TimeField.')
        # Passing dates to functions expecting datetimes is most likely a mistake.
        if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
            raise ValueError(
                "Cannot extract time component '%s' from DateField '%s'. " % (copy.lookup_name, field.name)
            )
        return copy

    @cached_property
    def output_field(self):
        return IntegerField()


class ExtractYear(Extract):
    lookup_name = 'year'


class ExtractMonth(Extract):
    lookup_name = 'month'


class ExtractDay(Extract):
    lookup_name = 'day'


class ExtractWeekDay(Extract):
    """
    Return Sunday=1 through Saturday=7.

    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
    """
    lookup_name = 'week_day'


class ExtractHour(Extract):
    lookup_name = 'hour'


class ExtractMinute(Extract):
    lookup_name = 'minute'


class ExtractSecond(Extract):
    lookup_name = 'second'


DateField.register_lookup(ExtractYear)
DateField.register_lookup(ExtractMonth)
DateField.register_lookup(ExtractDay)
DateField.register_lookup(ExtractWeekDay)

TimeField.register_lookup(ExtractHour)
TimeField.register_lookup(ExtractMinute)
TimeField.register_lookup(ExtractSecond)

DateTimeField.register_lookup(ExtractYear)
DateTimeField.register_lookup(ExtractMonth)
DateTimeField.register_lookup(ExtractDay)
DateTimeField.register_lookup(ExtractWeekDay)
DateTimeField.register_lookup(ExtractHour)
DateTimeField.register_lookup(ExtractMinute)
DateTimeField.register_lookup(ExtractSecond)

ExtractYear.register_lookup(YearExact)
ExtractYear.register_lookup(YearGt)
ExtractYear.register_lookup(YearGte)
ExtractYear.register_lookup(YearLt)
ExtractYear.register_lookup(YearLte)


class TruncBase(TimezoneMixin, Transform):
    arity = 1
    kind = None
    tzinfo = None

    def __init__(self, expression, output_field=None, tzinfo=None, **extra):
        self.tzinfo = tzinfo
        super(TruncBase, self).__init__(expression, output_field=output_field, **extra)

    def as_sql(self, compiler, connection):
        inner_sql, inner_params = compiler.compile(self.lhs)
        # Escape any params because trunc_sql will format the string.
        inner_sql = inner_sql.replace('%s', '%%s')
        if isinstance(self.output_field, DateTimeField):
            tzname = self.get_tzname()
            sql, params = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
        elif isinstance(self.output_field, DateField):
            sql = connection.ops.date_trunc_sql(self.kind, inner_sql)
            params = []
        else:
            raise ValueError('Trunc only valid on DateField or DateTimeField.')
        return sql, inner_params + params

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        copy = super(TruncBase, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
        field = copy.lhs.output_field
        # DateTimeField is a subclass of DateField so this works for both.
        assert isinstance(field, DateField), (
            "%r isn't a DateField or DateTimeField." % field.name
        )
        # If self.output_field was None, then accessing the field will trigger
        # the resolver to assign it to self.lhs.output_field.
        if not isinstance(copy.output_field, (DateField, DateTimeField)):
            raise ValueError('output_field must be either DateField or DateTimeField')
        # Passing dates to functions expecting datetimes is most likely a
        # mistake.
        if type(field) == DateField and (
                isinstance(copy.output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second')):
            raise ValueError("Cannot truncate DateField '%s' to DateTimeField. " % field.name)
        return copy

    def convert_value(self, value, expression, connection, context):
        if isinstance(self.output_field, DateTimeField):
            if settings.USE_TZ:
                if value is None:
                    raise ValueError(
                        "Database returned an invalid datetime value. "
                        "Are time zone definitions for your database and pytz installed?"
                    )
                value = value.replace(tzinfo=None)
                value = timezone.make_aware(value, self.tzinfo)
        elif isinstance(value, datetime):
            # self.output_field is definitely a DateField here.
            value = value.date()
        return value


class Trunc(TruncBase):

    def __init__(self, expression, kind, output_field=None, tzinfo=None, **extra):
        self.kind = kind
        super(Trunc, self).__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)


class TruncYear(TruncBase):
    kind = 'year'


class TruncMonth(TruncBase):
    kind = 'month'


class TruncDay(TruncBase):
    kind = 'day'


class TruncDate(TruncBase):
    lookup_name = 'date'

    @cached_property
    def output_field(self):
        return DateField()

    def as_sql(self, compiler, connection):
        # Cast to date rather than truncate to date.
        lhs, lhs_params = compiler.compile(self.lhs)
        tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
        sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
        lhs_params.extend(tz_params)
        return sql, lhs_params


class TruncHour(TruncBase):
    kind = 'hour'

    @cached_property
    def output_field(self):
        return DateTimeField()


class TruncMinute(TruncBase):
    kind = 'minute'

    @cached_property
    def output_field(self):
        return DateTimeField()


class TruncSecond(TruncBase):
    kind = 'second'

    @cached_property
    def output_field(self):
        return DateTimeField()


DateTimeField.register_lookup(TruncDate)
Loading