Commit d9ff5ef3 authored by Claude Paroz's avatar Claude Paroz
Browse files

Fixed #24214 -- Added GIS functions to replace geoqueryset's methods

Thanks Simon Charette and Tim Graham for the reviews.
parent 1418f753
Loading
Loading
Loading
Loading
+14 −5
Original line number Diff line number Diff line
import re
from functools import partial

from django.contrib.gis.db.models import aggregates
@@ -59,11 +60,11 @@ class BaseSpatialFeatures(object):
    # `has_<name>_method` (defined in __init__) which accesses connection.ops
    # to determine GIS method availability.
    geoqueryset_methods = (
        'area', 'centroid', 'difference', 'distance', 'distance_spheroid',
        'envelope', 'force_rhr', 'geohash', 'gml', 'intersection', 'kml',
        'length', 'num_geom', 'perimeter', 'point_on_surface', 'reverse',
        'scale', 'snap_to_grid', 'svg', 'sym_difference', 'transform',
        'translate', 'union', 'unionagg',
        'area', 'bounding_circle', 'centroid', 'difference', 'distance',
        'distance_spheroid', 'envelope', 'force_rhr', 'geohash', 'gml',
        'intersection', 'kml', 'length', 'mem_size', 'num_geom', 'num_points',
        'perimeter', 'point_on_surface', 'reverse', 'scale', 'snap_to_grid',
        'svg', 'sym_difference', 'transform', 'translate', 'union', 'unionagg',
    )

    # Specifies whether the Collect and Extent aggregates are supported by the database
@@ -86,5 +87,13 @@ class BaseSpatialFeatures(object):
            setattr(self.__class__, 'has_%s_method' % method,
                    property(partial(BaseSpatialFeatures.has_ops_method, method=method)))

    def __getattr__(self, name):
        m = re.match(r'has_(\w*)_function$', name)
        if m:
            func_name = m.group(1)
            if func_name not in self.connection.ops.unsupported_functions:
                return True
        return False

    def has_ops_method(self, method):
        return getattr(self.connection.ops, method, False)
+23 −2
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ class BaseSpatialOperations(object):
    geometry = False

    area = False
    bounding_circle = False
    centroid = False
    difference = False
    distance = False
@@ -30,7 +31,6 @@ class BaseSpatialOperations(object):
    envelope = False
    force_rhr = False
    mem_size = False
    bounding_circle = False
    num_geom = False
    num_points = False
    perimeter = False
@@ -48,6 +48,22 @@ class BaseSpatialOperations(object):
    # Aggregates
    disallowed_aggregates = ()

    geom_func_prefix = ''

    # Mapping between Django function names and backend names, when names do not
    # match; used in spatial_function_name().
    function_names = {}

    # Blacklist/set of known unsupported functions of the backend
    unsupported_functions = {
        'Area', 'AsGeoHash', 'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG',
        'BoundingCircle', 'Centroid', 'Difference', 'Distance', 'Envelope',
        'ForceRHR', 'Intersection', 'Length', 'MemSize', 'NumGeometries',
        'NumPoints', 'Perimeter', 'PointOnSurface', 'Reverse', 'Scale',
        'SnapToGrid', 'SymDifference', 'Transform', 'Translate',
        'Union',
    }

    # Serialization
    geohash = False
    geojson = False
@@ -108,9 +124,14 @@ class BaseSpatialOperations(object):
    def spatial_aggregate_name(self, agg_name):
        raise NotImplementedError('Aggregate support not implemented for this spatial backend.')

    def spatial_function_name(self, func_name):
        if func_name in self.unsupported_functions:
            raise NotImplementedError("This backend doesn't support the %s function." % func_name)
        return self.function_names.get(func_name, self.geom_func_prefix + func_name)

    # Routines for getting the OGC-compliant models.
    def geometry_columns(self):
        raise NotImplementedError('subclasses of BaseSpatialOperations must a provide geometry_columns() method')
        raise NotImplementedError('Subclasses of BaseSpatialOperations must provide a geometry_columns() method.')

    def spatial_ref_sys(self):
        raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
+6 −2
Original line number Diff line number Diff line
@@ -8,12 +8,13 @@ from psycopg2.extensions import ISQLQuote


class PostGISAdapter(object):
    def __init__(self, geom):
    def __init__(self, geom, geography=False):
        "Initializes on the geometry."
        # Getting the WKB (in string form, to allow easy pickling of
        # the adaptor) and the SRID from the geometry.
        self.ewkb = bytes(geom.ewkb)
        self.srid = geom.srid
        self.geography = geography
        self._adapter = Binary(self.ewkb)

    def __conform__(self, proto):
@@ -44,4 +45,7 @@ class PostGISAdapter(object):
    def getquoted(self):
        "Returns a properly quoted string for use in PostgreSQL/PostGIS."
        # psycopg will figure out whether to use E'\\000' or '\000'
        return str('ST_GeomFromEWKB(%s)' % self._adapter.getquoted().decode())
        return str('%s(%s)' % (
            'ST_GeogFromWKB' if self.geography else 'ST_GeomFromEWKB',
            self._adapter.getquoted().decode())
        )
+7 −0
Original line number Diff line number Diff line
@@ -88,6 +88,13 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
        'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True),
    }

    unsupported_functions = set()
    function_names = {
        'BoundingCircle': 'ST_MinimumBoundingCircle',
        'MemSize': 'ST_Mem_Size',
        'NumPoints': 'ST_NPoints',
    }

    def __init__(self, connection):
        super(PostGISOperations, self).__init__(connection)

+351 −0
Original line number Diff line number Diff line
from decimal import Decimal

from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.sql import AreaField
from django.contrib.gis.geos.geometry import GEOSGeometry
from django.contrib.gis.measure import (
    Area as AreaMeasure, Distance as DistanceMeasure,
)
from django.core.exceptions import FieldError
from django.db.models import FloatField, IntegerField, TextField
from django.db.models.expressions import Func, Value
from django.utils import six

NUMERIC_TYPES = six.integer_types + (float, Decimal)


class GeoFunc(Func):
    function = None
    output_field_class = None
    geom_param_pos = 0

    def __init__(self, *expressions, **extra):
        if 'output_field' not in extra and self.output_field_class:
            extra['output_field'] = self.output_field_class()
        super(GeoFunc, self).__init__(*expressions, **extra)

    @property
    def name(self):
        return self.__class__.__name__

    @property
    def srid(self):
        expr = self.source_expressions[self.geom_param_pos]
        if hasattr(expr, 'srid'):
            return expr.srid
        try:
            return expr.field.srid
        except (AttributeError, FieldError):
            return None

    def as_sql(self, compiler, connection):
        if self.function is None:
            self.function = connection.ops.spatial_function_name(self.name)
        return super(GeoFunc, self).as_sql(compiler, connection)

    def resolve_expression(self, *args, **kwargs):
        res = super(GeoFunc, self).resolve_expression(*args, **kwargs)
        base_srid = res.srid
        if not base_srid:
            raise TypeError("Geometry functions can only operate on geometric content.")

        for pos, expr in enumerate(res.source_expressions[1:], start=1):
            if isinstance(expr, GeomValue) and expr.srid != base_srid:
                # Automatic SRID conversion so objects are comparable
                res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
        return res

    def _handle_param(self, value, param_name='', check_types=None):
        if not hasattr(value, 'resolve_expression'):
            if check_types and not isinstance(value, check_types):
                raise TypeError(
                    "The %s parameter has the wrong type: should be %s." % (
                        param_name, str(check_types))
                )
        return value


class GeomValue(Value):
    geography = False

    @property
    def srid(self):
        return self.value.srid

    def as_sql(self, compiler, connection):
        if self.geography:
            self.value = connection.ops.Adapter(self.value, geography=self.geography)
        else:
            self.value = connection.ops.Adapter(self.value)
        return super(GeomValue, self).as_sql(compiler, connection)


class GeoFuncWithGeoParam(GeoFunc):
    def __init__(self, expression, geom, *expressions, **extra):
        if not hasattr(geom, 'srid'):
            # Try to interpret it as a geometry input
            try:
                geom = GEOSGeometry(geom)
            except Exception:
                raise ValueError("This function requires a geometric parameter.")
        if not geom.srid:
            raise ValueError("Please provide a geometry attribute with a defined SRID.")
        geom = GeomValue(geom)
        super(GeoFuncWithGeoParam, self).__init__(expression, geom, *expressions, **extra)


class Area(GeoFunc):
    def as_sql(self, compiler, connection):
        if connection.ops.oracle:
            self.output_field = AreaField('sq_m')  # Oracle returns area in units of meters.
        else:
            if connection.ops.geography:
                # Geography fields support area calculation, returns square meters.
                self.output_field = AreaField('sq_m')
            elif not self.output_field.geodetic(connection):
                # Getting the area units of the geographic field.
                self.output_field = AreaField(
                    AreaMeasure.unit_attname(self.output_field.units_name(connection)))
            else:
                # TODO: Do we want to support raw number areas for geodetic fields?
                raise NotImplementedError('Area on geodetic coordinate systems not supported.')
        return super(Area, self).as_sql(compiler, connection)


class AsGeoJSON(GeoFunc):
    output_field_class = TextField

    def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
        expressions = [expression]
        if precision is not None:
            expressions.append(self._handle_param(precision, 'precision', six.integer_types))
        options = 0
        if crs and bbox:
            options = 3
        elif bbox:
            options = 1
        elif crs:
            options = 2
        if options:
            expressions.append(options)
        super(AsGeoJSON, self).__init__(*expressions, **extra)


class AsGML(GeoFunc):
    geom_param_pos = 1
    output_field_class = TextField

    def __init__(self, expression, version=2, precision=8, **extra):
        expressions = [version, expression]
        if precision is not None:
            expressions.append(self._handle_param(precision, 'precision', six.integer_types))
        super(AsGML, self).__init__(*expressions, **extra)


class AsKML(AsGML):
    pass


class AsSVG(GeoFunc):
    output_field_class = TextField

    def __init__(self, expression, relative=False, precision=8, **extra):
        relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
        expressions = [
            expression,
            relative,
            self._handle_param(precision, 'precision', six.integer_types),
        ]
        super(AsSVG, self).__init__(*expressions, **extra)


class BoundingCircle(GeoFunc):
    def __init__(self, expression, num_seg=48, **extra):
        super(BoundingCircle, self).__init__(*[expression, num_seg], **extra)


class Centroid(GeoFunc):
    pass


class Difference(GeoFuncWithGeoParam):
    pass


class DistanceResultMixin(object):
    def convert_value(self, value, expression, connection, context):
        if value is None:
            return None
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        if geo_field.geodetic(connection):
            dist_att = 'm'
        else:
            dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection))
        return DistanceMeasure(**{dist_att: value})


class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
    output_field_class = FloatField
    spheroid = None

    def __init__(self, expr1, expr2, spheroid=None, **extra):
        expressions = [expr1, expr2]
        if spheroid is not None:
            self.spheroid = spheroid
            expressions += (self._handle_param(spheroid, 'spheroid', bool),)
        super(Distance, self).__init__(*expressions, **extra)

    def as_postgresql(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        src_field = self.get_source_fields()[0]
        geography = src_field.geography and self.srid == 4326
        if geography:
            # Set parameters as geography if base field is geography
            for pos, expr in enumerate(
                    self.source_expressions[self.geom_param_pos + 1:], start=self.geom_param_pos + 1):
                if isinstance(expr, GeomValue):
                    expr.geography = True
        elif geo_field.geodetic(connection):
            # Geometry fields with geodetic (lon/lat) coordinates need special distance functions
            if self.spheroid:
                self.function = 'ST_Distance_Spheroid'  # More accurate, resource intensive
                # Replace boolean param by the real spheroid of the base field
                self.source_expressions[2] = Value(geo_field._spheroid)
            else:
                self.function = 'ST_Distance_Sphere'
        return super(Distance, self).as_sql(compiler, connection)


class Envelope(GeoFunc):
    pass


class ForceRHR(GeoFunc):
    pass


class GeoHash(GeoFunc):
    output_field_class = TextField

    def __init__(self, expression, precision=None, **extra):
        expressions = [expression]
        if precision is not None:
            expressions.append(self._handle_param(precision, 'precision', six.integer_types))
        super(GeoHash, self).__init__(*expressions, **extra)


class Intersection(GeoFuncWithGeoParam):
    pass


class Length(DistanceResultMixin, GeoFunc):
    output_field_class = FloatField

    def __init__(self, expr1, spheroid=True, **extra):
        self.spheroid = spheroid
        super(Length, self).__init__(expr1, **extra)

    def as_postgresql(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        src_field = self.get_source_fields()[0]
        geography = src_field.geography and self.srid == 4326
        if geography:
            self.source_expressions.append(Value(self.spheroid))
        elif geo_field.geodetic(connection):
            # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
            self.function = 'ST_Length_Spheroid'
            self.source_expressions.append(Value(geo_field._spheroid))
        else:
            dim = min(f.dim for f in self.get_source_fields() if f)
            if dim > 2:
                self.function = connection.ops.length3d
        return super(Length, self).as_sql(compiler, connection)


class MemSize(GeoFunc):
    output_field_class = IntegerField


class NumGeometries(GeoFunc):
    output_field_class = IntegerField


class NumPoints(GeoFunc):
    output_field_class = IntegerField


class Perimeter(DistanceResultMixin, GeoFunc):
    output_field_class = FloatField

    def as_postgresql(self, compiler, connection):
        dim = min(f.dim for f in self.get_source_fields())
        if dim > 2:
            self.function = connection.ops.perimeter3d
        return super(Perimeter, self).as_sql(compiler, connection)


class PointOnSurface(GeoFunc):
    pass


class Reverse(GeoFunc):
    pass


class Scale(GeoFunc):
    def __init__(self, expression, x, y, z=0.0, **extra):
        expressions = [
            expression,
            self._handle_param(x, 'x', NUMERIC_TYPES),
            self._handle_param(y, 'y', NUMERIC_TYPES),
        ]
        if z != 0.0:
            expressions.append(self._handle_param(z, 'z', NUMERIC_TYPES))
        super(Scale, self).__init__(*expressions, **extra)


class SnapToGrid(GeoFunc):
    def __init__(self, expression, *args, **extra):
        nargs = len(args)
        expressions = [expression]
        if nargs in (1, 2):
            expressions.extend(
                [self._handle_param(arg, '', NUMERIC_TYPES) for arg in args]
            )
        elif nargs == 4:
            # Reverse origin and size param ordering
            expressions.extend(
                [self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[2:]]
            )
            expressions.extend(
                [self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[0:2]]
            )
        else:
            raise ValueError('Must provide 1, 2, or 4 arguments to `SnapToGrid`.')
        super(SnapToGrid, self).__init__(*expressions, **extra)


class SymDifference(GeoFuncWithGeoParam):
    pass


class Transform(GeoFunc):
    def __init__(self, expression, srid, **extra):
        expressions = [
            expression,
            self._handle_param(srid, 'srid', six.integer_types),
        ]
        super(Transform, self).__init__(*expressions, **extra)

    @property
    def srid(self):
        # Make srid the resulting srid of the transformation
        return self.source_expressions[self.geom_param_pos + 1].value


class Translate(Scale):
    pass


class Union(GeoFuncWithGeoParam):
    pass
Loading