Commit f1ada999 authored by Justin Bronn's avatar Justin Bronn
Browse files

Maintenance refactor of the GDAL (OGR) ctypes interface. Changes include:

* All C API method explictly called from their prototype module, no longer imported via *.
* Applied DRY to C pointer management, classes that do so subclass from `GDALBase`.
* `OGRGeometry`: Added `from_bbox` class method (patch from Christopher Schmidt) and `kml` property.
* `SpatialReference`: Now initialize with `SetFromUserInput` (initialization is now more simple and flexible); removed duplicate methods.
* `Envelope`: Added `expand_to_include` method and now allow same coordinates for lower left and upper right points.  Thanks to Paul Smith for tickets and patches.
* `OGRGeomType`: Now treat OGC 'Geometry' type as 'Unknown'.

Fixed #9855, #10368, #10380.  Refs #9806.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@9985 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 53da1e47
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
Copyright (c) 2007, Justin Bronn
Copyright (c) 2007-2009, Justin Bronn
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
+35 −0
Original line number Diff line number Diff line
from ctypes import c_void_p
from types import NoneType
from django.contrib.gis.gdal.error import GDALException

class GDALBase(object):
    """
    Base object for GDAL objects that has a pointer access property
    that controls access to the underlying C pointer.
    """
    # Initially the pointer is NULL.
    _ptr = None

    # Default allowed pointer type.
    ptr_type = c_void_p

    # Pointer access property.
    def _get_ptr(self):
        # Raise an exception if the pointer isn't valid don't
        # want to be passing NULL pointers to routines --
        # that's very bad.
        if self._ptr: return self._ptr
        else: raise GDALException('GDAL %s pointer no longer valid.' % self.__class__.__name__)

    def _set_ptr(self, ptr):
        # Only allow the pointer to be set with pointers of the
        # compatible type or None (NULL).
        if isinstance(ptr, int):
            self._ptr = self.ptr_type(ptr)
        elif isinstance(ptr, (self.ptr_type, NoneType)):
            self._ptr = ptr
        else:
            raise TypeError('Incompatible pointer type')

    ptr = property(_get_ptr, _set_ptr)
+15 −25
Original line number Diff line number Diff line
@@ -37,28 +37,23 @@
from ctypes import byref, c_void_p

# The GDAL C library, OGR exceptions, and the Layer object.
from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.driver import Driver
from django.contrib.gis.gdal.error import OGRException, OGRIndexError
from django.contrib.gis.gdal.layer import Layer

# Getting the ctypes prototypes for the DataSource.
from django.contrib.gis.gdal.prototypes.ds import \
    destroy_ds, get_driver_count, register_all, open_ds, release_ds, \
    get_ds_name, get_layer, get_layer_count, get_layer_by_name
from django.contrib.gis.gdal.prototypes import ds as capi

# For more information, see the OGR C API source code:
#  http://www.gdal.org/ogr/ogr__api_8h.html
#
# The OGR_DS_* routines are relevant here.
class DataSource(object):
class DataSource(GDALBase):
    "Wraps an OGR Data Source object."

    #### Python 'magic' routines ####
    def __init__(self, ds_input, ds_driver=False, write=False):

        # DataSource pointer is initially NULL.
        self._ptr = None

        # The write flag.
        if write:
            self._write = 1
@@ -67,33 +62,34 @@ class DataSource(object):

        # Registering all the drivers, this needs to be done
        #  _before_ we try to open up a data source.
        if not get_driver_count(): register_all()
        if not capi.get_driver_count():
            capi.register_all()

        if isinstance(ds_input, basestring):
            # The data source driver is a void pointer.
            ds_driver = c_void_p()
            ds_driver = Driver.ptr_type()
            try:
                # OGROpen will auto-detect the data source type.
                ds = open_ds(ds_input, self._write, byref(ds_driver))
                ds = capi.open_ds(ds_input, self._write, byref(ds_driver))
            except OGRException:
                # Making the error message more clear rather than something
                # like "Invalid pointer returned from OGROpen".
                raise OGRException('Could not open the datasource at "%s"' % ds_input)
        elif isinstance(ds_input, c_void_p) and isinstance(ds_driver, c_void_p):
        elif isinstance(ds_input, self.ptr_type) and isinstance(ds_driver, Driver.ptr_type):
            ds = ds_input
        else:
            raise OGRException('Invalid data source input type: %s' % type(ds_input))

        if bool(ds):
            self._ptr = ds
            self._driver = Driver(ds_driver)
            self.ptr = ds
            self.driver = Driver(ds_driver)
        else:
            # Raise an exception if the returned pointer is NULL 
            raise OGRException('Invalid data source file "%s"' % ds_input)

    def __del__(self):
        "Destroys this DataStructure object."
        if self._ptr: destroy_ds(self._ptr)
        if self._ptr: capi.destroy_ds(self._ptr)

    def __iter__(self):
        "Allows for iteration over the layers in a data source."
@@ -103,12 +99,12 @@ class DataSource(object):
    def __getitem__(self, index):
        "Allows use of the index [] operator to get a layer at the index."
        if isinstance(index, basestring):
            l = get_layer_by_name(self._ptr, index)
            l = capi.get_layer_by_name(self.ptr, index)
            if not l: raise OGRIndexError('invalid OGR Layer name given: "%s"' % index)
        elif isinstance(index, int):
            if index < 0 or index >= self.layer_count:
                raise OGRIndexError('index out of range')
            l = get_layer(self._ptr, index)
            l = capi.get_layer(self._ptr, index)
        else:
            raise TypeError('Invalid index type: %s' % type(index))
        return Layer(l, self)
@@ -121,18 +117,12 @@ class DataSource(object):
        "Returns OGR GetName and Driver for the Data Source."
        return '%s (%s)' % (self.name, str(self.driver))

    #### DataSource Properties ####
    @property
    def driver(self):
        "Returns the Driver object for this Data Source."
        return self._driver
        
    @property
    def layer_count(self):
        "Returns the number of layers in the data source."
        return get_layer_count(self._ptr)
        return capi.get_layer_count(self._ptr)

    @property
    def name(self):
        "Returns the name of the data source."
        return get_ds_name(self._ptr)
        return capi.get_ds_name(self._ptr)
+9 −10
Original line number Diff line number Diff line
# prerequisites imports 
from ctypes import c_void_p
from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.error import OGRException
from django.contrib.gis.gdal.prototypes.ds import \
    get_driver, get_driver_by_name, get_driver_count, get_driver_name, register_all
from django.contrib.gis.gdal.prototypes import ds as capi

# For more information, see the OGR C API source code:
#  http://www.gdal.org/ogr/ogr__api_8h.html
#
# The OGR_Dr_* routines are relevant here.
class Driver(object):
class Driver(GDALBase):
    "Wraps an OGR Data Source Driver."

    # Case-insensitive aliases for OGR Drivers.
@@ -24,7 +24,6 @@ class Driver(object):

        if isinstance(dr_input, basestring):
            # If a string name of the driver was passed in
            self._ptr = None # Initially NULL
            self._register()

            # Checking the alias dictionary (case-insensitive) to see if an alias
@@ -35,10 +34,10 @@ class Driver(object):
                name = dr_input

            # Attempting to get the OGR driver by the string name.
            dr = get_driver_by_name(name)
            dr = capi.get_driver_by_name(name)
        elif isinstance(dr_input, int):
            self._register()
            dr = get_driver(dr_input)
            dr = capi.get_driver(dr_input)
        elif isinstance(dr_input, c_void_p):
            dr = dr_input
        else:
@@ -47,20 +46,20 @@ class Driver(object):
        # Making sure we get a valid pointer to the OGR Driver
        if not dr:
            raise OGRException('Could not initialize OGR Driver on input: %s' % str(dr_input))
        self._ptr = dr
        self.ptr = dr

    def __str__(self):
        "Returns the string name of the OGR Driver."
        return get_driver_name(self._ptr)
        return capi.get_driver_name(self.ptr)

    def _register(self):
        "Attempts to register all the data source drivers."
        # Only register all if the driver count is 0 (or else all drivers
        # will be registered over and over again)
        if not self.driver_count: register_all()
        if not self.driver_count: capi.register_all()
                    
    # Driver properties
    @property
    def driver_count(self):
        "Returns the number of OGR data source drivers registered."
        return get_driver_count()
        return capi.get_driver_count()
+48 −7
Original line number Diff line number Diff line
@@ -11,7 +11,6 @@
 Lower left (min_x, min_y) o----------+
"""
from ctypes import Structure, c_double
from types import TupleType, ListType
from django.contrib.gis.gdal.error import OGRException

# The OGR definition of an Envelope is a C structure containing four doubles.
@@ -42,7 +41,7 @@ class Envelope(object):
            if isinstance(args[0], OGREnvelope):
                # OGREnvelope (a ctypes Structure) was passed in.
                self._envelope = args[0]
            elif isinstance(args[0], (TupleType, ListType)):
            elif isinstance(args[0], (tuple, list)):
                # A tuple was passed in.
                if len(args[0]) != 4:
                    raise OGRException('Incorrect number of tuple elements (%d).' % len(args[0]))
@@ -58,10 +57,10 @@ class Envelope(object):
            raise OGRException('Incorrect number (%d) of arguments.' % len(args))

        # Checking the x,y coordinates
        if self.min_x >= self.max_x:
            raise OGRException('Envelope minimum X >= maximum X.')
        if self.min_y >= self.max_y:
            raise OGRException('Envelope minimum Y >= maximum Y.')
        if self.min_x > self.max_x:
            raise OGRException('Envelope minimum X > maximum X.')
        if self.min_y > self.max_y:
            raise OGRException('Envelope minimum Y > maximum Y.')

    def __eq__(self, other):
        """
@@ -71,7 +70,7 @@ class Envelope(object):
        if isinstance(other, Envelope):
            return (self.min_x == other.min_x) and (self.min_y == other.min_y) and \
                   (self.max_x == other.max_x) and (self.max_y == other.max_y)
        elif isinstance(other, TupleType) and len(other) == 4:
        elif isinstance(other, tuple) and len(other) == 4:
            return (self.min_x == other[0]) and (self.min_y == other[1]) and \
                   (self.max_x == other[2]) and (self.max_y == other[3])
        else:
@@ -89,6 +88,48 @@ class Envelope(object):
        self._envelope.MaxX = seq[2]
        self._envelope.MaxY = seq[3]
    
    def expand_to_include(self, *args): 
        """ 
        Modifies the envelope to expand to include the boundaries of 
        the passed-in 2-tuple (a point), 4-tuple (an extent) or 
        envelope. 
        """ 
        # We provide a number of different signatures for this method, 
        # and the logic here is all about converting them into a 
        # 4-tuple single parameter which does the actual work of 
        # expanding the envelope. 
        if len(args) == 1: 
            if isinstance(args[0], Envelope): 
                return self.expand_to_include(args[0].tuple) 
            elif hasattr(args[0], 'x') and hasattr(args[0], 'y'):
                return self.expand_to_include(args[0].x, args[0].y, args[0].x, args[0].y) 
            elif isinstance(args[0], (tuple, list)): 
                # A tuple was passed in. 
                if len(args[0]) == 2: 
                    return self.expand_to_include((args[0][0], args[0][1], args[0][0], args[0][1])) 
                elif len(args[0]) == 4: 
                    (minx, miny, maxx, maxy) = args[0] 
                    if minx < self._envelope.MinX: 
                        self._envelope.MinX = minx 
                    if miny < self._envelope.MinY: 
                        self._envelope.MinY = miny 
                    if maxx > self._envelope.MaxX: 
                        self._envelope.MaxX = maxx 
                    if maxy > self._envelope.MaxY: 
                        self._envelope.MaxY = maxy 
                else: 
                    raise OGRException('Incorrect number of tuple elements (%d).' % len(args[0])) 
            else: 
                raise TypeError('Incorrect type of argument: %s' % str(type(args[0]))) 
        elif len(args) == 2: 
            # An x and an y parameter were passed in 
                return self.expand_to_include((args[0], args[1], args[0], args[1])) 
        elif len(args) == 4: 
            # Individiual parameters passed in. 
            return self.expand_to_include(args) 
        else: 
            raise OGRException('Incorrect number (%d) of arguments.' % len(args[0])) 

    @property
    def min_x(self):
        "Returns the value of the minimum X coordinate."
Loading