Commit a34fba5e authored by Claude Paroz's avatar Claude Paroz
Browse files

Simplified a bit GeoAggregate classes

Thanks Josh Smeaton for the review. Refs #24152.
parent 28db4af8
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
from functools import partial

from django.contrib.gis.db.models import aggregates


class BaseSpatialFeatures(object):
    gis_enabled = True
@@ -61,15 +63,15 @@ class BaseSpatialFeatures(object):
    # Specifies whether the Collect and Extent aggregates are supported by the database
    @property
    def supports_collect_aggr(self):
        return 'Collect' in self.connection.ops.valid_aggregates
        return aggregates.Collect not in self.connection.ops.disallowed_aggregates

    @property
    def supports_extent_aggr(self):
        return 'Extent' in self.connection.ops.valid_aggregates
        return aggregates.Extent not in self.connection.ops.disallowed_aggregates

    @property
    def supports_make_line_aggr(self):
        return 'MakeLine' in self.connection.ops.valid_aggregates
        return aggregates.MakeLine not in self.connection.ops.disallowed_aggregates

    def __init__(self, *args):
        super(BaseSpatialFeatures, self).__init__(*args)
+7 −10
Original line number Diff line number Diff line
@@ -46,11 +46,7 @@ class BaseSpatialOperations(object):
    union = False

    # Aggregates
    collect = False
    extent = False
    extent3d = False
    make_line = False
    unionagg = False
    disallowed_aggregates = ()

    # Serialization
    geohash = False
@@ -103,12 +99,13 @@ class BaseSpatialOperations(object):
        raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')

    def check_aggregate_support(self, aggregate):
        if aggregate.contains_aggregate == 'gis':
            return aggregate.name in self.valid_aggregates
        return super(BaseSpatialOperations, self).check_aggregate_support(aggregate)
        if isinstance(aggregate, self.disallowed_aggregates):
            raise NotImplementedError(
                "%s spatial aggregation is not supported by this database backend." % aggregate.name
            )
        super(BaseSpatialOperations, self).check_aggregate_support(aggregate)

    # Spatial SQL Construction
    def spatial_aggregate_sql(self, agg):
    def spatial_aggregate_name(self, agg_name):
        raise NotImplementedError('Aggregate support not implemented for this spatial backend.')

    # Routines for getting the OGC-compliant models.
+3 −0
Original line number Diff line number Diff line
from django.contrib.gis.db.backends.base.adapter import WKTAdapter
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.db.backends.mysql.operations import DatabaseOperations


@@ -30,6 +31,8 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
        'within': SpatialOperator(func='MBRWithin'),
    }

    disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union)

    def geo_db_type(self, f):
        return f.geom_type

+6 −13
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ import re
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Distance
from django.db.backends.oracle.base import Database
@@ -56,7 +57,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):

    name = 'oracle'
    oracle = True
    valid_aggregates = {'Union', 'Extent'}
    disallowed_aggregates = (aggregates.Collect, aggregates.Extent3D, aggregates.MakeLine)

    Adapter = OracleSpatialAdapter
    Adaptor = Adapter  # Backwards-compatibility alias.
@@ -223,20 +224,12 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
            else:
                return 'SDO_GEOMETRY(%%s, %s)' % f.srid

    def spatial_aggregate_sql(self, agg):
    def spatial_aggregate_name(self, agg_name):
        """
        Returns the spatial aggregate SQL template and function for the
        given Aggregate instance.
        Returns the spatial aggregate SQL name.
        """
        agg_name = agg.__class__.__name__.lower()
        if agg_name == 'union':
            agg_name += 'agg'
        if agg.is_extent:
            sql_template = '%(function)s(%(expressions)s)'
        else:
            sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
        sql_function = getattr(self, agg_name)
        return sql_template, sql_function
        agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower()
        return getattr(self, agg_name)

    # Routines for getting the OGC-compliant models.
    def geometry_columns(self):
+5 −15
Original line number Diff line number Diff line
@@ -49,7 +49,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
    geography = True
    geom_func_prefix = 'ST_'
    version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
    valid_aggregates = {'Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'}

    Adapter = PostGISAdapter
    Adaptor = Adapter  # Backwards-compatibility alias.
@@ -360,20 +359,11 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
        else:
            raise Exception('Could not determine PROJ.4 version from PostGIS.')

    def spatial_aggregate_sql(self, agg):
        """
        Returns the spatial aggregate SQL template and function for the
        given Aggregate instance.
        """
        agg_name = agg.__class__.__name__
        if not self.check_aggregate_support(agg):
            raise NotImplementedError('%s spatial aggregate is not implemented for this backend.' % agg_name)
        agg_name = agg_name.lower()
        if agg_name == 'union':
            agg_name += 'agg'
        sql_template = '%(function)s(%(expressions)s)'
        sql_function = getattr(self, agg_name)
        return sql_template, sql_function
    def spatial_aggregate_name(self, agg_name):
        if agg_name == 'Extent3D':
            return self.extent3d
        else:
            return self.geom_func_prefix + agg_name

    # Routines for getting the OGC-compliant models.
    def geometry_columns(self):
Loading