Commit 312fc1af authored by Sergey Fedoseev's avatar Sergey Fedoseev Committed by Tim Graham
Browse files

Fixed #25961 -- Removed handling of thread-non-safe GEOS functions.

parent febe1321
Loading
Loading
Loading
Loading
+36 −10
Original line number Diff line number Diff line
@@ -9,12 +9,13 @@
import logging
import os
import re
import threading
from ctypes import CDLL, CFUNCTYPE, POINTER, Structure, c_char_p
from ctypes.util import find_library

from django.contrib.gis.geos.error import GEOSException
from django.core.exceptions import ImproperlyConfigured
from django.utils.functional import SimpleLazyObject
from django.utils.functional import SimpleLazyObject, cached_property

logger = logging.getLogger('django.contrib.gis')

@@ -63,10 +64,11 @@ def load_geos():
    _lgeos = CDLL(lib_path)
    # Here we set up the prototypes for the initGEOS_r and finishGEOS_r
    # routines.  These functions aren't actually called until they are
    # attached to a GEOS context handle -- this actually occurs in
    # geos/prototypes/threadsafe.py.
    # attached to a GEOS context handle.
    _lgeos.initGEOS_r.restype = CONTEXT_PTR
    _lgeos.finishGEOS_r.argtypes = [CONTEXT_PTR]
    # Ensures compatibility across 32 and 64-bit platforms.
    _lgeos.GEOSversion.restype = c_char_p
    return _lgeos


@@ -134,6 +136,27 @@ def get_pointer_arr(n):
lgeos = SimpleLazyObject(load_geos)


class GEOSContextHandle(object):
    def __init__(self):
        # Initializing the context handle for this thread with
        # the notice and error handler.
        self.ptr = lgeos.initGEOS_r(notice_h, error_h)

    def __del__(self):
        if self.ptr and lgeos:
            lgeos.finishGEOS_r(self.ptr)


class GEOSContext(threading.local):

    @cached_property
    def ptr(self):
        # Assign handle so it will will garbage collected when
        # thread is finished.
        self.handle = GEOSContextHandle()
        return self.handle.ptr


class GEOSFuncFactory(object):
    """
    Lazy loading of GEOS functions.
@@ -141,6 +164,7 @@ class GEOSFuncFactory(object):
    argtypes = None
    restype = None
    errcheck = None
    thread_context = GEOSContext()

    def __init__(self, func_name, *args, **kwargs):
        self.func_name = func_name
@@ -154,21 +178,23 @@ class GEOSFuncFactory(object):
    def __call__(self, *args, **kwargs):
        if self.func is None:
            self.func = self.get_func(*self.args, **self.kwargs)
        return self.func(*args, **kwargs)
        # Call the threaded GEOS routine with pointer of the context handle
        # as the first argument.
        return self.func(self.thread_context.ptr, *args)

    def get_func(self, *args, **kwargs):
        from django.contrib.gis.geos.prototypes.threadsafe import GEOSFunc
        func = GEOSFunc(self.func_name)
        func.argtypes = self.argtypes or []
        # GEOS thread-safe function signatures end with '_r' and
        # take an additional context handle parameter.
        func = getattr(lgeos, self.func_name + '_r')
        func.argtypes = [CONTEXT_PTR] + (self.argtypes or [])
        func.restype = self.restype
        if self.errcheck:
            func.errcheck = self.errcheck
        return func


# Returns the string version of the GEOS library. Have to set the restype
# explicitly to c_char_p to ensure compatibility across 32 and 64-bit platforms.
geos_version = GEOSFuncFactory('GEOSversion', restype=c_char_p)
# Returns the string version of the GEOS library.
geos_version = lambda: lgeos.GEOSversion()

# Regular expression should be able to parse version strings such as
# '3.0.0rc4-CAPI-1.3.3', '3.0.0-CAPI-1.4.1', '3.4.0dev-CAPI-1.8.0' or '3.4.0dev-CAPI-1.8.0 r0'
+0 −93
Original line number Diff line number Diff line
import threading

from django.contrib.gis.geos.libgeos import (
    CONTEXT_PTR, error_h, lgeos, notice_h,
)


class GEOSContextHandle(object):
    """
    Python object representing a GEOS context handle.
    """
    def __init__(self):
        # Initializing the context handler for this thread with
        # the notice and error handler.
        self.ptr = lgeos.initGEOS_r(notice_h, error_h)

    def __del__(self):
        if self.ptr and lgeos:
            lgeos.finishGEOS_r(self.ptr)


# Defining a thread-local object and creating an instance
# to hold a reference to GEOSContextHandle for this thread.
class GEOSContext(threading.local):
    handle = None

thread_context = GEOSContext()


class GEOSFunc(object):
    """
    Class that serves as a wrapper for GEOS C Functions, and will
    use thread-safe function variants when available.
    """
    def __init__(self, func_name):
        try:
            # GEOS thread-safe function signatures end with '_r', and
            # take an additional context handle parameter.
            self.cfunc = getattr(lgeos, func_name + '_r')
            self.threaded = True
            # Create a reference here to thread_context so it's not
            # garbage-collected before an attempt to call this object.
            self.thread_context = thread_context
        except AttributeError:
            # Otherwise, use usual function.
            self.cfunc = getattr(lgeos, func_name)
            self.threaded = False

    def __call__(self, *args):
        if self.threaded:
            # If a context handle does not exist for this thread, initialize one.
            if not self.thread_context.handle:
                self.thread_context.handle = GEOSContextHandle()
            # Call the threaded GEOS routine with pointer of the context handle
            # as the first argument.
            return self.cfunc(self.thread_context.handle.ptr, *args)
        else:
            return self.cfunc(*args)

    def __str__(self):
        return self.cfunc.__name__

    # argtypes property
    def _get_argtypes(self):
        return self.cfunc.argtypes

    def _set_argtypes(self, argtypes):
        if self.threaded:
            new_argtypes = [CONTEXT_PTR]
            new_argtypes.extend(argtypes)
            self.cfunc.argtypes = new_argtypes
        else:
            self.cfunc.argtypes = argtypes

    argtypes = property(_get_argtypes, _set_argtypes)

    # restype property
    def _get_restype(self):
        return self.cfunc.restype

    def _set_restype(self, restype):
        self.cfunc.restype = restype

    restype = property(_get_restype, _set_restype)

    # errcheck property
    def _get_errcheck(self):
        return self.cfunc.errcheck

    def _set_errcheck(self, errcheck):
        self.cfunc.errcheck = errcheck

    errcheck = property(_get_errcheck, _set_errcheck)
+44 −1
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ from __future__ import unicode_literals
import ctypes
import json
import random
import threading
from binascii import a2b_hex, b2a_hex
from io import BytesIO
from unittest import skipUnless
@@ -12,7 +13,7 @@ from django.contrib.gis.gdal import HAS_GDAL
from django.contrib.gis.geos import (
    HAS_GEOS, GeometryCollection, GEOSException, GEOSGeometry, LinearRing,
    LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon,
    fromfile, fromstr,
    fromfile, fromstr, libgeos,
)
from django.contrib.gis.geos.base import GEOSBase
from django.contrib.gis.geos.libgeos import geos_version_info
@@ -1232,6 +1233,48 @@ class GEOSTest(SimpleTestCase, TestDataMixin):
            self.assertEqual(m.group('version'), v_geos)
            self.assertEqual(m.group('capi_version'), v_capi)

    def test_geos_threads(self):
        pnt = Point()
        context_ptrs = []

        geos_init = libgeos.lgeos.initGEOS_r
        geos_finish = libgeos.lgeos.finishGEOS_r

        def init(*args, **kwargs):
            result = geos_init(*args, **kwargs)
            context_ptrs.append(result)
            return result

        def finish(*args, **kwargs):
            result = geos_finish(*args, **kwargs)
            destructor_called.set()
            return result

        for i in range(2):
            destructor_called = threading.Event()
            patch_path = 'django.contrib.gis.geos.libgeos.lgeos'
            with mock.patch.multiple(patch_path, initGEOS_r=mock.DEFAULT, finishGEOS_r=mock.DEFAULT) as mocked:
                mocked['initGEOS_r'].side_effect = init
                mocked['finishGEOS_r'].side_effect = finish
                with mock.patch('django.contrib.gis.geos.prototypes.predicates.geos_hasz.func') as mocked_hasz:
                    thread = threading.Thread(target=lambda: pnt.hasz)
                    thread.start()
                    thread.join()

                    # We can't be sure that members of thread locals are
                    # garbage collected right after `thread.join()` so
                    # we must wait until destructor is actually called.
                    # Fail if destructor wasn't called within a second.
                    self.assertTrue(destructor_called.wait(1))

                    context_ptr = context_ptrs[i]
                    self.assertIsInstance(context_ptr, libgeos.CONTEXT_PTR)
                    mocked_hasz.assert_called_once_with(context_ptr, pnt.ptr)
                    mocked['finishGEOS_r'].assert_called_once_with(context_ptr)

        # Check that different contexts were used for the different threads.
        self.assertNotEqual(context_ptrs[0], context_ptrs[1])

    @ignore_warnings(category=RemovedInDjango20Warning)
    def test_deprecated_srid_getters_setters(self):
        p = Point(1, 2, srid=123)