Commit fd23c060 authored by Tim Graham's avatar Tim Graham
Browse files

Fixed #21649 -- Added optional invalidation of sessions when user password changes.

Thanks Paul McMillan, Aymeric Augustin, and Erik Romijn for reviews.
parent 9494f29d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ MIDDLEWARE_CLASSES = (
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
)
+22 −2
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ from .signals import user_logged_in, user_logged_out, user_login_failed

SESSION_KEY = '_auth_user_id'
BACKEND_SESSION_KEY = '_auth_user_backend'
HASH_SESSION_KEY = '_auth_user_hash'
REDIRECT_FIELD_NAME = 'next'


@@ -76,11 +77,16 @@ def login(request, user):
    have to reauthenticate on every request. Note that data set during
    the anonymous session is retained when the user logs in.
    """
    session_auth_hash = ''
    if user is None:
        user = request.user
    # TODO: It would be nice to support different login methods, like signed cookies.
    if hasattr(user, 'get_session_auth_hash'):
        session_auth_hash = user.get_session_auth_hash()

    if SESSION_KEY in request.session:
        if request.session[SESSION_KEY] != user.pk:
        if request.session[SESSION_KEY] != user.pk or (
                session_auth_hash and
                request.session[HASH_SESSION_KEY] != session_auth_hash):
            # To avoid reusing another user's session, create a new, empty
            # session if the existing session corresponds to a different
            # authenticated user.
@@ -89,6 +95,7 @@ def login(request, user):
        request.session.cycle_key()
    request.session[SESSION_KEY] = user.pk
    request.session[BACKEND_SESSION_KEY] = user.backend
    request.session[HASH_SESSION_KEY] = session_auth_hash
    if hasattr(request, 'user'):
        request.user = user
    rotate_token(request)
@@ -158,4 +165,17 @@ def get_permission_codename(action, opts):
    return '%s_%s' % (action, opts.model_name)


def update_session_auth_hash(request, user):
    """
    Updating a user's password logs out all sessions for the user if
    django.contrib.auth.middleware.SessionAuthenticationMiddleware is enabled.

    This function takes the current request and the updated user object from
    which the new session hash will be derived and updates the session hash
    appropriately to prevent a password change from logging out the session
    from which the password was changed.
    """
    if hasattr(user, 'get_session_auth_hash') and request.user == user:
        request.session[HASH_SESSION_KEY] = user.get_session_auth_hash()

default_app_config = 'django.contrib.auth.apps.AuthConfig'
+2 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ from django.conf import settings
from django.conf.urls import url
from django.contrib import admin
from django.contrib.admin.options import IS_POPUP_VAR
from django.contrib.auth import update_session_auth_hash
from django.contrib.auth.forms import (UserCreationForm, UserChangeForm,
    AdminPasswordChangeForm)
from django.contrib.auth.models import User, Group
@@ -131,6 +132,7 @@ class UserAdmin(admin.ModelAdmin):
                self.log_change(request, request.user, change_message)
                msg = ugettext('Password changed successfully.')
                messages.success(request, msg)
                update_session_auth_hash(request, form.user)
                return HttpResponseRedirect('..')
        else:
            form = self.change_password_form(user)
+19 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ from django.contrib import auth
from django.contrib.auth import load_backend
from django.contrib.auth.backends import RemoteUserBackend
from django.core.exceptions import ImproperlyConfigured
from django.utils.crypto import constant_time_compare
from django.utils.functional import SimpleLazyObject


@@ -22,6 +23,24 @@ class AuthenticationMiddleware(object):
        request.user = SimpleLazyObject(lambda: get_user(request))


class SessionAuthenticationMiddleware(object):
    """
    Middleware for invalidating a user's sessions that don't correspond to the
    user's current session authentication hash (generated based on the user's
    password for AbstractUser).
    """
    def process_request(self, request):
        user = request.user
        if user and hasattr(user, 'get_session_auth_hash'):
            session_hash = request.session.get(auth.HASH_SESSION_KEY)
            session_hash_verified = session_hash and constant_time_compare(
                session_hash,
                user.get_session_auth_hash()
            )
            if not session_hash_verified:
                auth.logout(request)


class RemoteUserMiddleware(object):
    """
    Middleware for utilizing Web-server-provided authentication.
+8 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ from django.core.mail import send_mail
from django.core import validators
from django.db import models
from django.db.models.manager import EmptyManager
from django.utils.crypto import get_random_string
from django.utils.crypto import get_random_string, salted_hmac
from django.utils import six
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
@@ -249,6 +249,13 @@ class AbstractBaseUser(models.Model):
    def get_short_name(self):
        raise NotImplementedError('subclasses of AbstractBaseUser must provide a get_short_name() method.')

    def get_session_auth_hash(self):
        """
        Returns an HMAC of the password field.
        """
        key_salt = "django.contrib.auth.models.AbstractBaseUser.get_session_auth_hash"
        return salted_hmac(key_salt, self.password).hexdigest()


# A few helper functions for common logic between User and AnonymousUser.
def _user_get_all_permissions(user, obj):
Loading