Commit cc4e4d9a authored by Russell Keith-Magee's avatar Russell Keith-Magee
Browse files

Fixed #3566 -- Added support for aggregation to the ORM. See the documentation...

Fixed #3566 -- Added support for aggregation to the ORM. See the documentation for details on usage.

Many thanks to:
 * Nicolas Lara, who worked on this feature during the 2008 Google Summer of Code.
 * Alex Gaynor for his help debugging and fixing a number of issues.
 * Justin Bronn for his help integrating with contrib.gis.
 * Karen Tracey for her help with cross-platform testing.
 * Ian Kelly for his help testing and fixing Oracle support.
 * Malcolm Tredinnick for his invaluable review notes.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9742 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 50a293a0
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ answer newbie questions, and generally made Django that much better:
    AgarFu <heaven@croasanaso.sytes.net>
    Dagur Páll Ammendrup <dagurp@gmail.com>
    Collin Anderson <cmawebsite@gmail.com>
    Nicolas Lara <nicolaslara@gmail.com>
    Jeff Anderson <jefferya@programmerq.net>
    Marian Andre <django@andre.sk>
    Andreas
+10 −0
Original line number Diff line number Diff line
from django.db.models import Aggregate

class Extent(Aggregate):
    name = 'Extent'

class MakeLine(Aggregate):
    name = 'MakeLine'

class Union(Aggregate):
    name = 'Union'
+68 −121
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ from django.db import connection
from django.db.models.query import sql, QuerySet, Q

from django.contrib.gis.db.backend import SpatialBackend
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.db.models.fields import GeometryField, PointField
from django.contrib.gis.db.models.sql import AreaField, DistanceField, GeomField, GeoQuery, GeoWhereNode
from django.contrib.gis.measure import Area, Distance
@@ -98,20 +99,7 @@ class GeoQuerySet(QuerySet):
        Returns the extent (aggregate) of the features in the GeoQuerySet.  The
        extent will be returned as a 4-tuple, consisting of (xmin, ymin, xmax, ymax).
        """
        convert_extent = None
        if SpatialBackend.postgis:
            def convert_extent(box, geo_field):
                # TODO: Parsing of BOX3D, Oracle support (patches welcome!)
                # Box text will be something like "BOX(-90.0 30.0, -85.0 40.0)"; 
                # parsing out and returning as a 4-tuple.
                ll, ur = box[4:-1].split(',')
                xmin, ymin = map(float, ll.split())
                xmax, ymax = map(float, ur.split())
                return (xmin, ymin, xmax, ymax)
        elif SpatialBackend.oracle:
            def convert_extent(wkt, geo_field):
                raise NotImplementedError
        return self._spatial_aggregate('extent', convert_func=convert_extent, **kwargs)
        return self._spatial_aggregate(aggregates.Extent, **kwargs)

    def gml(self, precision=8, version=2, **kwargs):
        """
@@ -163,9 +151,7 @@ class GeoQuerySet(QuerySet):
        this GeoQuerySet and returns it.  This is a spatial aggregate
        method, and thus returns a geometry rather than a GeoQuerySet.
        """
        kwargs['geo_field_type'] = PointField
        kwargs['agg_field'] = GeometryField
        return self._spatial_aggregate('make_line', **kwargs)
        return self._spatial_aggregate(aggregates.MakeLine, geo_field_type=PointField, **kwargs)

    def mem_size(self, **kwargs):
        """
@@ -288,11 +274,10 @@ class GeoQuerySet(QuerySet):
        None if the GeoQuerySet is empty.  The `tolerance` keyword is for
        Oracle backends only.
        """
        kwargs['agg_field'] = GeometryField
        return self._spatial_aggregate('unionagg', **kwargs)
        return self._spatial_aggregate(aggregates.Union, **kwargs)

    ### Private API -- Abstracted DRY routines. ###
    def _spatial_setup(self, att, aggregate=False, desc=None, field_name=None, geo_field_type=None):
    def _spatial_setup(self, att, desc=None, field_name=None, geo_field_type=None):
        """
        Performs set up for executing the spatial function.
        """
@@ -316,71 +301,37 @@ class GeoQuerySet(QuerySet):
            raise TypeError('"%s" stored procedures may only be called on %ss.' % (func, geo_field_type.__name__))

        # Setting the procedure args.
        procedure_args['geo_col'] = self._geocol_select(geo_field, field_name, aggregate)
        procedure_args['geo_col'] = self._geocol_select(geo_field, field_name)

        return procedure_args, geo_field

    def _spatial_aggregate(self, att, field_name=None, 
                           agg_field=None, convert_func=None, 
                           geo_field_type=None, tolerance=0.0005):
    def _spatial_aggregate(self, aggregate, field_name=None,
                           geo_field_type=None, tolerance=0.05):
        """
        DRY routine for calling aggregate spatial stored procedures and
        returning their result to the caller of the function.
        """
        # Constructing the setup keyword arguments.
        setup_kwargs = {'aggregate' : True,
                        'field_name' : field_name,
                        'geo_field_type' : geo_field_type,
                        }
        procedure_args, geo_field = self._spatial_setup(att, **setup_kwargs)
        # Getting the field the geographic aggregate will be called on.
        geo_field = self.query._geo_field(field_name)
        if not geo_field:
            raise TypeError('%s aggregate only available on GeometryFields.' % aggregate.name)

        if SpatialBackend.oracle:
            procedure_args['tolerance'] = tolerance
            # Adding in selection SQL for Oracle geometry columns.
            if agg_field is GeometryField: 
                agg_sql = '%s' % SpatialBackend.select
            else: 
                agg_sql = '%s'
            agg_sql =  agg_sql % ('%(function)s(SDOAGGRTYPE(%(geo_col)s,%(tolerance)s))' % procedure_args)
        else:
            agg_sql = '%(function)s(%(geo_col)s)' % procedure_args

        # Wrapping our selection SQL in `GeomSQL` to bypass quoting, and
        # specifying the type of the aggregate field.
        self.query.select = [GeomSQL(agg_sql)]
        self.query.select_fields = [agg_field]

        try:
            # `asql` => not overriding `sql` module.
            asql, params = self.query.as_sql()
        except sql.datastructures.EmptyResultSet:
            return None   

        # Getting a cursor, executing the query, and extracting the returned
        # value from the aggregate function.
        cursor = connection.cursor()
        cursor.execute(asql, params)
        result = cursor.fetchone()[0]
        
        # If the `agg_field` is specified as a GeometryField, then autmatically
        # set up the conversion function.
        if agg_field is GeometryField and not callable(convert_func):
            if SpatialBackend.postgis:
                def convert_geom(hex, geo_field):
                    if hex: return SpatialBackend.Geometry(hex)
                    else: return None
            elif SpatialBackend.oracle:
                def convert_geom(clob, geo_field):
                    if clob: return SpatialBackend.Geometry(clob.read(), geo_field._srid)
                    else: return None
            convert_func = convert_geom

        # Returning the callback function evaluated on the result culled
        # from the executed cursor.
        if callable(convert_func):
            return convert_func(result, geo_field)
        else:
            return result
        # Checking if there are any geo field type limitations on this
        # aggregate (e.g. ST_Makeline only operates on PointFields).
        if not geo_field_type is None and not isinstance(geo_field, geo_field_type):
            raise TypeError('%s aggregate may only be called on %ss.' % (aggregate.name, geo_field_type.__name__))

        # Getting the string expression of the field name, as this is the
        # argument taken by `Aggregate` objects.
        agg_col = field_name or geo_field.name

        # Adding any keyword parameters for the Aggregate object. Oracle backends
        # in particular need an additional `tolerance` parameter.
        agg_kwargs = {}
        if SpatialBackend.oracle: agg_kwargs['tolerance'] = tolerance

        # Calling the QuerySet.aggregate, and returning only the value of the aggregate.
        return self.aggregate(_geoagg=aggregate(agg_col, **agg_kwargs))['_geoagg']

    def _spatial_attribute(self, att, settings, field_name=None, model_att=None):
        """
@@ -595,16 +546,12 @@ class GeoQuerySet(QuerySet):
            s['procedure_args']['tolerance'] = tolerance
        return self._spatial_attribute(func, s, **kwargs)

    def _geocol_select(self, geo_field, field_name, aggregate=False):
    def _geocol_select(self, geo_field, field_name):
        """
        Helper routine for constructing the SQL to select the geographic
        column.  Takes into account if the geographic field is in a
        ForeignKey relation to the current model.
        """
        # If this is an aggregate spatial query, the flag needs to be
        # set on the `GeoQuery` object of this queryset.
        if aggregate: self.query.aggregate = True

        opts = self.model._meta
        if not geo_field in opts.fields:
            # Is this operation going to be on a related geographic field?
+36 −0
Original line number Diff line number Diff line
from django.db.models.sql.aggregates import *

from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.backend import SpatialBackend

if SpatialBackend.oracle:
    geo_template = '%(function)s(SDOAGGRTYPE(%(field)s,%(tolerance)s))'
else:
    geo_template = '%(function)s(%(field)s)'

class GeoAggregate(Aggregate):
    # Overriding the SQL template with the geographic one.
    sql_template = geo_template

    is_extent = False

    def __init__(self, col, source=None, is_summary=False, **extra):
        super(GeoAggregate, self).__init__(col, source, is_summary, **extra)

        # Can't use geographic aggregates on non-geometry fields.
        if not isinstance(self.source, GeometryField):
            raise ValueError('Geospatial aggregates only allowed on geometry fields.')

        # Making sure the SQL function is available for this spatial backend.
        if not self.sql_function:
            raise NotImplementedError('This aggregate functionality not implemented for your spatial backend.')

class Extent(GeoAggregate):
    is_extent = True
    sql_function = SpatialBackend.extent

class MakeLine(GeoAggregate):
    sql_function = SpatialBackend.make_line

class Union(GeoAggregate):
    sql_function = SpatialBackend.unionagg
+85 −45
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from django.db.models.fields.related import ForeignKey

from django.contrib.gis.db.backend import SpatialBackend
from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.sql import aggregates as gis_aggregates_module
from django.contrib.gis.db.models.sql.where import GeoWhereNode
from django.contrib.gis.measure import Area, Distance

@@ -12,12 +13,35 @@ from django.contrib.gis.measure import Area, Distance
ALL_TERMS = sql.constants.QUERY_TERMS.copy()
ALL_TERMS.update(SpatialBackend.gis_terms)

# Conversion functions used in normalizing geographic aggregates.
if SpatialBackend.postgis:
    def convert_extent(box):
        # TODO: Parsing of BOX3D, Oracle support (patches welcome!)
        # Box text will be something like "BOX(-90.0 30.0, -85.0 40.0)";
        # parsing out and returning as a 4-tuple.
        ll, ur = box[4:-1].split(',')
        xmin, ymin = map(float, ll.split())
        xmax, ymax = map(float, ur.split())
        return (xmin, ymin, xmax, ymax)

    def convert_geom(hex, geo_field):
        if hex: return SpatialBackend.Geometry(hex)
        else: return None
else:
    def convert_extent(box):
        raise NotImplementedError('Aggregate extent not implemented for this spatial backend.')

    def convert_geom(clob, geo_field):
        if clob: return SpatialBackend.Geometry(clob.read(), geo_field._srid)
        else: return None

class GeoQuery(sql.Query):
    """
    A single spatial SQL query.
    """
    # Overridding the valid query terms.
    query_terms = ALL_TERMS
    aggregates_module = gis_aggregates_module

    #### Methods overridden from the base Query class ####
    def __init__(self, model, conn):
@@ -25,7 +49,6 @@ class GeoQuery(sql.Query):
        # The following attributes are customized for the GeoQuerySet.
        # The GeoWhereNode and SpatialBackend classes contain backend-specific
        # routines and functions.
        self.aggregate = False
        self.custom_select = {}
        self.transformed_srid = None
        self.extra_select_fields = {}
@@ -34,7 +57,6 @@ class GeoQuery(sql.Query):
        obj = super(GeoQuery, self).clone(*args, **kwargs)
        # Customized selection dictionary and transformed srid flag have
        # to also be added to obj.
        obj.aggregate = self.aggregate
        obj.custom_select = self.custom_select.copy()
        obj.transformed_srid = self.transformed_srid
        obj.extra_select_fields = self.extra_select_fields.copy()
@@ -67,27 +89,42 @@ class GeoQuery(sql.Query):
            for col, field in izip(self.select, self.select_fields):
                if isinstance(col, (list, tuple)):
                    r = self.get_field_select(field, col[0])
                    if with_aliases and col[1] in col_aliases:
                    if with_aliases:
                        if col[1] in col_aliases:
                            c_alias = 'Col%d' % len(col_aliases)
                            result.append('%s AS %s' % (r, c_alias))
                            aliases.add(c_alias)
                            col_aliases.add(c_alias)
                        else:
                            result.append('%s AS %s' % (r, col[1]))
                            aliases.add(r)
                            col_aliases.add(col[1])
                    else:
                        result.append(r)
                        aliases.add(r)
                        col_aliases.add(col[1])
                else:
                    result.append(col.as_sql(quote_func=qn))

                    if hasattr(col, 'alias'):
                        aliases.add(col.alias)
                        col_aliases.add(col.alias)

        elif self.default_cols:
            cols, new_aliases = self.get_default_columns(with_aliases,
                    col_aliases)
            result.extend(cols)
            aliases.update(new_aliases)

        result.extend([
                '%s%s' % (
                    aggregate.as_sql(quote_func=qn),
                    alias is not None and ' AS %s' % alias or ''
                    )
                for alias, aggregate in self.aggregate_select.items()
                ])

        # This loop customized for GeoQuery.
        if not self.aggregate:
        for (table, col), field in izip(self.related_select_cols, self.related_select_fields):
            r = self.get_field_select(field, table)
            if with_aliases and col in col_aliases:
@@ -154,16 +191,6 @@ class GeoQuery(sql.Query):
            return result, None
        return result, aliases

    def get_ordering(self):
        """
        This routine is overridden to disable ordering for aggregate
        spatial queries.
        """
        if not self.aggregate:
            return super(GeoQuery, self).get_ordering()
        else:
            return ()

    def resolve_columns(self, row, fields=()):
        """
        This routine is necessary so that distances and geometries returned
@@ -212,6 +239,19 @@ class GeoQuery(sql.Query):
            value = SpatialBackend.Geometry(value)
        return value

    def resolve_aggregate(self, value, aggregate):
        """
        Overridden from GeoQuery's normalize to handle the conversion of
        GeoAggregate objects.
        """
        if isinstance(aggregate, self.aggregates_module.GeoAggregate):
            if aggregate.is_extent:
                return convert_extent(value)
            else:
                return convert_geom(value, aggregate.source)
        else:
            return super(GeoQuery, self).resolve_aggregate(value, aggregate)

    #### Routines unique to GeoQuery ####
    def get_extra_select_format(self, alias):
        sel_fmt = '%s'
Loading