Commit f269c1d6 authored by Daniel Wiesmann's avatar Daniel Wiesmann Committed by Claude Paroz
Browse files

Added write support for GDALRaster

- Instantiation of GDALRaster instances from dict or json data.
- Retrieve and write pixel values in GDALBand objects.
- Support for the GDALFlushCache in gdal C prototypes
- Added private flush method to GDALRaster to make sure all
  data is written to files when file-based rasters are changed.
- Replaced ``ptr`` with ``_ptr`` for internal ptr variable

Refs #23804. Thanks Claude Paroz and Tim Graham for the reviews.
parent 8758a63d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ get_driver_description = const_string_output(lgdal.GDALGetDescription, [c_void_p
create_ds = voidptr_output(lgdal.GDALCreate, [c_void_p, c_char_p, c_int, c_int, c_int, c_int])
open_ds = voidptr_output(lgdal.GDALOpen, [c_char_p, c_int])
close_ds = void_output(lgdal.GDALClose, [c_void_p])
flush_ds = int_output(lgdal.GDALFlushCache, [c_void_p])
copy_ds = voidptr_output(lgdal.GDALCreateCopy, [c_void_p, c_char_p, c_void_p, c_int,
                                                POINTER(c_char_p), c_void_p, c_void_p])
add_band_ds = void_output(lgdal.GDALAddBand, [c_void_p, c_int])
+89 −14
Original line number Diff line number Diff line
@@ -2,9 +2,11 @@ from ctypes import byref, c_int

from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.prototypes import raster as capi
from django.contrib.gis.shortcuts import numpy
from django.utils import six
from django.utils.encoding import force_text

from .const import GDAL_PIXEL_TYPES
from .const import GDAL_PIXEL_TYPES, GDAL_TO_CTYPES


class GDALBand(GDALBase):
@@ -13,51 +15,49 @@ class GDALBand(GDALBase):
    """
    def __init__(self, source, index):
        self.source = source
        self.ptr = capi.get_ds_raster_band(source.ptr, index)
        self._ptr = capi.get_ds_raster_band(source._ptr, index)

    @property
    def description(self):
        """
        Returns the description string of the band.
        """
        return force_text(capi.get_band_description(self.ptr))
        return force_text(capi.get_band_description(self._ptr))

    @property
    def width(self):
        """
        Width (X axis) in pixels of the band.
        """
        return capi.get_band_xsize(self.ptr)
        return capi.get_band_xsize(self._ptr)

    @property
    def height(self):
        """
        Height (Y axis) in pixels of the band.
        """
        return capi.get_band_ysize(self.ptr)
        return capi.get_band_ysize(self._ptr)

    def datatype(self, as_string=False):
    @property
    def pixel_count(self):
        """
        Returns the GDAL Pixel Datatype for this band.
        Returns the total number of pixels in this band.
        """
        dtype = capi.get_band_datatype(self.ptr)
        if as_string:
            dtype = GDAL_PIXEL_TYPES[dtype]
        return dtype
        return self.width * self.height

    @property
    def min(self):
        """
        Returns the minimum pixel value for this band.
        """
        return capi.get_band_minimum(self.ptr, byref(c_int()))
        return capi.get_band_minimum(self._ptr, byref(c_int()))

    @property
    def max(self):
        """
        Returns the maximum pixel value for this band.
        """
        return capi.get_band_maximum(self.ptr, byref(c_int()))
        return capi.get_band_maximum(self._ptr, byref(c_int()))

    @property
    def nodata_value(self):
@@ -65,5 +65,80 @@ class GDALBand(GDALBase):
        Returns the nodata value for this band, or None if it isn't set.
        """
        nodata_exists = c_int()
        value = capi.get_band_nodata_value(self.ptr, nodata_exists)
        value = capi.get_band_nodata_value(self._ptr, nodata_exists)
        return value if nodata_exists else None

    @nodata_value.setter
    def nodata_value(self, value):
        """
        Sets the nodata value for this band.
        """
        if not isinstance(value, (int, float)):
            raise ValueError('Nodata value must be numeric.')
        capi.set_band_nodata_value(self._ptr, value)
        self.source._flush()

    def datatype(self, as_string=False):
        """
        Returns the GDAL Pixel Datatype for this band.
        """
        dtype = capi.get_band_datatype(self._ptr)
        if as_string:
            dtype = GDAL_PIXEL_TYPES[dtype]
        return dtype

    def data(self, data=None, offset=None, size=None, as_memoryview=False):
        """
        Reads or writes pixel values for this band. Blocks of data can
        be accessed by specifying the width, height and offset of the
        desired block. The same specification can be used to update
        parts of a raster by providing an array of values.

        Allowed input data types are bytes, memoryview, list, tuple, and array.
        """
        if not offset:
            offset = (0, 0)

        if not size:
            size = (self.width - offset[0], self.height - offset[1])

        if any(x <= 0 for x in size):
            raise ValueError('Offset too big for this raster.')

        if size[0] > self.width or size[1] > self.height:
            raise ValueError('Size is larger than raster.')

        # Create ctypes type array generator
        ctypes_array = GDAL_TO_CTYPES[self.datatype()] * (size[0] * size[1])

        if data is None:
            # Set read mode
            access_flag = 0
            # Prepare empty ctypes array
            data_array = ctypes_array()
        else:
            # Set write mode
            access_flag = 1

            # Instantiate ctypes array holding the input data
            if isinstance(data, (bytes, six.memoryview, numpy.ndarray)):
                data_array = ctypes_array.from_buffer_copy(data)
            else:
                data_array = ctypes_array(*data)

        # Access band
        capi.band_io(self._ptr, access_flag, offset[0], offset[1],
                     size[0], size[1], byref(data_array), size[0],
                     size[1], self.datatype(), 0, 0)

        # Return data as numpy array if possible, otherwise as list
        if data is None:
            if as_memoryview:
                return memoryview(data_array)
            elif numpy:
                return numpy.frombuffer(
                    data_array, dtype=numpy.dtype(data_array)).reshape(size)
            else:
                return list(data_array)
        else:
            self.source._flush()
+12 −0
Original line number Diff line number Diff line
"""
GDAL - Constant definitions
"""
from ctypes import (
    c_byte, c_double, c_float, c_int16, c_int32, c_uint16, c_uint32,
)

# See http://www.gdal.org/gdal_8h.html#a22e22ce0a55036a96f652765793fb7a4
GDAL_PIXEL_TYPES = {
@@ -17,3 +20,12 @@ GDAL_PIXEL_TYPES = {
    10: 'GDT_CFloat32',  # Complex Float32
    11: 'GDT_CFloat64',  # Complex Float64
}

# Lookup values to convert GDAL pixel type indices into ctypes objects.
# The GDAL band-io works with ctypes arrays to hold data to be written
# or to hold the space for data to be read into. The lookup below helps
# selecting the right ctypes object for a given gdal pixel type.
GDAL_TO_CTYPES = [
    None, c_byte, c_uint16, c_int16, c_uint32, c_int32,
    c_float, c_double, None, None, None, None
]
+140 −20
Original line number Diff line number Diff line
import json
import os
from ctypes import addressof, byref, c_double

@@ -7,6 +8,7 @@ from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.gdal.prototypes import raster as capi
from django.contrib.gis.gdal.raster.band import GDALBand
from django.contrib.gis.gdal.srs import SpatialReference, SRSException
from django.contrib.gis.geometry.regex import json_regex
from django.utils import six
from django.utils.encoding import (
    force_bytes, force_text, python_2_unicode_compatible,
@@ -33,10 +35,22 @@ class TransformPoint(list):
    def x(self):
        return self[0]

    @x.setter
    def x(self, value):
        gtf = self._raster.geotransform
        gtf[self.indices[self._prop][0]] = value
        self._raster.geotransform = gtf

    @property
    def y(self):
        return self[1]

    @y.setter
    def y(self, value):
        gtf = self._raster.geotransform
        gtf[self.indices[self._prop][1]] = value
        self._raster.geotransform = gtf


@python_2_unicode_compatible
class GDALRaster(GDALBase):
@@ -47,17 +61,64 @@ class GDALRaster(GDALBase):
        self._write = 1 if write else 0
        Driver.ensure_registered()

        # Preprocess json inputs. This converts json strings to dictionaries,
        # which are parsed below the same way as direct dictionary inputs.
        if isinstance(ds_input, six.string_types) and json_regex.match(ds_input):
            ds_input = json.loads(ds_input)

        # If input is a valid file path, try setting file as source.
        if isinstance(ds_input, six.string_types):
            if os.path.exists(ds_input):
            if not os.path.exists(ds_input):
                raise GDALException('Unable to read raster source input "{}"'.format(ds_input))
            try:
                # GDALOpen will auto-detect the data source type.
                    self.ptr = capi.open_ds(force_bytes(ds_input), self._write)
                self._ptr = capi.open_ds(force_bytes(ds_input), self._write)
            except GDALException as err:
                    raise GDALException('Could not open the datasource at "{}" ({}).'.format(
                        ds_input, err))
            else:
                raise GDALException('Unable to read raster source input "{}"'.format(ds_input))
                raise GDALException('Could not open the datasource at "{}" ({}).'.format(ds_input, err))
        elif isinstance(ds_input, dict):
            # A new raster needs to be created in write mode
            self._write = 1

            # Create driver (in memory by default)
            driver = Driver(ds_input.get('driver', 'MEM'))

            # For out of memory drivers, check filename argument
            if driver.name != 'MEM' and 'name' not in ds_input:
                raise GDALException('Specify name for creation of raster with driver "{}".'.format(driver.name))

            # Check if width and height where specified
            if 'width' not in ds_input or 'height' not in ds_input:
                raise GDALException('Specify width and height attributes for JSON or dict input.')

            # Create GDAL Raster
            self._ptr = capi.create_ds(
                driver._ptr,
                force_bytes(ds_input.get('name', '')),
                ds_input['width'],
                ds_input['height'],
                ds_input.get('nr_of_bands', len(ds_input.get('bands', []))),
                ds_input.get('datatype', 6),
                None
            )

            # Set band data if provided
            for i, band_input in enumerate(ds_input.get('bands', [])):
                self.bands[i].data(band_input['data'])
                if 'nodata_value' in band_input:
                    self.bands[i].nodata_value = band_input['nodata_value']

            # Set SRID, default to 0 (this assures SRS is always instanciated)
            self.srs = ds_input.get('srid', 0)

            # Set additional properties if provided
            if 'origin' in ds_input:
                self.origin.x, self.origin.y = ds_input['origin']

            if 'scale' in ds_input:
                self.scale.x, self.scale.y = ds_input['scale']

            if 'skew' in ds_input:
                self.skew.x, self.skew.y = ds_input['skew']
        else:
            raise GDALException('Invalid data source input type: "{}".'.format(type(ds_input)))

@@ -72,15 +133,34 @@ class GDALRaster(GDALBase):
        """
        Short-hand representation because WKB may be very large.
        """
        return '<Raster object at %s>' % hex(addressof(self.ptr))
        return '<Raster object at %s>' % hex(addressof(self._ptr))

    def _flush(self):
        """
        Flush all data from memory into the source file if it exists.
        The data that needs flushing are geotransforms, coordinate systems,
        nodata_values and pixel values. This function will be called
        automatically wherever it is needed.
        """
        # Raise an Exception if the value is being changed in read mode.
        if not self._write:
            raise GDALException('Raster needs to be opened in write mode to change values.')
        capi.flush_ds(self._ptr)

    @property
    def name(self):
        return force_text(capi.get_ds_description(self.ptr))
        """
        Returns the name of this raster. Corresponds to filename
        for file-based rasters.
        """
        return force_text(capi.get_ds_description(self._ptr))

    @cached_property
    def driver(self):
        ds_driver = capi.get_ds_driver(self.ptr)
        """
        Returns the GDAL Driver used for this raster.
        """
        ds_driver = capi.get_ds_driver(self._ptr)
        return Driver(ds_driver)

    @property
@@ -88,14 +168,14 @@ class GDALRaster(GDALBase):
        """
        Width (X axis) in pixels.
        """
        return capi.get_ds_xsize(self.ptr)
        return capi.get_ds_xsize(self._ptr)

    @property
    def height(self):
        """
        Height (Y axis) in pixels.
        """
        return capi.get_ds_ysize(self.ptr)
        return capi.get_ds_ysize(self._ptr)

    @property
    def srs(self):
@@ -103,33 +183,70 @@ class GDALRaster(GDALBase):
        Returns the SpatialReference used in this GDALRaster.
        """
        try:
            wkt = capi.get_ds_projection_ref(self.ptr)
            wkt = capi.get_ds_projection_ref(self._ptr)
            if not wkt:
                return None
            return SpatialReference(wkt, srs_type='wkt')
        except SRSException:
            return None

    @cached_property
    @srs.setter
    def srs(self, value):
        """
        Sets the spatial reference used in this GDALRaster. The input can be
        a SpatialReference or any parameter accepted by the SpatialReference
        constructor.
        """
        if isinstance(value, SpatialReference):
            srs = value
        elif isinstance(value, six.integer_types + six.string_types):
            srs = SpatialReference(value)
        else:
            raise ValueError('Could not create a SpatialReference from input.')
        capi.set_ds_projection_ref(self._ptr, srs.wkt.encode())
        self._flush()

    @property
    def geotransform(self):
        """
        Returns the geotransform of the data source.
        Returns the default geotransform if it does not exist or has not been
        set previously. The default is (0.0, 1.0, 0.0, 0.0, 0.0, -1.0).
        set previously. The default is [0.0, 1.0, 0.0, 0.0, 0.0, -1.0].
        """
        # Create empty ctypes double array for data
        gtf = (c_double * 6)()
        capi.get_ds_geotransform(self.ptr, byref(gtf))
        return tuple(gtf)
        capi.get_ds_geotransform(self._ptr, byref(gtf))
        return list(gtf)

    @geotransform.setter
    def geotransform(self, values):
        "Sets the geotransform for the data source."
        if sum([isinstance(x, (int, float)) for x in values]) != 6:
            raise ValueError('Geotransform must consist of 6 numeric values.')
        # Create ctypes double array with input and write data
        values = (c_double * 6)(*values)
        capi.set_ds_geotransform(self._ptr, byref(values))
        self._flush()

    @property
    def origin(self):
        """
        Coordinates of the raster origin.
        """
        return TransformPoint(self, 'origin')

    @property
    def scale(self):
        """
        Pixel scale in units of the raster projection.
        """
        return TransformPoint(self, 'scale')

    @property
    def skew(self):
        """
        Skew of pixels (rotation parameters).
        """
        return TransformPoint(self, 'skew')

    @property
@@ -150,7 +267,10 @@ class GDALRaster(GDALBase):

    @cached_property
    def bands(self):
        """
        Returns the bands of this raster as a list of GDALBand instances.
        """
        bands = []
        for idx in range(1, capi.get_ds_raster_count(self.ptr) + 1):
        for idx in range(1, capi.get_ds_raster_count(self._ptr) + 1):
            bands.append(GDALBand(self, idx))
        return bands
+183 −14

File changed.

Preview size limit exceeded, changes collapsed.

Loading