Commit cbb5cdd1 authored by Anssi Kääriäinen's avatar Anssi Kääriäinen Committed by Tim Graham
Browse files

Fixed #23867 -- removed DateQuerySet hacks

The .dates() queries were implemented by using custom Query, QuerySet,
and Compiler classes. Instead implement them by using expressions and
database converters APIs.
parent cc870b8e
Loading
Loading
Loading
Loading
+0 −8
Original line number Diff line number Diff line
@@ -22,11 +22,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):

class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
    pass


class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
    pass


class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
    pass
+0 −8
Original line number Diff line number Diff line
@@ -235,11 +235,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):

class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
    pass


class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
    pass


class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
    pass
+0 −8
Original line number Diff line number Diff line
@@ -23,11 +23,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):

class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
    pass


class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
    pass


class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
    pass
+0 −8
Original line number Diff line number Diff line
@@ -54,11 +54,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):

class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
    pass


class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
    pass


class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
    pass
+66 −6
Original line number Diff line number Diff line
import copy
import datetime

from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import refs_aggregate
from django.utils import timezone
from django.utils.functional import cached_property


@@ -124,6 +126,9 @@ class ExpressionNode(CombinableMixin):
    # aggregate specific fields
    is_summary = False

    def get_db_converters(self, connection):
        return [self.convert_value]

    def __init__(self, output_field=None):
        self._output_field = output_field

@@ -531,32 +536,60 @@ class Date(ExpressionNode):
    """
    Add a date selection column.
    """
    def __init__(self, col, lookup_type):
    def __init__(self, lookup, lookup_type):
        super(Date, self).__init__(output_field=fields.DateField())
        self.col = col
        self.lookup = lookup
        self.col = None
        self.lookup_type = lookup_type

    def get_source_expressions(self):
        return [self.col]

    def set_source_expressions(self, exprs):
        self.col, = self.exprs
        self.col, = exprs

    def resolve_expression(self, query, allow_joins, reuse, summarize):
        copy = self.copy()
        copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
        field = copy.col.output_field
        assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name
        if settings.USE_TZ:
            assert not isinstance(field, fields.DateTimeField), (
                "%r is a DateTimeField, not a DateField." % field.name
            )
        return copy

    def as_sql(self, compiler, connection):
        sql, params = self.col.as_sql(compiler, connection)
        assert not(params)
        return connection.ops.date_trunc_sql(self.lookup_type, sql), []

    def copy(self):
        copy = super(Date, self).copy()
        copy.lookup = self.lookup
        copy.lookup_type = self.lookup_type
        return copy

    def convert_value(self, value, connection):
        if isinstance(value, datetime.datetime):
            value = value.date()
        return value


class DateTime(ExpressionNode):
    """
    Add a datetime selection column.
    """
    def __init__(self, col, lookup_type, tzname):
    def __init__(self, lookup, lookup_type, tzinfo):
        super(DateTime, self).__init__(output_field=fields.DateTimeField())
        self.col = col
        self.lookup = lookup
        self.col = None
        self.lookup_type = lookup_type
        self.tzname = tzname
        if tzinfo is None:
            self.tzname = None
        else:
            self.tzname = timezone._get_timezone_name(tzinfo)
        self.tzinfo = tzinfo

    def get_source_expressions(self):
        return [self.col]
@@ -564,7 +597,34 @@ class DateTime(ExpressionNode):
    def set_source_expressions(self, exprs):
        self.col, = exprs

    def resolve_expression(self, query, allow_joins, reuse, summarize):
        copy = self.copy()
        copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
        field = copy.col.output_field
        assert isinstance(field, fields.DateTimeField), (
            "%r isn't a DateTimeField." % field.name
        )
        return copy

    def as_sql(self, compiler, connection):
        sql, params = self.col.as_sql(compiler, connection)
        assert not(params)
        return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)

    def copy(self):
        copy = super(DateTime, self).copy()
        copy.lookup = self.lookup
        copy.lookup_type = self.lookup_type
        copy.tzname = self.tzname
        return copy

    def convert_value(self, value, connection):
        if settings.USE_TZ:
            if value is None:
                raise ValueError(
                    "Database returned an invalid value in QuerySet.datetimes(). "
                    "Are time zone definitions for your database and pytz installed?"
                )
            value = value.replace(tzinfo=None)
            value = timezone.make_aware(value, self.tzinfo)
        return value
Loading