chore: Removed edx-token-utils dep and moved necessary logic to the repo

This commit is contained in:
Marcos
2025-01-07 10:32:08 -03:00
committed by Marcos Rigoli
parent 2a07080a08
commit 825931a9b4
5 changed files with 267 additions and 4 deletions

View File

@@ -0,0 +1,91 @@
"""
JWT Token handling and signing functions.
"""
import json
from time import time
from django.conf import settings
from jwkest import Expired, Invalid, MissingKey, jwk
from jwkest.jws import JWS
def create_jwt(lms_user_id, expires_in_seconds, additional_token_claims, now=None):
"""
Produce an encoded JWT (string) indicating some temporary permission for the indicated user.
What permission that is must be encoded in additional_claims.
Arguments:
lms_user_id (int): LMS user ID this token is being generated for
expires_in_seconds (int): Time to token expiry, specified in seconds.
additional_token_claims (dict): Additional claims to include in the token.
now(int): optional now value for testing
"""
now = now or int(time())
payload = {
'lms_user_id': lms_user_id,
'exp': now + expires_in_seconds,
'iat': now,
'iss': settings.TOKEN_SIGNING['JWT_ISSUER'],
'version': settings.TOKEN_SIGNING['JWT_SUPPORTED_VERSION'],
}
payload.update(additional_token_claims)
return _encode_and_sign(payload)
def _encode_and_sign(payload):
"""
Encode and sign the provided payload.
The signing key and algorithm are pulled from settings.
"""
keys = jwk.KEYS()
serialized_keypair = json.loads(settings.TOKEN_SIGNING['JWT_PRIVATE_SIGNING_JWK'])
keys.add(serialized_keypair)
algorithm = settings.TOKEN_SIGNING['JWT_SIGNING_ALGORITHM']
data = json.dumps(payload)
jws = JWS(data, alg=algorithm)
return jws.sign_compact(keys=keys)
def unpack_jwt(token, lms_user_id, now=None):
"""
Unpack and verify an encoded JWT.
Validate the user and expiration.
Arguments:
token (string): The token to be unpacked and verified.
lms_user_id (int): LMS user ID this token should match with.
now (int): Optional now value for testing.
Returns a valid, decoded json payload (string).
"""
now = now or int(time())
payload = _unpack_and_verify(token)
if "lms_user_id" not in payload:
raise MissingKey("LMS user id is missing")
if "exp" not in payload:
raise MissingKey("Expiration is missing")
if payload["lms_user_id"] != lms_user_id:
raise Invalid("User does not match")
if payload["exp"] < now:
raise Expired("Token is expired")
return payload
def _unpack_and_verify(token):
"""
Unpack and verify the provided token.
The signing key and algorithm are pulled from settings.
"""
keys = jwk.KEYS()
keys.load_jwks(settings.TOKEN_SIGNING['JWT_PUBLIC_SIGNING_JWK_SET'])
decoded = JWS().verify_compact(token.encode('utf-8'), keys)
return decoded

View File

@@ -0,0 +1,120 @@
"""
Tests for token handling
"""
import unittest
from django.conf import settings
from jwkest import BadSignature, Expired, Invalid, MissingKey, jwk
from jwkest.jws import JWS
from lms.djangoapps.courseware.jwt import _encode_and_sign, create_jwt, unpack_jwt
import unittest
from unittest.mock import patch
test_user_id = 121
invalid_test_user_id = 120
test_timeout = 60
test_now = 1661432902
test_claims = {"foo": "bar", "baz": "quux", "meaning": 42}
expected_full_token = {
"lms_user_id": test_user_id,
"iat": 1661432902,
"exp": 1661432902 + 60,
"iss": "token-test-issuer", # these lines from test_settings.py
"version": "1.2.0", # these lines from test_settings.py
}
class TestSign(unittest.TestCase):
def test_create_jwt(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)
decoded = _verify_jwt(token)
self.assertEqual(expected_full_token, decoded)
def test_create_jwt_with_claims(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
decoded = _verify_jwt(token)
self.assertEqual(expected_token_with_claims, decoded)
def test_malformed_token(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
token = token + "a"
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
with self.assertRaises(BadSignature):
_verify_jwt(token)
def _verify_jwt(jwt_token):
"""
Helper function which verifies the signature and decodes the token
from string back to claims form
"""
keys = jwk.KEYS()
keys.load_jwks(settings.TOKEN_SIGNING['JWT_PUBLIC_SIGNING_JWK_SET'])
decoded = JWS().verify_compact(jwt_token.encode('utf-8'), keys)
return decoded
class TestUnpack(unittest.TestCase):
def test_unpack_jwt(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)
decoded = unpack_jwt(token, test_user_id, test_now)
self.assertEqual(expected_full_token, decoded)
def test_unpack_jwt_with_claims(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
decoded = unpack_jwt(token, test_user_id, test_now)
self.assertEqual(expected_token_with_claims, decoded)
def test_malformed_token(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
token = token + "a"
expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)
with self.assertRaises(BadSignature):
unpack_jwt(token, test_user_id, test_now)
def test_unpack_token_with_invalid_user(self):
token = create_jwt(invalid_test_user_id, test_timeout, {}, test_now)
with self.assertRaises(Invalid):
unpack_jwt(token, test_user_id, test_now)
def test_unpack_expired_token(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)
with self.assertRaises(Expired):
unpack_jwt(token, test_user_id, test_now + test_timeout + 1)
def test_missing_expired_lms_user_id(self):
payload = expected_full_token.copy()
del payload['lms_user_id']
token = _encode_and_sign(payload)
with self.assertRaises(MissingKey):
unpack_jwt(token, test_user_id, test_now)
def test_missing_expired_key(self):
payload = expected_full_token.copy()
del payload['exp']
token = _encode_and_sign(payload)
with self.assertRaises(MissingKey):
unpack_jwt(token, test_user_id, test_now)

View File

@@ -5,6 +5,8 @@ import datetime
import hashlib
import logging
from time import time
from django.conf import settings
from django.http import HttpResponse, HttpResponseBadRequest
from edx_rest_api_client.client import OAuthAPIClient
@@ -15,6 +17,9 @@ from xmodule.partitions.partitions import \
ENROLLMENT_TRACK_PARTITION_ID # lint-amnesty, pylint: disable=wrong-import-order
from xmodule.partitions.partitions_service import PartitionService # lint-amnesty, pylint: disable=wrong-import-order
from jwkest import Expired, Invalid, MissingKey, jwk
from jwkest.jws import JWS
from common.djangoapps.course_modes.models import CourseMode
from lms.djangoapps.commerce.utils import EcommerceService
from lms.djangoapps.courseware.config import ENABLE_NEW_FINANCIAL_ASSISTANCE_FLOW
@@ -229,3 +234,36 @@ def _use_new_financial_assistance_flow(course_id):
):
return True
return False
def unpack_jwt(token, lms_user_id, now=None):
"""
Unpack and verify an encoded JWT.
Validate the user and expiration.
Arguments:
token (string): The token to be unpacked and verified.
lms_user_id (int): LMS user ID this token should match with.
now (int): Optional now value for testing.
Returns a valid, decoded json payload (string).
"""
now = now or int(time())
# Unpack and verify token
keys = jwk.KEYS()
keys.load_jwks(settings.TOKEN_SIGNING['JWT_PUBLIC_SIGNING_JWK_SET'])
payload = JWS().verify_compact(token.encode('utf-8'), keys)
if "lms_user_id" not in payload:
raise MissingKey("LMS user id is missing")
if "exp" not in payload:
raise MissingKey("Expiration is missing")
if payload["lms_user_id"] != lms_user_id:
raise Invalid("User does not match")
if payload["exp"] < now:
raise Expired("Token is expired")
return payload

View File

@@ -46,7 +46,6 @@ from rest_framework import status
from rest_framework.decorators import api_view, throttle_classes
from rest_framework.response import Response
from rest_framework.throttling import UserRateThrottle
from token_utils.api import unpack_token_for
from web_fragments.fragment import Fragment
from xmodule.course_block import (
COURSE_VISIBILITY_PUBLIC,
@@ -106,7 +105,7 @@ from lms.djangoapps.courseware.user_state_client import DjangoXBlockUserStateCli
from lms.djangoapps.courseware.utils import (
_use_new_financial_assistance_flow,
create_financial_assistance_application,
is_eligible_for_financial_aid
is_eligible_for_financial_aid, unpack_jwt
)
from lms.djangoapps.edxnotes.helpers import is_feature_enabled
from lms.djangoapps.experiments.utils import get_experiment_user_metadata_context
@@ -1535,7 +1534,7 @@ def _check_sequence_exam_access(request, location):
try:
# unpack will validate both expiration and the requesting user matches the
# token user
exam_access_unpacked = unpack_token_for(exam_access_token, request.user.id)
exam_access_unpacked = unpack_jwt(exam_access_token, request.user.id)
except: # pylint: disable=bare-except
log.exception(f"Failed to validate exam access token. user_id={request.user.id} location={location}")
return False

View File

@@ -4315,7 +4315,22 @@ TOKEN_SIGNING = {
'JWT_ISSUER': 'http://127.0.0.1:8740',
'JWT_SIGNING_ALGORITHM': 'RS512',
'JWT_SUPPORTED_VERSION': '1.2.0',
'JWT_PUBLIC_SIGNING_JWK_SET': None,
'JWT_PUBLIC_SIGNING_JWK_SET': '''{
"keys": [
{
"kid":"token-test-wrong-key",
"e": "AQAB",
"kty": "RSA",
"n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dffgRQLD1qf5D6sprmYfWVokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ"
},
{
"kid":"token-test-sign",
"e": "AQAB",
"kty": "RSA",
"n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ"
}
]
}''',
}
COURSE_CATALOG_URL_ROOT = 'http://localhost:8008'