Commit e9103402 authored by Marc Tamlyn's avatar Marc Tamlyn
Browse files

Fixed #18757, #14462, #21565 -- Reworked database-python type conversions

Complete rework of translating data values from database

Deprecation of SubfieldBase, removal of resolve_columns and
convert_values in favour of a more general converter based approach and
public API Field.from_db_value(). Now works seamlessly with aggregation,
.values() and raw queries.

Thanks to akaariai in particular for extensive advice and inspiration,
also to shaib, manfre and timograham for their reviews.
parent 89559bcf
Loading
Loading
Loading
Loading
+0 −41
Original line number Diff line number Diff line
from django.contrib.gis.db.models.sql.compiler import GeoSQLCompiler as BaseGeoSQLCompiler
from django.db.backends.mysql import compiler

SQLCompiler = compiler.SQLCompiler


class GeoSQLCompiler(BaseGeoSQLCompiler, SQLCompiler):
    def resolve_columns(self, row, fields=()):
        """
        Integrate the cases handled both by the base GeoSQLCompiler and the
        main MySQL compiler (converting 0/1 to True/False for boolean fields).

        Refs #15169.

        """
        row = BaseGeoSQLCompiler.resolve_columns(self, row, fields)
        return SQLCompiler.resolve_columns(self, row, fields)


class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
    pass


class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
    pass


class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
    pass


class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
    pass


class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
    pass


class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
    pass
+1 −1
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from django.contrib.gis.db.backends.base import BaseSpatialOperations

class MySQLOperations(DatabaseOperations, BaseSpatialOperations):

    compiler_module = 'django.contrib.gis.db.backends.mysql.compiler'
    compiler_module = 'django.contrib.gis.db.models.sql.compiler'
    mysql = True
    name = 'mysql'
    select = 'AsText(%s)'
+5 −0
Original line number Diff line number Diff line
@@ -197,6 +197,11 @@ class GeometryField(Field):
        else:
            return geom

    def from_db_value(self, value, connection):
        if value is not None:
            value = Geometry(value)
        return value

    def get_srid(self, geom):
        """
        Returns the default SRID for the given geometry, taking into account
+1 −27
Original line number Diff line number Diff line
from django.db import connections
from django.db.models.query import QuerySet, ValuesQuerySet, ValuesListQuerySet
from django.db.models.query import QuerySet

from django.contrib.gis.db.models import aggregates
from django.contrib.gis.db.models.fields import get_srid_info, PointField, LineStringField
@@ -18,19 +18,6 @@ class GeoQuerySet(QuerySet):
        super(GeoQuerySet, self).__init__(model=model, query=query, using=using, hints=hints)
        self.query = query or GeoQuery(self.model)

    def values(self, *fields):
        return self._clone(klass=GeoValuesQuerySet, setup=True, _fields=fields)

    def values_list(self, *fields, **kwargs):
        flat = kwargs.pop('flat', False)
        if kwargs:
            raise TypeError('Unexpected keyword arguments to values_list: %s'
                    % (list(kwargs),))
        if flat and len(fields) > 1:
            raise TypeError("'flat' is not valid when values_list is called with more than one field.")
        return self._clone(klass=GeoValuesListQuerySet, setup=True, flat=flat,
                           _fields=fields)

    ### GeoQuerySet Methods ###
    def area(self, tolerance=0.05, **kwargs):
        """
@@ -767,16 +754,3 @@ class GeoQuerySet(QuerySet):
            return self.query.get_compiler(self.db)._field_column(geo_field, parent_model._meta.db_table)
        else:
            return self.query.get_compiler(self.db)._field_column(geo_field)


class GeoValuesQuerySet(ValuesQuerySet):
    def __init__(self, *args, **kwargs):
        super(GeoValuesQuerySet, self).__init__(*args, **kwargs)
        # This flag tells `resolve_columns` to run the values through
        # `convert_values`.  This ensures that Geometry objects instead
        # of string values are returned with `values()` or `values_list()`.
        self.query.geo_values = True


class GeoValuesListQuerySet(GeoValuesQuerySet, ValuesListQuerySet):
    pass
+10 −88
Original line number Diff line number Diff line
import datetime

from django.conf import settings
from django.db.backends.utils import truncate_name, typecast_date, typecast_timestamp
from django.db.backends.utils import truncate_name
from django.db.models.sql import compiler
from django.db.models.sql.constants import MULTI
from django.utils import six
from django.utils.six.moves import zip, zip_longest
from django.utils import timezone

SQLCompiler = compiler.SQLCompiler

@@ -153,38 +147,13 @@ class GeoSQLCompiler(compiler.SQLCompiler):
                    col_aliases.add(field.column)
        return result, aliases

    def resolve_columns(self, row, fields=()):
        """
        This routine is necessary so that distances and geometries returned
        from extra selection SQL get resolved appropriately into Python
        objects.
        """
        values = []
        aliases = list(self.query.extra_select)

        # Have to set a starting row number offset that is used for
        # determining the correct starting row index -- needed for
        # doing pagination with Oracle.
        rn_offset = 0
        if self.connection.ops.oracle:
            if self.query.high_mark is not None or self.query.low_mark:
                rn_offset = 1
        index_start = rn_offset + len(aliases)

        # Converting any extra selection values (e.g., geometries and
        # distance objects added by GeoQuerySet methods).
        values = [self.query.convert_values(v,
                               self.query.extra_select_fields.get(a, None),
                               self.connection)
                  for v, a in zip(row[rn_offset:index_start], aliases)]
        if self.connection.ops.oracle or getattr(self.query, 'geo_values', False):
            # We resolve the rest of the columns if we're on Oracle or if
            # the `geo_values` attribute is defined.
            for value, field in zip_longest(row[index_start:], fields):
                values.append(self.query.convert_values(value, field, self.connection))
        else:
            values.extend(row[index_start:])
        return tuple(values)
    def get_converters(self, fields):
        converters = super(GeoSQLCompiler, self).get_converters(fields)
        for i, alias in enumerate(self.query.extra_select):
            field = self.query.extra_select_fields.get(alias)
            if field:
                converters[i] = ([], [field.from_db_value], field)
        return converters

    #### Routines unique to GeoQuery ####
    def get_extra_select_format(self, alias):
@@ -268,55 +237,8 @@ class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):


class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
    """
    This is overridden for GeoDjango to properly cast date columns, since
    `GeoQuery.resolve_columns` is used for spatial values.
    See #14648, #16757.
    """
    def results_iter(self):
        if self.connection.ops.oracle:
            from django.db.models.fields import DateTimeField
            fields = [DateTimeField()]
        else:
            needs_string_cast = self.connection.features.needs_datetime_string_cast

        offset = len(self.query.extra_select)
        for rows in self.execute_sql(MULTI):
            for row in rows:
                date = row[offset]
                if self.connection.ops.oracle:
                    date = self.resolve_columns(row, fields)[offset]
                elif needs_string_cast:
                    date = typecast_date(str(date))
                if isinstance(date, datetime.datetime):
                    date = date.date()
                yield date
    pass


class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
    """
    This is overridden for GeoDjango to properly cast date columns, since
    `GeoQuery.resolve_columns` is used for spatial values.
    See #14648, #16757.
    """
    def results_iter(self):
        if self.connection.ops.oracle:
            from django.db.models.fields import DateTimeField
            fields = [DateTimeField()]
        else:
            needs_string_cast = self.connection.features.needs_datetime_string_cast

        offset = len(self.query.extra_select)
        for rows in self.execute_sql(MULTI):
            for row in rows:
                datetime = row[offset]
                if self.connection.ops.oracle:
                    datetime = self.resolve_columns(row, fields)[offset]
                elif needs_string_cast:
                    datetime = typecast_timestamp(str(datetime))
                # Datetimes are artificially returned in UTC on databases that
                # don't support time zone. Restore the zone used in the query.
                if settings.USE_TZ:
                    datetime = datetime.replace(tzinfo=None)
                    datetime = timezone.make_aware(datetime, self.query.tzinfo)
                yield datetime
    pass
Loading