Commit b0ce6fe6 authored by Tim Graham's avatar Tim Graham
Browse files

Fixed #20922 -- Allowed customizing the serializer used by contrib.sessions

Added settings.SESSION_SERIALIZER which is the import path of a serializer
to use for sessions.

Thanks apollo13, carljm, shaib, akaariai, charettes, and dstufft for reviews.
parent 9d1987d7
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -475,6 +475,7 @@ SESSION_SAVE_EVERY_REQUEST = False # Whether to save the se
SESSION_EXPIRE_AT_BROWSER_CLOSE = False                 # Whether a user's session cookie expires when the Web browser is closed.
SESSION_ENGINE = 'django.contrib.sessions.backends.db'  # The module to store session data
SESSION_FILE_PATH = None                                # Directory to store session files if using the file session module. If None, the backend will use a sensible default.
SESSION_SERIALIZER = 'django.contrib.sessions.serializers.JSONSerializer'  # class to serialize session data

#########
# CACHE #
+15 −2
Original line number Diff line number Diff line
import json

from django.contrib.messages.storage.base import BaseStorage
from django.contrib.messages.storage.cookie import MessageEncoder, MessageDecoder
from django.utils import six


class SessionStorage(BaseStorage):
@@ -20,14 +24,23 @@ class SessionStorage(BaseStorage):
        always stores everything it is given, so return True for the
        all_retrieved flag.
        """
        return self.request.session.get(self.session_key), True
        return self.deserialize_messages(self.request.session.get(self.session_key)), True

    def _store(self, messages, response, *args, **kwargs):
        """
        Stores a list of messages to the request's session.
        """
        if messages:
            self.request.session[self.session_key] = messages
            self.request.session[self.session_key] = self.serialize_messages(messages)
        else:
            self.request.session.pop(self.session_key, None)
        return []

    def serialize_messages(self, messages):
        encoder = MessageEncoder(separators=(',', ':'))
        return encoder.encode(messages)

    def deserialize_messages(self, data):
        if data and isinstance(data, six.string_types):
            return json.loads(data, cls=MessageDecoder)
        return data
+1 −0
Original line number Diff line number Diff line
@@ -61,6 +61,7 @@ class BaseTests(object):
            MESSAGE_TAGS    = '',
            MESSAGE_STORAGE = '%s.%s' % (self.storage_class.__module__,
                                         self.storage_class.__name__),
            SESSION_SERIALIZER = 'django.contrib.sessions.serializers.JSONSerializer',
        )
        self.settings_override.enable()

+2 −2
Original line number Diff line number Diff line
@@ -11,13 +11,13 @@ def set_session_data(storage, messages):
    Sets the messages into the backend request's session and remove the
    backend's loaded data cache.
    """
    storage.request.session[storage.session_key] = messages
    storage.request.session[storage.session_key] = storage.serialize_messages(messages)
    if hasattr(storage, '_loaded_data'):
        del storage._loaded_data


def stored_session_messages_count(storage):
    data = storage.request.session.get(storage.session_key, [])
    data = storage.deserialize_messages(storage.request.session.get(storage.session_key, []))
    return len(data)


+9 −12
Original line number Diff line number Diff line
@@ -3,11 +3,6 @@ from __future__ import unicode_literals
import base64
from datetime import datetime, timedelta
import logging

try:
    from django.utils.six.moves import cPickle as pickle
except ImportError:
    import pickle
import string

from django.conf import settings
@@ -17,6 +12,7 @@ from django.utils.crypto import get_random_string
from django.utils.crypto import salted_hmac
from django.utils import timezone
from django.utils.encoding import force_bytes, force_text
from django.utils.module_loading import import_by_path

from django.contrib.sessions.exceptions import SuspiciousSession

@@ -42,6 +38,7 @@ class SessionBase(object):
        self._session_key = session_key
        self.accessed = False
        self.modified = False
        self.serializer = import_by_path(settings.SESSION_SERIALIZER)

    def __contains__(self, key):
        return key in self._session
@@ -86,21 +83,21 @@ class SessionBase(object):
        return salted_hmac(key_salt, value).hexdigest()

    def encode(self, session_dict):
        "Returns the given session dictionary pickled and encoded as a string."
        pickled = pickle.dumps(session_dict, pickle.HIGHEST_PROTOCOL)
        hash = self._hash(pickled)
        return base64.b64encode(hash.encode() + b":" + pickled).decode('ascii')
        "Returns the given session dictionary serialized and encoded as a string."
        serialized = self.serializer().dumps(session_dict)
        hash = self._hash(serialized)
        return base64.b64encode(hash.encode() + b":" + serialized).decode('ascii')

    def decode(self, session_data):
        encoded_data = base64.b64decode(force_bytes(session_data))
        try:
            # could produce ValueError if there is no ':'
            hash, pickled = encoded_data.split(b':', 1)
            expected_hash = self._hash(pickled)
            hash, serialized = encoded_data.split(b':', 1)
            expected_hash = self._hash(serialized)
            if not constant_time_compare(hash.decode(), expected_hash):
                raise SuspiciousSession("Session data corrupted")
            else:
                return pickle.loads(pickled)
                return self.serializer().loads(serialized)
        except Exception as e:
            # ValueError, SuspiciousOperation, unpickling exceptions. If any of
            # these happen, just return an empty dictionary (an empty session).
Loading