Commit 71e20814 authored by Claude Paroz's avatar Claude Paroz
Browse files

Added MySQL support to GIS functions

parent 44bdbbc3
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -6,6 +6,8 @@ from django.db.backends.mysql.features import \
class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures):
    has_spatialrefsys_table = False
    supports_add_srs_entry = False
    supports_distance_geodetic = False
    supports_length_geodetic = False
    supports_distances_lookups = False
    supports_transform = False
    supports_real_shape_operations = False
+23 −1
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from django.contrib.gis.db.backends.base.operations import \
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.db.backends.mysql.operations import DatabaseOperations
from django.utils.functional import cached_property


class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
@@ -32,7 +33,28 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
        'within': SpatialOperator(func='MBRWithin'),
    }

    disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union)
    function_names = {
        'Distance': 'ST_Distance',
        'Length': 'GLength',
        'Union': 'ST_Union',
    }

    disallowed_aggregates = (
        aggregates.Collect, aggregates.Extent, aggregates.Extent3D,
        aggregates.MakeLine, aggregates.Union,
    )

    @cached_property
    def unsupported_functions(self):
        unsupported = {
            'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG', 'BoundingCircle',
            'Difference', 'ForceRHR', 'GeoHash', 'Intersection', 'MemSize',
            'Perimeter', 'PointOnSurface', 'Reverse', 'Scale', 'SnapToGrid',
            'SymDifference', 'Transform', 'Translate',
        }
        if self.connection.mysql_version < (5, 6, 1):
            unsupported.update({'Distance', 'Union'})
        return unsupported

    def geo_db_type(self, f):
        return f.geom_type
+4 −1
Original line number Diff line number Diff line
@@ -169,7 +169,10 @@ class GeometryField(GeoSelectFormatMixin, Field):
        Returns true if this field's SRID corresponds with a coordinate
        system that uses non-projected units (e.g., latitude/longitude).
        """
        return self.units_name(connection).lower() in self.geodetic_units
        units_name = self.units_name(connection)
        # Some backends like MySQL cannot determine units name. In that case,
        # test if srid is 4326 (WGS84), even if this is over-simplification.
        return units_name.lower() in self.geodetic_units if units_name else self.srid == 4326

    def get_distance(self, value, lookup_type, connection):
        """
+23 −4
Original line number Diff line number Diff line
@@ -79,6 +79,9 @@ class GeomValue(Value):
            self.value = connection.ops.Adapter(self.value)
        return super(GeomValue, self).as_sql(compiler, connection)

    def as_mysql(self, compiler, connection):
        return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]

    def as_sqlite(self, compiler, connection):
        return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]

@@ -119,8 +122,12 @@ class Area(GeoFunc):
                self.output_field = AreaField('sq_m')
            elif not self.output_field.geodetic(connection):
                # Getting the area units of the geographic field.
                units = self.output_field.units_name(connection)
                if units:
                    self.output_field = AreaField(
                        AreaMeasure.unit_attname(self.output_field.units_name(connection)))
                else:
                    self.output_field = FloatField()
            else:
                # TODO: Do we want to support raw number areas for geodetic fields?
                raise NotImplementedError('Area on geodetic coordinate systems not supported.')
@@ -198,8 +205,14 @@ class DistanceResultMixin(object):
        if geo_field.geodetic(connection):
            dist_att = 'm'
        else:
            dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection))
            units = geo_field.units_name(connection)
            if units:
                dist_att = DistanceMeasure.unit_attname(units)
            else:
                dist_att = None
        if dist_att:
            return DistanceMeasure(**{dist_att: value})
        return value


class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
@@ -263,6 +276,12 @@ class Length(DistanceResultMixin, GeoFunc):
        self.spheroid = spheroid
        super(Length, self).__init__(expr1, **extra)

    def as_sql(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
            raise NotImplementedError("This backend doesn't support Length on geodetic fields")
        return super(Length, self).as_sql(compiler, connection)

    def as_postgresql(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        src_field = self.get_source_fields()[0]
+6 −5
Original line number Diff line number Diff line
@@ -438,7 +438,8 @@ class DistanceFunctionsTests(TestCase):
        # Tolerance has to be lower for Oracle
        tol = 2
        for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')):
            self.assertAlmostEqual(area_sq_m[i], z.area.sq_m, tol)
            # MySQL is returning a raw float value
            self.assertAlmostEqual(area_sq_m[i], z.area.sq_m if hasattr(z.area, 'sq_m') else z.area, tol)

    @skipUnlessDBFeature("has_Distance_function")
    def test_distance_simple(self):
@@ -624,12 +625,12 @@ class DistanceFunctionsTests(TestCase):
            # TODO: test with spheroid argument (True and False)
        else:
            # Does not support geodetic coordinate systems.
            with self.assertRaises(ValueError):
                Interstate.objects.annotate(length=Length('path'))
            with self.assertRaises(NotImplementedError):
                list(Interstate.objects.annotate(length=Length('path')))

        # Now doing length on a projected coordinate system.
        i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10')
        self.assertAlmostEqual(len_m2, i10.length.m, 2)
        self.assertAlmostEqual(len_m2, i10.length.m if isinstance(i10.length, D) else i10.length, 2)
        self.assertTrue(
            SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists()
        )
@@ -652,7 +653,7 @@ class DistanceFunctionsTests(TestCase):
        for city in qs:
            self.assertEqual(0, city.perim.m)

    @skipUnlessDBFeature("has_Area_function", "has_Distance_function")
    @skipUnlessDBFeature("supports_null_geometries", "has_Area_function", "has_Distance_function")
    def test_measurement_null_fields(self):
        """
        Test the measurement functions on fields with NULL values.
Loading