Commit 92b2dec9 authored by Claude Paroz's avatar Claude Paroz
Browse files

[py3] Made signing infrastructure pass tests with Python 3

parent 72755762
Loading
Loading
Loading
Loading
+9 −10
Original line number Diff line number Diff line
@@ -32,6 +32,8 @@ start of the base64 JSON.
There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'.
These functions make use of all of them.
"""
from __future__ import unicode_literals

import base64
import json
import time
@@ -41,7 +43,7 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.utils import baseconv
from django.utils.crypto import constant_time_compare, salted_hmac
from django.utils.encoding import force_text, smart_bytes
from django.utils.encoding import smart_bytes
from django.utils.importlib import import_module


@@ -60,12 +62,12 @@ class SignatureExpired(BadSignature):


def b64_encode(s):
    return base64.urlsafe_b64encode(s).strip('=')
    return base64.urlsafe_b64encode(smart_bytes(s)).decode('ascii').strip('=')


def b64_decode(s):
    pad = '=' * (-len(s) % 4)
    return base64.urlsafe_b64decode(s + pad)
    return base64.urlsafe_b64decode(smart_bytes(s + pad)).decode('ascii')


def base64_hmac(salt, value, key):
@@ -121,7 +123,7 @@ def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer,

    if compress:
        # Avoid zlib dependency unless compress is being used
        compressed = zlib.compress(data)
        compressed = zlib.compress(smart_bytes(data))
        if len(compressed) < (len(data) - 1):
            data = compressed
            is_compressed = True
@@ -135,8 +137,7 @@ def loads(s, key=None, salt='django.core.signing', serializer=JSONSerializer, ma
    """
    Reverse of dumps(), raises BadSignature if signature fails
    """
    base64d = smart_bytes(
        TimestampSigner(key, salt=salt).unsign(s, max_age=max_age))
    base64d = TimestampSigner(key, salt=salt).unsign(s, max_age=max_age)
    decompress = False
    if base64d[0] == '.':
        # It's compressed; uncompress it first
@@ -159,16 +160,14 @@ class Signer(object):
        return base64_hmac(self.salt + 'signer', value, self.key)

    def sign(self, value):
        value = smart_bytes(value)
        return '%s%s%s' % (value, self.sep, self.signature(value))

    def unsign(self, signed_value):
        signed_value = smart_bytes(signed_value)
        if not self.sep in signed_value:
            raise BadSignature('No "%s" found in value' % self.sep)
        value, sig = signed_value.rsplit(self.sep, 1)
        if constant_time_compare(sig, self.signature(value)):
            return force_text(value)
            return value
        raise BadSignature('Signature "%s" does not match' % sig)


@@ -178,7 +177,7 @@ class TimestampSigner(Signer):
        return baseconv.base62.encode(int(time.time()))

    def sign(self, value):
        value = smart_bytes('%s%s%s' % (value, self.sep, self.timestamp()))
        value = '%s%s%s' % (value, self.sep, self.timestamp())
        return '%s%s%s' % (value, self.sep, self.signature(value))

    def unsign(self, value, max_age=None):
+7 −3
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import time

from django.core import signing
from django.test import TestCase
from django.utils import six
from django.utils.encoding import force_text


@@ -69,15 +70,18 @@ class TestSigner(TestCase):

    def test_dumps_loads(self):
        "dumps and loads be reversible for any JSON serializable object"
        objects = (
        objects = [
            ['a', 'list'],
            b'a string',
            'a unicode string \u2019',
            {'a': 'dictionary'},
        )
        ]
        if not six.PY3:
            objects.append(b'a byte string')
        for o in objects:
            self.assertNotEqual(o, signing.dumps(o))
            self.assertEqual(o, signing.loads(signing.dumps(o)))
            self.assertNotEqual(o, signing.dumps(o, compress=True))
            self.assertEqual(o, signing.loads(signing.dumps(o, compress=True)))

    def test_decode_detects_tampering(self):
        "loads should raise exception for tampered objects"