Commit 20e69736 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Fixed #18323 -- Refactored date arithmetic

in date based generic views, in order to deal properly with both
DateFields and DateTimeFields.
parent dcd43831
Loading
Loading
Loading
Loading
+139 −63
Original line number Diff line number Diff line
@@ -23,7 +23,9 @@ class YearMixin(object):
        return self.year_format

    def get_year(self):
        "Return the year for which this view should display data"
        """
        Return the year for which this view should display data.
        """
        year = self.year
        if year is None:
            try:
@@ -35,6 +37,20 @@ class YearMixin(object):
                    raise Http404(_(u"No year specified"))
        return year

    def _get_next_year(self, date):
        """
        Return the start date of the next interval.

        The interval is defined by start date <= item date < next start date.
        """
        return date.replace(year=date.year + 1, month=1, day=1)

    def _get_current_year(self, date):
        """
        Return the start date of the current interval.
        """
        return date.replace(month=1, day=1)


class MonthMixin(object):
    month_format = '%b'
@@ -48,7 +64,9 @@ class MonthMixin(object):
        return self.month_format

    def get_month(self):
        "Return the month for which this view should display data"
        """
        Return the month for which this view should display data.
        """
        month = self.month
        if month is None:
            try:
@@ -64,20 +82,30 @@ class MonthMixin(object):
        """
        Get the next valid month.
        """
        # next must be the first day of the next month.
        if date.month == 12:
            next = date.replace(year=date.year + 1, month=1, day=1)
        else:
            next = date.replace(month=date.month + 1, day=1)
        return _get_next_prev(self, next, is_previous=False, period='month')
        return _get_next_prev(self, date, is_previous=False, period='month')

    def get_previous_month(self, date):
        """
        Get the previous valid month.
        """
        # prev must be the last day of the previous month.
        prev = date.replace(day=1) - datetime.timedelta(days=1)
        return _get_next_prev(self, prev, is_previous=True, period='month')
        return _get_next_prev(self, date, is_previous=True, period='month')

    def _get_next_month(self, date):
        """
        Return the start date of the next interval.

        The interval is defined by start date <= item date < next start date.
        """
        if date.month == 12:
            return date.replace(year=date.year + 1, month=1, day=1)
        else:
            return date.replace(month=date.month + 1, day=1)

    def _get_current_month(self, date):
        """
        Return the start date of the previous interval.
        """
        return date.replace(day=1)


class DayMixin(object):
@@ -92,7 +120,9 @@ class DayMixin(object):
        return self.day_format

    def get_day(self):
        "Return the day for which this view should display data"
        """
        Return the day for which this view should display data.
        """
        day = self.day
        if day is None:
            try:
@@ -108,15 +138,27 @@ class DayMixin(object):
        """
        Get the next valid day.
        """
        next = date + datetime.timedelta(days=1)
        return _get_next_prev(self, next, is_previous=False, period='day')
        return _get_next_prev(self, date, is_previous=False, period='day')

    def get_previous_day(self, date):
        """
        Get the previous valid day.
        """
        prev = date - datetime.timedelta(days=1)
        return _get_next_prev(self, prev, is_previous=True, period='day')
        return _get_next_prev(self, date, is_previous=True, period='day')

    def _get_next_day(self, date):
        """
        Return the start date of the next interval.

        The interval is defined by start date <= item date < next start date.
        """
        return date + datetime.timedelta(days=1)

    def _get_current_day(self, date):
        """
        Return the start date of the current interval.
        """
        return date


class WeekMixin(object):
@@ -131,7 +173,9 @@ class WeekMixin(object):
        return self.week_format

    def get_week(self):
        "Return the week for which this view should display data"
        """
        Return the week for which this view should display data
        """
        week = self.week
        if week is None:
            try:
@@ -147,19 +191,34 @@ class WeekMixin(object):
        """
        Get the next valid week.
        """
        # next must be the first day of the next week.
        next = date + datetime.timedelta(days=7 - self._get_weekday(date))
        return _get_next_prev(self, next, is_previous=False, period='week')
        return _get_next_prev(self, date, is_previous=False, period='week')

    def get_previous_week(self, date):
        """
        Get the previous valid week.
        """
        # prev must be the last day of the previous week.
        prev = date - datetime.timedelta(days=self._get_weekday(date) + 1)
        return _get_next_prev(self, prev, is_previous=True, period='week')
        return _get_next_prev(self, date, is_previous=True, period='week')

    def _get_next_week(self, date):
        """
        Return the start date of the next interval.

        The interval is defined by start date <= item date < next start date.
        """
        return date + datetime.timedelta(days=7 - self._get_weekday(date))

    def _get_current_week(self, date):
        """
        Return the start date of the current interval.
        """
        return date - datetime.timedelta(self._get_weekday(date))

    def _get_weekday(self, date):
        """
        Return the weekday for a given date.

        The first day according to the week format is 0 and the last day is 6.
        """
        week_format = self.get_week_format()
        if week_format == '%W':                 # week starts on Monday
            return date.weekday()
@@ -168,6 +227,7 @@ class WeekMixin(object):
        else:
            raise ValueError("unknown week format: %s" % week_format)


class DateMixin(object):
    """
    Mixin class for views manipulating date-based data.
@@ -267,7 +327,7 @@ class BaseDateListView(MultipleObjectMixin, DateMixin, View):
        paginate_by = self.get_paginate_by(qs)

        if not allow_future:
            now = timezone.now() if self.uses_datetime_field else datetime.date.today()
            now = timezone.now() if self.uses_datetime_field else timezone_today()
            qs = qs.filter(**{'%s__lte' % date_field: now})

        if not allow_empty:
@@ -344,7 +404,7 @@ class BaseYearArchiveView(YearMixin, BaseDateListView):
        date = _date_from_string(year, self.get_year_format())

        since = self._make_date_lookup_arg(date)
        until = self._make_date_lookup_arg(datetime.date(date.year + 1, 1, 1))
        until = self._make_date_lookup_arg(self._get_next_year(date))
        lookup_kwargs = {
            '%s__gte' % date_field: since,
            '%s__lt' % date_field: until,
@@ -392,12 +452,8 @@ class BaseMonthArchiveView(YearMixin, MonthMixin, BaseDateListView):
        date = _date_from_string(year, self.get_year_format(),
                                 month, self.get_month_format())

        # Construct a date-range lookup.
        since = self._make_date_lookup_arg(date)
        if date.month == 12:
            until = self._make_date_lookup_arg(datetime.date(date.year + 1, 1, 1))
        else:
            until = self._make_date_lookup_arg(datetime.date(date.year, date.month + 1, 1))
        until = self._make_date_lookup_arg(self._get_next_month(date))
        lookup_kwargs = {
            '%s__gte' % date_field: since,
            '%s__lt' % date_field: until,
@@ -442,9 +498,8 @@ class BaseWeekArchiveView(YearMixin, WeekMixin, BaseDateListView):
                                 week_start, '%w',
                                 week, week_format)

        # Construct a date-range lookup.
        since = self._make_date_lookup_arg(date)
        until = self._make_date_lookup_arg(date + datetime.timedelta(days=7))
        until = self._make_date_lookup_arg(self._get_next_week(date))
        lookup_kwargs = {
            '%s__gte' % date_field: since,
            '%s__lt' % date_field: until,
@@ -585,22 +640,22 @@ def _date_from_string(year, year_format, month='', month_format='', day='', day_
        })


def _get_next_prev(generic_view, naive_result, is_previous, period):
def _get_next_prev(generic_view, date, is_previous, period):
    """
    Helper: Get the next or the previous valid date. The idea is to allow
    links on month/day views to never be 404s by never providing a date
    that'll be invalid for the given view.

    This is a bit complicated since it handles both next and previous months
    and days (for MonthArchiveView and DayArchiveView); hence the coupling to generic_view.
    This is a bit complicated since it handles different intervals of time,
    hence the coupling to generic_view.

    However in essence the logic comes down to:

        * If allow_empty and allow_future are both true, this is easy: just
          return the naive result (just the next/previous day or month,
          return the naive result (just the next/previous day/week/month,
          reguardless of object existence.)

        * If allow_empty is true, allow_future is false, and the naive month
        * If allow_empty is true, allow_future is false, and the naive result
          isn't in the future, then return it; otherwise return None.

        * If allow_empty is false and allow_future is true, return the next
@@ -616,9 +671,23 @@ def _get_next_prev(generic_view, naive_result, is_previous, period):
    allow_empty = generic_view.get_allow_empty()
    allow_future = generic_view.get_allow_future()

    # If allow_empty is True the naive value will be valid
    get_current = getattr(generic_view, '_get_current_%s' % period)
    get_next = getattr(generic_view, '_get_next_%s' % period)

    # Bounds of the current interval
    start, end = get_current(date), get_next(date)

    # If allow_empty is True, the naive result will be valid
    if allow_empty:
        result = naive_result
        if is_previous:
            result = get_current(start - datetime.timedelta(days=1))
        else:
            result = end

        if allow_future or result <= timezone_today():
            return result
        else:
            return None

    # Otherwise, we'll need to go to the database to look for an object
    # whose date_field is at least (greater than/less than) the given
@@ -627,12 +696,22 @@ def _get_next_prev(generic_view, naive_result, is_previous, period):
        # Construct a lookup and an ordering depending on whether we're doing
        # a previous date or a next date lookup.
        if is_previous:
            lookup = {'%s__lte' % date_field: generic_view._make_date_lookup_arg(naive_result)}
            lookup = {'%s__lt' % date_field: generic_view._make_date_lookup_arg(start)}
            ordering = '-%s' % date_field
        else:
            lookup = {'%s__gte' % date_field: generic_view._make_date_lookup_arg(naive_result)}
            lookup = {'%s__gte' % date_field: generic_view._make_date_lookup_arg(end)}
            ordering = date_field

        # Filter out objects in the future if appropriate.
        if not allow_future:
            # Fortunately, to match the implementation of allow_future,
            # we need __lte, which doesn't conflict with __lt above.
            if generic_view.uses_datetime_field:
                now = timezone.now()
            else:
                now = timezone_today()
            lookup['%s__lte' % date_field] = now

        qs = generic_view.get_queryset().filter(**lookup).order_by(ordering)

        # Snag the first object from the queryset; if it doesn't exist that
@@ -640,26 +719,23 @@ def _get_next_prev(generic_view, naive_result, is_previous, period):
        try:
            result = getattr(qs[0], date_field)
        except IndexError:
            result = None
            return None

    # Convert datetimes to a dates
    if result and generic_view.uses_datetime_field:
        # Convert datetimes to dates in the current time zone.
        if generic_view.uses_datetime_field:
            if settings.USE_TZ:
                result = timezone.localtime(result)
            result = result.date()

    if result:
        if period == 'month':
            # first day of the month
            result = result.replace(day=1)
        elif period == 'week':
            # monday of the week
            result = result - datetime.timedelta(days=generic_view._get_weekday(result))
        elif period != 'day':
            raise ValueError('invalid period: %s' % period)

    # Check against future dates.
    if result and (allow_future or result < datetime.date.today()):
        return result
        # Return the first day of the period.
        return get_current(result)


def timezone_today():
    """
    Return the current date in the current time zone.
    """
    if settings.USE_TZ:
        return timezone.localtime(timezone.now()).date()
    else:
        return None
        return datetime.date.today()