diff --git a/rest_framework_simplejwt/authentication.py b/rest_framework_simplejwt/authentication.py index 13767e1ee..9953e53b4 100644 --- a/rest_framework_simplejwt/authentication.py +++ b/rest_framework_simplejwt/authentication.py @@ -2,6 +2,7 @@ from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractBaseUser +from django.utils.crypto import constant_time_compare from django.utils.translation import gettext_lazy as _ from rest_framework import HTTP_HEADER_ENCODING, authentication from rest_framework.request import Request @@ -10,7 +11,7 @@ from .models import TokenUser from .settings import api_settings from .tokens import Token -from .utils import get_md5_hash_password +from .utils import get_fallback_token_auth_hash, get_token_auth_hash AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES @@ -135,9 +136,17 @@ def get_user(self, validated_token: Token) -> AuthUser: raise AuthenticationFailed(_("User is inactive"), code="user_inactive") if api_settings.CHECK_REVOKE_TOKEN: - if validated_token.get( - api_settings.REVOKE_TOKEN_CLAIM - ) != get_md5_hash_password(user.password): + validation_claim = validated_token.get(api_settings.REVOKE_TOKEN_CLAIM) + if ( + validation_claim is None + or not constant_time_compare( + validation_claim, get_token_auth_hash(user) + ) + and not any( + constant_time_compare(validation_claim, fallback_auth_hash) + for fallback_auth_hash in get_fallback_token_auth_hash(user) + ) + ): raise AuthenticationFailed( _("The user's password has been changed."), code="password_changed" ) diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index b207ef27c..f6b1c4581 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -16,7 +16,7 @@ datetime_from_epoch, datetime_to_epoch, format_lazy, - get_md5_hash_password, + get_token_auth_hash, ) if TYPE_CHECKING: @@ -208,9 +208,7 @@ def for_user(cls, user: AuthUser) -> "Token": token[api_settings.USER_ID_CLAIM] = user_id if api_settings.CHECK_REVOKE_TOKEN: - token[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password( - user.password - ) + token[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user) return token diff --git a/rest_framework_simplejwt/utils.py b/rest_framework_simplejwt/utils.py index 4490fa23c..0c162cb7a 100644 --- a/rest_framework_simplejwt/utils.py +++ b/rest_framework_simplejwt/utils.py @@ -1,18 +1,44 @@ -import hashlib from calendar import timegm from datetime import datetime, timezone -from typing import Callable +from typing import TYPE_CHECKING, Callable, TypeVar from django.conf import settings +from django.contrib.auth.models import AbstractBaseUser +from django.utils.crypto import salted_hmac from django.utils.functional import lazy from django.utils.timezone import is_naive, make_aware +if TYPE_CHECKING: + from .models import TokenUser -def get_md5_hash_password(password: str) -> str: + AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) + + +def _get_token_auth_hash(user: "AuthUser", secret=None) -> str: + key_salt = "rest_framework_simplejwt.utils.get_token_auth_hash" + return salted_hmac(key_salt, user.password, secret=secret).hexdigest() + + +def get_token_auth_hash(user: "AuthUser") -> str: """ - Returns MD5 hash of the given password + Return an HMAC of the given user password field. """ - return hashlib.md5(password.encode()).hexdigest().upper() + if hasattr(user, "get_session_auth_hash"): + return user.get_session_auth_hash() + return _get_token_auth_hash(user) + + +def get_fallback_token_auth_hash(user: "AuthUser") -> str: + """ + Yields a sequence of fallback HMACs of the given user password field. + """ + if hasattr(user, "get_session_auth_fallback_hash"): + yield from user.get_session_auth_fallback_hash() + + fallback_keys = getattr(settings, "SECRET_KEY_FALLBACKS", []) + yield from ( + _get_token_auth_hash(user, fallback_secret) for fallback_secret in fallback_keys + ) def make_utc(dt: datetime) -> datetime: diff --git a/tests/conftest.py b/tests/conftest.py index 5248ff08c..ec0d7557f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,10 @@ def pytest_configure(): }, SITE_ID=1, SECRET_KEY="not very secret in tests", + SECRET_KEY_FALLBACKS=[ + "old not very secure secret", + "other old not very secure secret", + ], USE_I18N=True, STATIC_URL="/static/", ROOT_URLCONF="tests.urls", diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 99f0b5525..876147e56 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -10,7 +10,7 @@ from rest_framework_simplejwt.models import TokenUser from rest_framework_simplejwt.settings import api_settings from rest_framework_simplejwt.tokens import AccessToken, SlidingToken -from rest_framework_simplejwt.utils import get_md5_hash_password +from rest_framework_simplejwt.utils import _get_token_auth_hash, get_token_auth_hash from .utils import override_api_settings @@ -145,60 +145,47 @@ def test_get_user(self): with self.assertRaises(AuthenticationFailed): self.backend.get_user(payload) - u = User.objects.create_user(username="markhamill") - u.is_active = False - u.save() + user = User.objects.create_user(username="markhamill", is_active=False) - payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD) + payload[api_settings.USER_ID_CLAIM] = getattr(user, api_settings.USER_ID_FIELD) # Should raise exception if user is inactive with self.assertRaises(AuthenticationFailed): self.backend.get_user(payload) - u.is_active = True - u.save() + user.is_active = True + user.save() # Otherwise, should return correct user - self.assertEqual(self.backend.get_user(payload).id, u.id) + self.assertEqual(self.backend.get_user(payload).id, user.id) @override_api_settings( CHECK_REVOKE_TOKEN=True, REVOKE_TOKEN_CLAIM="revoke_token_claim" ) def test_get_user_with_check_revoke_token(self): - payload = {"some_other_id": "foo"} - - # Should raise error if no recognizable user identification - with self.assertRaises(InvalidToken): - self.backend.get_user(payload) - - payload[api_settings.USER_ID_CLAIM] = 42 - - # Should raise exception if user not found - with self.assertRaises(AuthenticationFailed): - self.backend.get_user(payload) - - u = User.objects.create_user(username="markhamill") - u.is_active = False - u.save() + user = User.objects.create_user(username="markhamill") + payload = { + api_settings.USER_ID_CLAIM: getattr(user, api_settings.USER_ID_FIELD) + } - payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD) - - # Should raise exception if user is inactive + # Should raise exception if claim is missing with self.assertRaises(AuthenticationFailed): self.backend.get_user(payload) - u.is_active = True - u.save() - - # Should raise exception if hash password is different + payload[api_settings.REVOKE_TOKEN_CLAIM] = "differenthash" + # Should raise exception if claim is different with self.assertRaises(AuthenticationFailed): self.backend.get_user(payload) - if api_settings.CHECK_REVOKE_TOKEN: - payload[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(u.password) + payload[api_settings.REVOKE_TOKEN_CLAIM] = _get_token_auth_hash( + user, "other old not very secure secret" + ) + # Should return correct user if claim was signed with an old key + self.assertEqual(self.backend.get_user(payload).id, user.id) + payload[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user) # Otherwise, should return correct user - self.assertEqual(self.backend.get_user(payload).id, u.id) + self.assertEqual(self.backend.get_user(payload).id, user.id) class TestJWTStatelessUserAuthentication(TestCase):