Commit bbfad84d authored by Daniel Wiesmann's avatar Daniel Wiesmann Committed by Tim Graham
Browse files

Fixed #25588 -- Added spatial lookups to RasterField.

Thanks Tim Graham for the review.
parent 03efa304
Loading
Loading
Loading
Loading
+31 −13
Original line number Diff line number Diff line
@@ -6,16 +6,27 @@ from __future__ import unicode_literals
from psycopg2 import Binary
from psycopg2.extensions import ISQLQuote

from django.contrib.gis.db.backends.postgis.pgraster import to_pgraster
from django.contrib.gis.geometry.backend import Geometry


class PostGISAdapter(object):
    def __init__(self, geom, geography=False):
        "Initializes on the geometry."
    def __init__(self, obj, geography=False):
        """
        Initialize on the spatial object.
        """
        self.is_geometry = isinstance(obj, Geometry)

        # Getting the WKB (in string form, to allow easy pickling of
        # the adaptor) and the SRID from the geometry.
        self.ewkb = bytes(geom.ewkb)
        self.srid = geom.srid
        self.geography = geography
        # the adaptor) and the SRID from the geometry or raster.
        if self.is_geometry:
            self.ewkb = bytes(obj.ewkb)
            self._adapter = Binary(self.ewkb)
        else:
            self.ewkb = to_pgraster(obj)

        self.srid = obj.srid
        self.geography = geography

    def __conform__(self, proto):
        # Does the given protocol conform to what Psycopg2 expects?
@@ -40,12 +51,19 @@ class PostGISAdapter(object):
        This method allows escaping the binary in the style required by the
        server's `standard_conforming_string` setting.
        """
        if self.is_geometry:
            self._adapter.prepare(conn)

    def getquoted(self):
        "Returns a properly quoted string for use in PostgreSQL/PostGIS."
        # psycopg will figure out whether to use E'\\000' or '\000'
        """
        Return a properly quoted string for use in PostgreSQL/PostGIS.
        """
        if self.is_geometry:
            # Psycopg will figure out whether to use E'\\000' or '\000'.
            return str('%s(%s)' % (
                'ST_GeogFromWKB' if self.geography else 'ST_GeomFromEWKB',
                self._adapter.getquoted().decode())
            )
        else:
            # For rasters, add explicit type cast to WKB string.
            return "'%s'::raster" % self.ewkb
+81 −27
Original line number Diff line number Diff line
@@ -4,30 +4,83 @@ from django.conf import settings
from django.contrib.gis.db.backends.base.operations import \
    BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.gdal import GDALRaster
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Distance
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.postgresql.operations import DatabaseOperations
from django.db.utils import ProgrammingError
from django.utils import six
from django.utils.functional import cached_property

from .adapter import PostGISAdapter
from .models import PostGISGeometryColumns, PostGISSpatialRefSys
from .pgraster import from_pgraster, get_pgraster_srid, to_pgraster

# Identifier to mark raster lookups as bilateral.
BILATERAL = 'bilateral'


class PostGISOperator(SpatialOperator):
    def __init__(self, geography=False, **kwargs):
        # Only a subset of the operators and functions are available
        # for the geography type.
    def __init__(self, geography=False, raster=False, **kwargs):
        # Only a subset of the operators and functions are available for the
        # geography type.
        self.geography = geography
        # Only a subset of the operators and functions are available for the
        # raster type. Lookups that don't suport raster will be converted to
        # polygons. If the raster argument is set to BILATERAL, then the
        # operator cannot handle mixed geom-raster lookups.
        self.raster = raster
        super(PostGISOperator, self).__init__(**kwargs)

    def as_sql(self, connection, lookup, *args):
    def as_sql(self, connection, lookup, template_params, *args):
        if lookup.lhs.output_field.geography and not self.geography:
            raise ValueError('PostGIS geography does not support the "%s" '
                             'function/operator.' % (self.func or self.op,))
        return super(PostGISOperator, self).as_sql(connection, lookup, *args)

        template_params = self.check_raster(lookup, template_params)
        return super(PostGISOperator, self).as_sql(connection, lookup, template_params, *args)

    def check_raster(self, lookup, template_params):
        # Get rhs value.
        if isinstance(lookup.rhs, (tuple, list)):
            rhs_val = lookup.rhs[0]
            spheroid = lookup.rhs[-1] == 'spheroid'
        else:
            rhs_val = lookup.rhs
            spheroid = False

        # Check which input is a raster.
        lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER'
        rhs_is_raster = isinstance(rhs_val, GDALRaster)

        # Look for band indices and inject them if provided.
        if lookup.band_lhs is not None and lhs_is_raster:
            if not self.func:
                raise ValueError('Band indices are not allowed for this operator, it works on bbox only.')
            template_params['lhs'] = '%s, %s' % (template_params['lhs'], lookup.band_lhs)

        if lookup.band_rhs is not None and rhs_is_raster:
            if not self.func:
                raise ValueError('Band indices are not allowed for this operator, it works on bbox only.')
            template_params['rhs'] = '%s, %s' % (template_params['rhs'], lookup.band_rhs)

        # Convert rasters to polygons if necessary.
        if not self.raster or spheroid:
            # Operators without raster support.
            if lhs_is_raster:
                template_params['lhs'] = 'ST_Polygon(%s)' % template_params['lhs']
            if rhs_is_raster:
                template_params['rhs'] = 'ST_Polygon(%s)' % template_params['rhs']
        elif self.raster == BILATERAL:
            # Operators with raster support but don't support mixed (rast-geom)
            # lookups.
            if lhs_is_raster and not rhs_is_raster:
                template_params['lhs'] = 'ST_Polygon(%s)' % template_params['lhs']
            elif rhs_is_raster and not lhs_is_raster:
                template_params['rhs'] = 'ST_Polygon(%s)' % template_params['rhs']

        return template_params


class PostGISDistanceOperator(PostGISOperator):
@@ -35,6 +88,7 @@ class PostGISDistanceOperator(PostGISOperator):

    def as_sql(self, connection, lookup, template_params, sql_params):
        if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
            template_params = self.check_raster(lookup, template_params)
            sql_template = self.sql_template
            if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
                template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
@@ -58,33 +112,33 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
    Adapter = PostGISAdapter

    gis_operators = {
        'bbcontains': PostGISOperator(op='~'),
        'bboverlaps': PostGISOperator(op='&&', geography=True),
        'contained': PostGISOperator(op='@'),
        'contains': PostGISOperator(func='ST_Contains'),
        'overlaps_left': PostGISOperator(op='&<'),
        'overlaps_right': PostGISOperator(op='&>'),
        'bbcontains': PostGISOperator(op='~', raster=True),
        'bboverlaps': PostGISOperator(op='&&', geography=True, raster=True),
        'contained': PostGISOperator(op='@', raster=True),
        'overlaps_left': PostGISOperator(op='&<', raster=BILATERAL),
        'overlaps_right': PostGISOperator(op='&>', raster=BILATERAL),
        'overlaps_below': PostGISOperator(op='&<|'),
        'overlaps_above': PostGISOperator(op='|&>'),
        'left': PostGISOperator(op='<<'),
        'right': PostGISOperator(op='>>'),
        'strictly_below': PostGISOperator(op='<<|'),
        'strictly_above': PostGISOperator(op='|>>'),
        'same_as': PostGISOperator(op='~='),
        'exact': PostGISOperator(op='~='),  # alias of same_as
        'contains_properly': PostGISOperator(func='ST_ContainsProperly'),
        'coveredby': PostGISOperator(func='ST_CoveredBy', geography=True),
        'covers': PostGISOperator(func='ST_Covers', geography=True),
        'same_as': PostGISOperator(op='~=', raster=BILATERAL),
        'exact': PostGISOperator(op='~=', raster=BILATERAL),  # alias of same_as
        'contains': PostGISOperator(func='ST_Contains', raster=BILATERAL),
        'contains_properly': PostGISOperator(func='ST_ContainsProperly', raster=BILATERAL),
        'coveredby': PostGISOperator(func='ST_CoveredBy', geography=True, raster=BILATERAL),
        'covers': PostGISOperator(func='ST_Covers', geography=True, raster=BILATERAL),
        'crosses': PostGISOperator(func='ST_Crosses'),
        'disjoint': PostGISOperator(func='ST_Disjoint'),
        'disjoint': PostGISOperator(func='ST_Disjoint', raster=BILATERAL),
        'equals': PostGISOperator(func='ST_Equals'),
        'intersects': PostGISOperator(func='ST_Intersects', geography=True),
        'intersects': PostGISOperator(func='ST_Intersects', geography=True, raster=BILATERAL),
        'isvalid': PostGISOperator(func='ST_IsValid'),
        'overlaps': PostGISOperator(func='ST_Overlaps'),
        'overlaps': PostGISOperator(func='ST_Overlaps', raster=BILATERAL),
        'relate': PostGISOperator(func='ST_Relate'),
        'touches': PostGISOperator(func='ST_Touches'),
        'within': PostGISOperator(func='ST_Within'),
        'dwithin': PostGISOperator(func='ST_DWithin', geography=True),
        'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
        'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
        'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
        'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True),
        'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True),
        'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True),
@@ -272,14 +326,14 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):

    def get_geom_placeholder(self, f, value, compiler):
        """
        Provides a proper substitution value for Geometries that are not in the
        SRID of the field.  Specifically, this routine will substitute in the
        ST_Transform() function call.
        Provide a proper substitution value for Geometries or rasters that are
        not in the SRID of the field. Specifically, this routine will
        substitute in the ST_Transform() function call.
        """
        # Get the srid for this object
        if value is None:
            value_srid = None
        elif f.geom_type == 'RASTER':
        elif f.geom_type == 'RASTER' and isinstance(value, six.string_types):
            value_srid = get_pgraster_srid(value)
        else:
            value_srid = value.srid
@@ -288,7 +342,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
        # is not equal to the field srid.
        if value_srid is None or value_srid == f.srid:
            placeholder = '%s'
        elif f.geom_type == 'RASTER':
        elif f.geom_type == 'RASTER' and isinstance(value, six.string_types):
            placeholder = '%s((%%s)::raster, %s)' % (self.transform, f.srid)
        else:
            placeholder = '%s(%%s, %s)' % (self.transform, f.srid)
+96 −25
Original line number Diff line number Diff line
from django.contrib.gis import forms
from django.contrib.gis.db.models.lookups import gis_lookups
from django.contrib.gis.db.models.lookups import (
    RasterBandTransform, gis_lookups,
)
from django.contrib.gis.db.models.proxy import SpatialProxy
from django.contrib.gis.gdal import HAS_GDAL
from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.geometry.backend import Geometry, GeometryException
from django.core.exceptions import ImproperlyConfigured
from django.db.models.expressions import Expression
@@ -157,6 +160,82 @@ class BaseSpatialField(Field):
        """
        return connection.ops.get_geom_placeholder(self, value, compiler)

    def get_srid(self, obj):
        """
        Return the default SRID for the given geometry or raster, taking into
        account the SRID set for the field. For example, if the input geometry
        or raster doesn't have an SRID, then the SRID of the field will be
        returned.
        """
        srid = obj.srid  # SRID of given geometry.
        if srid is None or self.srid == -1 or (srid == -1 and self.srid != -1):
            return self.srid
        else:
            return srid

    def get_db_prep_save(self, value, connection):
        """
        Prepare the value for saving in the database.
        """
        if not value:
            return None
        else:
            return connection.ops.Adapter(self.get_prep_value(value))

    def get_prep_value(self, value):
        """
        Spatial lookup values are either a parameter that is (or may be
        converted to) a geometry or raster, or a sequence of lookup values
        that begins with a geometry or raster. This routine sets up the
        geometry or raster value properly and preserves any other lookup
        parameters.
        """
        from django.contrib.gis.gdal import GDALRaster

        value = super(BaseSpatialField, self).get_prep_value(value)
        # For IsValid lookups, boolean values are allowed.
        if isinstance(value, (Expression, bool)):
            return value
        elif isinstance(value, (tuple, list)):
            obj = value[0]
            seq_value = True
        else:
            obj = value
            seq_value = False

        # When the input is not a geometry or raster, attempt to construct one
        # from the given string input.
        if isinstance(obj, (Geometry, GDALRaster)):
            pass
        elif isinstance(obj, (bytes, six.string_types)) or hasattr(obj, '__geo_interface__'):
            try:
                obj = Geometry(obj)
            except (GeometryException, GDALException):
                try:
                    obj = GDALRaster(obj)
                except GDALException:
                    raise ValueError("Couldn't create spatial object from lookup value '%s'." % obj)
        elif isinstance(obj, dict):
            try:
                obj = GDALRaster(obj)
            except GDALException:
                raise ValueError("Couldn't create spatial object from lookup value '%s'." % obj)
        else:
            raise ValueError('Cannot use object with type %s for a spatial lookup parameter.' % type(obj).__name__)

        # Assigning the SRID value.
        obj.srid = self.get_srid(obj)

        if seq_value:
            lookup_val = [obj]
            lookup_val.extend(value[1:])
            return tuple(lookup_val)
        else:
            return obj

for klass in gis_lookups.values():
    BaseSpatialField.register_lookup(klass)


class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
    """
@@ -224,6 +303,8 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
        value properly, and preserve any other lookup parameters before
        returning to the caller.
        """
        from django.contrib.gis.gdal import GDALRaster

        value = super(GeometryField, self).get_prep_value(value)
        if isinstance(value, (Expression, bool)):
            return value
@@ -236,7 +317,7 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):

        # When the input is not a GEOS geometry, attempt to construct one
        # from the given string input.
        if isinstance(geom, Geometry):
        if isinstance(geom, (Geometry, GDALRaster)):
            pass
        elif isinstance(geom, (bytes, six.string_types)) or hasattr(geom, '__geo_interface__'):
            try:
@@ -265,18 +346,6 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
                value.srid = self.srid
        return value

    def get_srid(self, geom):
        """
        Returns the default SRID for the given geometry, taking into account
        the SRID set for the field.  For example, if the input geometry
        has no SRID, then that of the field will be returned.
        """
        gsrid = geom.srid  # SRID of given geometry.
        if gsrid is None or self.srid == -1 or (gsrid == -1 and self.srid != -1):
            return self.srid
        else:
            return gsrid

    # ### Routines overloaded from Field ###
    def contribute_to_class(self, cls, name, **kwargs):
        super(GeometryField, self).contribute_to_class(cls, name, **kwargs)
@@ -316,17 +385,6 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
            params = [connection.ops.Adapter(value)]
        return params

    def get_db_prep_save(self, value, connection):
        "Prepares the value for saving in the database."
        if not value:
            return None
        else:
            return connection.ops.Adapter(self.get_prep_value(value))


for klass in gis_lookups.values():
    GeometryField.register_lookup(klass)


# The OpenGIS Geometry Type Fields
class PointField(GeometryField):
@@ -387,6 +445,7 @@ class RasterField(BaseSpatialField):

    description = _("Raster Field")
    geom_type = 'RASTER'
    geography = False

    def __init__(self, *args, **kwargs):
        if not HAS_GDAL:
@@ -421,3 +480,15 @@ class RasterField(BaseSpatialField):
        # delays the instantiation of the objects to the moment of evaluation
        # of the raster attribute.
        setattr(cls, self.attname, SpatialProxy(GDALRaster, self))

    def get_transform(self, name):
        try:
            band_index = int(name)
            return type(
                'SpecificRasterBandTransform',
                (RasterBandTransform, ),
                {'band_index': band_index}
            )
        except ValueError:
            pass
        return super(RasterField, self).get_transform(name)
+59 −10
Original line number Diff line number Diff line
@@ -5,16 +5,23 @@ import re
from django.core.exceptions import FieldDoesNotExist
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Expression
from django.db.models.lookups import BuiltinLookup, Lookup
from django.db.models.lookups import BuiltinLookup, Lookup, Transform
from django.utils import six

gis_lookups = {}


class RasterBandTransform(Transform):
    def as_sql(self, compiler, connection):
        return compiler.compile(self.lhs)


class GISLookup(Lookup):
    sql_template = None
    transform_func = None
    distance = False
    band_rhs = None
    band_lhs = None

    def __init__(self, *args, **kwargs):
        super(GISLookup, self).__init__(*args, **kwargs)
@@ -28,10 +35,10 @@ class GISLookup(Lookup):
        'point, 'the_geom', or a related lookup on a geographic field like
        'address__point'.

        If a GeometryField exists according to the given lookup on the model
        options, it will be returned.  Otherwise returns None.
        If a BaseSpatialField exists according to the given lookup on the model
        options, it will be returned. Otherwise return None.
        """
        from django.contrib.gis.db.models.fields import GeometryField
        from django.contrib.gis.db.models.fields import BaseSpatialField
        # This takes into account the situation where the lookup is a
        # lookup to a related geographic field, e.g., 'address__point'.
        field_list = lookup.split(LOOKUP_SEP)
@@ -55,11 +62,34 @@ class GISLookup(Lookup):
            return False

        # Finally, make sure we got a Geographic field and return.
        if isinstance(geo_fld, GeometryField):
        if isinstance(geo_fld, BaseSpatialField):
            return geo_fld
        else:
            return False

    def process_band_indices(self, only_lhs=False):
        """
        Extract the lhs band index from the band transform class and the rhs
        band index from the input tuple.
        """
        # PostGIS band indices are 1-based, so the band index needs to be
        # increased to be consistent with the GDALRaster band indices.
        if only_lhs:
            self.band_rhs = 1
            self.band_lhs = self.lhs.band_index + 1
            return

        if isinstance(self.lhs, RasterBandTransform):
            self.band_lhs = self.lhs.band_index + 1
        else:
            self.band_lhs = 1

        self.band_rhs = self.rhs[1]
        if len(self.rhs) == 1:
            self.rhs = self.rhs[0]
        else:
            self.rhs = (self.rhs[0], ) + self.rhs[2:]

    def get_db_prep_lookup(self, value, connection):
        # get_db_prep_lookup is called by process_rhs from super class
        if isinstance(value, (tuple, list)):
@@ -70,10 +100,9 @@ class GISLookup(Lookup):
        return ('%s', params)

    def process_rhs(self, compiler, connection):
        rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection)
        if hasattr(self.rhs, '_as_sql'):
            # If rhs is some QuerySet, don't touch it
            return rhs, rhs_params
            return super(GISLookup, self).process_rhs(compiler, connection)

        geom = self.rhs
        if isinstance(self.rhs, Col):
@@ -85,9 +114,19 @@ class GISLookup(Lookup):
                raise ValueError('No geographic field found in expression.')
            self.rhs.srid = geo_fld.srid
        elif isinstance(self.rhs, Expression):
            raise ValueError('Complex expressions not supported for GeometryField')
            raise ValueError('Complex expressions not supported for spatial fields.')
        elif isinstance(self.rhs, (list, tuple)):
            geom = self.rhs[0]
            # Check if a band index was passed in the query argument.
            if ((len(self.rhs) == 2 and not self.lookup_name == 'relate') or
                    (len(self.rhs) == 3 and self.lookup_name == 'relate')):
                self.process_band_indices()
            elif len(self.rhs) > 2:
                raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
        elif isinstance(self.lhs, RasterBandTransform):
            self.process_band_indices(only_lhs=True)

        rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection)
        rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler)
        return rhs, rhs_params

@@ -274,6 +313,8 @@ class IsValidLookup(BuiltinLookup):
    lookup_name = 'isvalid'

    def as_sql(self, compiler, connection):
        if self.lhs.field.geom_type == 'RASTER':
            raise ValueError('The isvalid lookup is only available on geometry fields.')
        gis_op = connection.ops.gis_operators[self.lookup_name]
        sql, params = self.process_lhs(compiler, connection)
        sql = '%(func)s(%(lhs)s)' % {'func': gis_op.func, 'lhs': sql}
@@ -323,9 +364,17 @@ class DistanceLookupBase(GISLookup):
    sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'

    def process_rhs(self, compiler, connection):
        if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 3:
            raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name)
        if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 4:
            raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
        elif len(self.rhs) == 4 and not self.rhs[3] == 'spheroid':
            raise ValueError("For 4-element tuples the last argument must be the 'speroid' directive.")

        # Check if the second parameter is a band index.
        if len(self.rhs) > 2 and not self.rhs[2] == 'spheroid':
            self.process_band_indices()

        params = [connection.ops.Adapter(self.rhs[0])]

        # Getting the distance parameter in the units of the field.
        dist_param = self.rhs[1]
        if hasattr(dist_param, 'resolve_expression'):
+120 −47

File changed.

Preview size limit exceeded, changes collapsed.

Loading