Commit c03f0c28 authored by Ben Kraft's avatar Ben Kraft Committed by Tim Graham
Browse files

[1.8.x] Fixed #25389 -- Fixed pickling a SimpleLazyObject wrapping a model.

Pickling a `SimpleLazyObject` wrapping a model did not work correctly; in
particular it did not add the `_django_version` attribute added in 42736ac8.
Now it will handle this and other custom `__reduce__` methods correctly.

Backport of 35355a4f from master
parent 29c9a7d2
Loading
Loading
Loading
Loading
+31 −25
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@ from functools import wraps

from django.utils import six
from django.utils.deprecation import RemovedInDjango19Warning
from django.utils.six.moves import copyreg


# You can't trivially replace this with `functools.partial` because this binds
@@ -268,32 +267,30 @@ class LazyObject(object):
        raise NotImplementedError('subclasses of LazyObject must provide a _setup() method')

    # Because we have messed with __class__ below, we confuse pickle as to what
    # class we are pickling. It also appears to stop __reduce__ from being
    # called. So, we define __getstate__ in a way that cooperates with the way
    # that pickle interprets this class.  This fails when the wrapped class is
    # a builtin, but it is better than nothing.
    def __getstate__(self):
    # class we are pickling. We're going to have to initialize the wrapped
    # object to successfully pickle it, so we might as well just pickle the
    # wrapped object since they're supposed to act the same way.
    #
    # Unfortunately, if we try to simply act like the wrapped object, the ruse
    # will break down when pickle gets our id(). Thus we end up with pickle
    # thinking, in effect, that we are a distinct object from the wrapped
    # object, but with the same __dict__. This can cause problems (see #25389).
    #
    # So instead, we define our own __reduce__ method and custom unpickler. We
    # pickle the wrapped object as the unpickler's argument, so that pickle
    # will pickle it normally, and then the unpickler simply returns its
    # argument.
    def __reduce__(self):
        if self._wrapped is empty:
            self._setup()
        return self._wrapped.__dict__
        return (unpickle_lazyobject, (self._wrapped,))

    # Python 3.3 will call __reduce__ when pickling; this method is needed
    # to serialize and deserialize correctly.
    @classmethod
    def __newobj__(cls, *args):
        return cls.__new__(cls, *args)

    def __reduce_ex__(self, proto):
        if proto >= 2:
            # On Py3, since the default protocol is 3, pickle uses the
            # ``__newobj__`` method (& more efficient opcodes) for writing.
            return (self.__newobj__, (self.__class__,), self.__getstate__())
        else:
            # On Py2, the default protocol is 0 (for back-compat) & the above
            # code fails miserably (see regression test). Instead, we return
            # exactly what's returned if there's no ``__reduce__`` method at
            # all.
            return (copyreg._reconstructor, (self.__class__, object, None), self.__getstate__())
    # We have to explicitly override __getstate__ so that older versions of
    # pickle don't try to pickle the __dict__ (which in the case of a
    # SimpleLazyObject may contain a lambda). The value will end up being
    # ignored by our __reduce__ and custom unpickler.
    def __getstate__(self):
        return {}

    def __deepcopy__(self, memo):
        if self._wrapped is empty:
@@ -332,6 +329,15 @@ class LazyObject(object):
    __contains__ = new_method_proxy(operator.contains)


def unpickle_lazyobject(wrapped):
    """
    Used to unpickle lazy objects. Just return its argument, which will be the
    wrapped object.
    """
    return wrapped
unpickle_lazyobject.__safe_for_unpickling__ = True


# Workaround for http://bugs.python.org/issue12370
_super = super

+2 −0
Original line number Diff line number Diff line
@@ -56,3 +56,5 @@ Bugfixes
* Fixed incorrect queries with multiple many-to-many fields on a model with the
  same 'to' model and with ``related_name`` set to '+' (:ticket:`24505`,
  :ticket:`25486`).

* Fixed pickling a ``SimpleLazyObject`` wrapping a model (:ticket:`25389`).
+4 −0
Original line number Diff line number Diff line
@@ -11,3 +11,7 @@ class Category(models.Model):
class Thing(models.Model):
    name = models.CharField(max_length=100)
    category = models.ForeignKey(Category)


class CategoryInfo(models.Model):
    category = models.OneToOneField(Category)
+92 −0
Original line number Diff line number Diff line
@@ -3,11 +3,14 @@ from __future__ import unicode_literals
import copy
import pickle
import sys
import warnings
from unittest import TestCase

from django.utils import six
from django.utils.functional import LazyObject, SimpleLazyObject, empty

from .models import Category, CategoryInfo


class Foo(object):
    """
@@ -273,3 +276,92 @@ class SimpleLazyObjectTestCase(LazyObjectTestCase):
        self.assertNotIn(6, lazy_set)
        self.assertEqual(len(lazy_list), 5)
        self.assertEqual(len(lazy_set), 4)


class BaseBaz(object):
    """
    A base class with a funky __reduce__ method, meant to simulate the
    __reduce__ method of Model, which sets self._django_version.
    """
    def __init__(self):
        self.baz = 'wrong'

    def __reduce__(self):
        self.baz = 'right'
        return super(BaseBaz, self).__reduce__()

    def __eq__(self, other):
        if self.__class__ != other.__class__:
            return False
        for attr in ['bar', 'baz', 'quux']:
            if hasattr(self, attr) != hasattr(other, attr):
                return False
            elif getattr(self, attr, None) != getattr(other, attr, None):
                return False
        return True


class Baz(BaseBaz):
    """
    A class that inherits from BaseBaz and has its own __reduce_ex__ method.
    """
    def __init__(self, bar):
        self.bar = bar
        super(Baz, self).__init__()

    def __reduce_ex__(self, proto):
        self.quux = 'quux'
        return super(Baz, self).__reduce_ex__(proto)


class BazProxy(Baz):
    """
    A class that acts as a proxy for Baz. It does some scary mucking about with
    dicts, which simulates some crazy things that people might do with
    e.g. proxy models.
    """
    def __init__(self, baz):
        self.__dict__ = baz.__dict__
        self._baz = baz
        super(BaseBaz, self).__init__()


class SimpleLazyObjectPickleTestCase(TestCase):
    """
    Regression test for pickling a SimpleLazyObject wrapping a model (#25389).
    Also covers other classes with a custom __reduce__ method.
    """
    def test_pickle_with_reduce(self):
        """
        Test in a fairly synthetic setting.
        """
        # Test every pickle protocol available
        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
            lazy_objs = [
                SimpleLazyObject(lambda: BaseBaz()),
                SimpleLazyObject(lambda: Baz(1)),
                SimpleLazyObject(lambda: BazProxy(Baz(2))),
            ]
            for obj in lazy_objs:
                pickled = pickle.dumps(obj, protocol)
                unpickled = pickle.loads(pickled)
                self.assertEqual(unpickled, obj)
                self.assertEqual(unpickled.baz, 'right')

    def test_pickle_model(self):
        """
        Test on an actual model, based on the report in #25426.
        """
        category = Category.objects.create(name="thing1")
        CategoryInfo.objects.create(category=category)
        # Test every pickle protocol available
        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
            lazy_category = SimpleLazyObject(lambda: category)
            # Test both if we accessed a field on the model and if we didn't.
            lazy_category.categoryinfo
            lazy_category_2 = SimpleLazyObject(lambda: category)
            with warnings.catch_warnings(record=True) as recorded:
                self.assertEqual(pickle.loads(pickle.dumps(lazy_category, protocol)), category)
                self.assertEqual(pickle.loads(pickle.dumps(lazy_category_2, protocol)), category)
                # Assert that there were no warnings.
                self.assertEqual(len(recorded), 0)