Commit a08d2463 authored by Daniel Wiesmann's avatar Daniel Wiesmann Committed by Claude Paroz
Browse files

Fixed #26112 -- Error when computing aggregate of GIS areas.

Thanks Simon Charette and Claude Paroz for the reviews.
parent 16baec5c
Loading
Loading
Loading
Loading
+12 −13
Original line number Diff line number Diff line
@@ -117,24 +117,23 @@ class OracleToleranceMixin(object):


class Area(OracleToleranceMixin, GeoFunc):
    output_field_class = AreaField
    arity = 1

    def as_sql(self, compiler, connection):
        if connection.ops.geography:
            # Geography fields support area calculation, returns square meters.
            self.output_field = AreaField('sq_m')
        elif not self.output_field.geodetic(connection):
            # Getting the area units of the geographic field.
            units = self.output_field.units_name(connection)
            if units:
                self.output_field = AreaField(
                    AreaMeasure.unit_attname(self.output_field.units_name(connection))
                )
            else:
                self.output_field = FloatField()
            self.output_field.area_att = 'sq_m'
        else:
            # Getting the area units of the geographic field.
            source_fields = self.get_source_fields()
            if len(source_fields):
                source_field = source_fields[0]
                if source_field.geodetic(connection):
                    # TODO: Do we want to support raw number areas for geodetic fields?
                    raise NotImplementedError('Area on geodetic coordinate systems not supported.')
                units_name = source_field.units_name(connection)
                if units_name:
                    self.output_field.area_att = AreaMeasure.unit_attname(units_name)
        return super(Area, self).as_sql(compiler, connection)

    def as_oracle(self, compiler, connection):
+3 −2
Original line number Diff line number Diff line
@@ -21,13 +21,14 @@ class BaseField(object):

class AreaField(BaseField):
    "Wrapper for Area values."
    def __init__(self, area_att):
    def __init__(self, area_att=None):
        self.area_att = area_att

    def from_db_value(self, value, expression, connection, context):
        if connection.features.interprets_empty_strings_as_nulls and value == '':
            value = None
        if value is not None:
        # If the units are known, convert value into area measure.
        if value is not None and self.area_att:
            value = Area(**{self.area_att: value})
        return value

+4 −0
Original line number Diff line number Diff line
@@ -22,6 +22,10 @@ class Country(NamedModel):
    mpoly = models.MultiPolygonField()  # SRID, by default, is 4326


class CountryWebMercator(NamedModel):
    mpoly = models.MultiPolygonField(srid=3857)


class City(NamedModel):
    point = models.PointField()

+17 −1
Original line number Diff line number Diff line
@@ -5,12 +5,14 @@ from decimal import Decimal

from django.contrib.gis.db.models import functions
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
from django.contrib.gis.measure import Area
from django.db import connection
from django.db.models import Sum
from django.test import TestCase, skipUnlessDBFeature
from django.utils import six

from ..utils import mysql, oracle, postgis, spatialite
from .models import City, Country, State, Track
from .models import City, Country, CountryWebMercator, State, Track


@skipUnlessDBFeature("gis_enabled")
@@ -231,6 +233,20 @@ class GISFunctionsTests(TestCase):
                expected = c.mpoly.intersection(geom)
            self.assertEqual(c.inter, expected)

    @skipUnlessDBFeature("has_Area_function")
    def test_area_with_regular_aggregate(self):
        # Create projected country objects, for this test to work on all backends.
        for c in Country.objects.all():
            CountryWebMercator.objects.create(name=c.name, mpoly=c.mpoly)
        # Test in projected coordinate system
        qs = CountryWebMercator.objects.annotate(area_sum=Sum(functions.Area('mpoly')))
        for c in qs:
            result = c.area_sum
            # If the result is a measure object, get value.
            if isinstance(result, Area):
                result = result.sq_m
            self.assertAlmostEqual((result - c.mpoly.area) / c.mpoly.area, 0)

    @skipUnlessDBFeature("has_MemSize_function")
    def test_memsize(self):
        ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')