Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden revoke access token for password changes #746

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions rest_framework_simplejwt/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractBaseUser
from django.utils.crypto import constant_time_compare
from django.utils.translation import gettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, authentication
from rest_framework.request import Request
Expand All @@ -10,7 +11,7 @@
from .models import TokenUser
from .settings import api_settings
from .tokens import Token
from .utils import get_md5_hash_password
from .utils import get_fallback_token_auth_hash, get_token_auth_hash

AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES

Expand Down Expand Up @@ -135,9 +136,17 @@ def get_user(self, validated_token: Token) -> AuthUser:
raise AuthenticationFailed(_("User is inactive"), code="user_inactive")

if api_settings.CHECK_REVOKE_TOKEN:
if validated_token.get(
api_settings.REVOKE_TOKEN_CLAIM
) != get_md5_hash_password(user.password):
validation_claim = validated_token.get(api_settings.REVOKE_TOKEN_CLAIM)
if (
validation_claim is None
or not constant_time_compare(
validation_claim, get_token_auth_hash(user)
)
and not any(
constant_time_compare(validation_claim, fallback_auth_hash)
for fallback_auth_hash in get_fallback_token_auth_hash(user)
)
):
raise AuthenticationFailed(
_("The user's password has been changed."), code="password_changed"
)
Expand Down
6 changes: 2 additions & 4 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
datetime_from_epoch,
datetime_to_epoch,
format_lazy,
get_md5_hash_password,
get_token_auth_hash,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -208,9 +208,7 @@ def for_user(cls, user: AuthUser) -> "Token":
token[api_settings.USER_ID_CLAIM] = user_id

if api_settings.CHECK_REVOKE_TOKEN:
token[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(
user.password
)
token[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user)

return token

Expand Down
36 changes: 31 additions & 5 deletions rest_framework_simplejwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
import hashlib
from calendar import timegm
from datetime import datetime, timezone
from typing import Callable
from typing import TYPE_CHECKING, Callable, TypeVar

from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.utils.crypto import salted_hmac
from django.utils.functional import lazy
from django.utils.timezone import is_naive, make_aware

if TYPE_CHECKING:
from .models import TokenUser

def get_md5_hash_password(password: str) -> str:
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)


def _get_token_auth_hash(user: "AuthUser", secret=None) -> str:
key_salt = "rest_framework_simplejwt.utils.get_token_auth_hash"
return salted_hmac(key_salt, user.password, secret=secret).hexdigest()


def get_token_auth_hash(user: "AuthUser") -> str:
"""
Returns MD5 hash of the given password
Return an HMAC of the given user password field.
"""
return hashlib.md5(password.encode()).hexdigest().upper()
if hasattr(user, "get_session_auth_hash"):
return user.get_session_auth_hash()
return _get_token_auth_hash(user)


def get_fallback_token_auth_hash(user: "AuthUser") -> str:
"""
Yields a sequence of fallback HMACs of the given user password field.
"""
if hasattr(user, "get_session_auth_fallback_hash"):
yield from user.get_session_auth_fallback_hash()

fallback_keys = getattr(settings, "SECRET_KEY_FALLBACKS", [])
yield from (
_get_token_auth_hash(user, fallback_secret) for fallback_secret in fallback_keys
)


def make_utc(dt: datetime) -> datetime:
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def pytest_configure():
},
SITE_ID=1,
SECRET_KEY="not very secret in tests",
SECRET_KEY_FALLBACKS=[
"old not very secure secret",
"other old not very secure secret",
],
USE_I18N=True,
STATIC_URL="/static/",
ROOT_URLCONF="tests.urls",
Expand Down
53 changes: 20 additions & 33 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rest_framework_simplejwt.models import TokenUser
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import AccessToken, SlidingToken
from rest_framework_simplejwt.utils import get_md5_hash_password
from rest_framework_simplejwt.utils import _get_token_auth_hash, get_token_auth_hash

from .utils import override_api_settings

Expand Down Expand Up @@ -145,60 +145,47 @@ def test_get_user(self):
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

u = User.objects.create_user(username="markhamill")
u.is_active = False
u.save()
user = User.objects.create_user(username="markhamill", is_active=False)

payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD)
payload[api_settings.USER_ID_CLAIM] = getattr(user, api_settings.USER_ID_FIELD)

# Should raise exception if user is inactive
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

u.is_active = True
u.save()
user.is_active = True
user.save()

# Otherwise, should return correct user
self.assertEqual(self.backend.get_user(payload).id, u.id)
self.assertEqual(self.backend.get_user(payload).id, user.id)

@override_api_settings(
CHECK_REVOKE_TOKEN=True, REVOKE_TOKEN_CLAIM="revoke_token_claim"
)
def test_get_user_with_check_revoke_token(self):
payload = {"some_other_id": "foo"}

# Should raise error if no recognizable user identification
with self.assertRaises(InvalidToken):
self.backend.get_user(payload)

payload[api_settings.USER_ID_CLAIM] = 42

# Should raise exception if user not found
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

u = User.objects.create_user(username="markhamill")
u.is_active = False
u.save()
user = User.objects.create_user(username="markhamill")
payload = {
api_settings.USER_ID_CLAIM: getattr(user, api_settings.USER_ID_FIELD)
}

payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD)

# Should raise exception if user is inactive
# Should raise exception if claim is missing
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

u.is_active = True
u.save()

# Should raise exception if hash password is different
payload[api_settings.REVOKE_TOKEN_CLAIM] = "differenthash"
# Should raise exception if claim is different
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

if api_settings.CHECK_REVOKE_TOKEN:
payload[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(u.password)
payload[api_settings.REVOKE_TOKEN_CLAIM] = _get_token_auth_hash(
user, "other old not very secure secret"
)
# Should return correct user if claim was signed with an old key
self.assertEqual(self.backend.get_user(payload).id, user.id)

payload[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user)
# Otherwise, should return correct user
self.assertEqual(self.backend.get_user(payload).id, u.id)
self.assertEqual(self.backend.get_user(payload).id, user.id)


class TestJWTStatelessUserAuthentication(TestCase):
Expand Down