Commit 63378163 authored by Simon Charette's avatar Simon Charette Committed by Andrew Godwin
Browse files

Fixed #20943 -- Weakly reference senders when caching their associated receivers

parent 77478d84
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ post_delete = Signal(providing_args=["instance", "using"], use_caching=True)

pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"])
pre_syncdb = pre_migrate
post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"], use_caching=True)
post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"])
post_syncdb = post_migrate

m2m_changed = Signal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True)
+8 −4
Original line number Diff line number Diff line
@@ -4,8 +4,10 @@ import threading
from django.dispatch import saferef
from django.utils.six.moves import xrange


WEAKREF_TYPES = (weakref.ReferenceType, saferef.BoundMethodWeakref)


def _make_id(target):
    if hasattr(target, '__func__'):
        return (id(target.__self__), id(target.__func__))
@@ -15,6 +17,7 @@ NONE_ID = _make_id(None)
# A marker for caching
NO_RECEIVERS = object()


class Signal(object):
    """
    Base class for all signals
@@ -42,7 +45,7 @@ class Signal(object):
        # distinct sender we cache the receivers that sender has in
        # 'sender_receivers_cache'. The cache is cleaned when .connect() or
        # .disconnect() is called and populated on send().
        self.sender_receivers_cache = {}
        self.sender_receivers_cache = weakref.WeakKeyDictionary() if use_caching else {}

    def connect(self, receiver, sender=None, weak=True, dispatch_uid=None):
        """
@@ -116,7 +119,7 @@ class Signal(object):
                    break
            else:
                self.receivers.append((lookup_key, receiver))
            self.sender_receivers_cache = {}
            self.sender_receivers_cache.clear()

    def disconnect(self, receiver=None, sender=None, weak=True, dispatch_uid=None):
        """
@@ -151,7 +154,7 @@ class Signal(object):
                if r_key == lookup_key:
                    del self.receivers[index]
                    break
            self.sender_receivers_cache = {}
            self.sender_receivers_cache.clear()

    def has_listeners(self, sender=None):
        return bool(self._live_receivers(sender))
@@ -276,7 +279,8 @@ class Signal(object):
                for idx, (r_key, _) in enumerate(reversed(self.receivers)):
                    if r_key == key:
                        del self.receivers[last_idx - idx]
            self.sender_receivers_cache = {}
            self.sender_receivers_cache.clear()


def receiver(signal, **kwargs):
    """
+21 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import gc
import sys
import time
import unittest
import weakref

from django.dispatch import Signal, receiver

@@ -35,6 +36,8 @@ class Callable(object):
a_signal = Signal(providing_args=["val"])
b_signal = Signal(providing_args=["val"])
c_signal = Signal(providing_args=["val"])
d_signal = Signal(providing_args=["val"], use_caching=True)


class DispatcherTests(unittest.TestCase):
    """Test suite for dispatcher (barely started)"""
@@ -72,6 +75,24 @@ class DispatcherTests(unittest.TestCase):
        self.assertEqual(result, expected)
        self._testIsClean(a_signal)

    def testCachedGarbagedCollected(self):
        """
        Make sure signal caching sender receivers don't prevent garbage
        collection of senders.
        """
        class sender:
            pass
        wref = weakref.ref(sender)
        d_signal.connect(receiver_1_arg)
        d_signal.send(sender, val='garbage')
        del sender
        garbage_collect()
        try:
            self.assertIsNone(wref())
        finally:
            # Disconnect after reference check since it flushes the tested cache.
            d_signal.disconnect(receiver_1_arg)

    def testMultipleRegistration(self):
        a = Callable()
        a_signal.connect(a)