Files
edx-platform/lms/djangoapps/courseware/tests/test_jwt.py
2025-01-23 15:55:58 -03:00

127 lines
3.9 KiB
Python

"""
Tests for token handling
"""
import unittest
from django.conf import settings
from jwkest import BadSignature, Expired, Invalid, MissingKey, jwk
from jwkest.jws import JWS
from lms.djangoapps.courseware.jwt import _encode_and_sign, create_jwt, unpack_jwt
test_user_id = 121
invalid_test_user_id = 120
test_timeout = 60
test_now = 1661432902
test_claims = {"foo": "bar", "baz": "quux", "meaning": 42}
expected_full_token = {
"lms_user_id": test_user_id,
"iat": 1661432902,
"exp": 1661432902 + 60,
"iss": "token-test-issuer", # these lines from test_settings.py
"version": "1.2.0", # these lines from test_settings.py
}
class TestSign(unittest.TestCase):
"""
Tests for JWT creation and signing.
"""
def test_create_jwt(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)
decoded = _verify_jwt(token)
self.assertEqual(expected_full_token, decoded)
def test_create_jwt_with_claims(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
decoded = _verify_jwt(token)
self.assertEqual(expected_token_with_claims, decoded)
def test_malformed_token(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
token = token + "a"
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
with self.assertRaises(BadSignature):
_verify_jwt(token)
def _verify_jwt(jwt_token):
"""
Helper function which verifies the signature and decodes the token
from string back to claims form
"""
keys = jwk.KEYS()
keys.load_jwks(settings.TOKEN_SIGNING['JWT_PUBLIC_SIGNING_JWK_SET'])
decoded = JWS().verify_compact(jwt_token.encode('utf-8'), keys)
return decoded
class TestUnpack(unittest.TestCase):
"""
Tests for JWT unpacking.
"""
def test_unpack_jwt(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)
decoded = unpack_jwt(token, test_user_id, test_now)
self.assertEqual(expected_full_token, decoded)
def test_unpack_jwt_with_claims(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
decoded = unpack_jwt(token, test_user_id, test_now)
self.assertEqual(expected_token_with_claims, decoded)
def test_malformed_token(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
token = token + "a"
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
with self.assertRaises(BadSignature):
unpack_jwt(token, test_user_id, test_now)
def test_unpack_token_with_invalid_user(self):
token = create_jwt(invalid_test_user_id, test_timeout, {}, test_now)
with self.assertRaises(Invalid):
unpack_jwt(token, test_user_id, test_now)
def test_unpack_expired_token(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)
with self.assertRaises(Expired):
unpack_jwt(token, test_user_id, test_now + test_timeout + 1)
def test_missing_expired_lms_user_id(self):
payload = expected_full_token.copy()
del payload['lms_user_id']
token = _encode_and_sign(payload)
with self.assertRaises(MissingKey):
unpack_jwt(token, test_user_id, test_now)
def test_missing_expired_key(self):
payload = expected_full_token.copy()
del payload['exp']
token = _encode_and_sign(payload)
with self.assertRaises(MissingKey):
unpack_jwt(token, test_user_id, test_now)