From 825931a9b4e591df6248f359d79c5452835ec21a Mon Sep 17 00:00:00 2001 From: Marcos Date: Tue, 7 Jan 2025 10:32:08 -0300 Subject: [PATCH] chore: Removed edx-token-utils dep and moved necessary logic to the repo --- lms/djangoapps/courseware/jwt.py | 91 +++++++++++++++ lms/djangoapps/courseware/tests/test_jwt.py | 120 ++++++++++++++++++++ lms/djangoapps/courseware/utils.py | 38 +++++++ lms/djangoapps/courseware/views/views.py | 5 +- lms/envs/common.py | 17 ++- 5 files changed, 267 insertions(+), 4 deletions(-) create mode 100644 lms/djangoapps/courseware/jwt.py create mode 100644 lms/djangoapps/courseware/tests/test_jwt.py diff --git a/lms/djangoapps/courseware/jwt.py b/lms/djangoapps/courseware/jwt.py new file mode 100644 index 0000000000..47642b8695 --- /dev/null +++ b/lms/djangoapps/courseware/jwt.py @@ -0,0 +1,91 @@ +""" +JWT Token handling and signing functions. +""" + +import json +from time import time + +from django.conf import settings +from jwkest import Expired, Invalid, MissingKey, jwk +from jwkest.jws import JWS + + +def create_jwt(lms_user_id, expires_in_seconds, additional_token_claims, now=None): + """ + Produce an encoded JWT (string) indicating some temporary permission for the indicated user. + + What permission that is must be encoded in additional_claims. + Arguments: + lms_user_id (int): LMS user ID this token is being generated for + expires_in_seconds (int): Time to token expiry, specified in seconds. + additional_token_claims (dict): Additional claims to include in the token. + now(int): optional now value for testing + """ + now = now or int(time()) + + payload = { + 'lms_user_id': lms_user_id, + 'exp': now + expires_in_seconds, + 'iat': now, + 'iss': settings.TOKEN_SIGNING['JWT_ISSUER'], + 'version': settings.TOKEN_SIGNING['JWT_SUPPORTED_VERSION'], + } + payload.update(additional_token_claims) + return _encode_and_sign(payload) + + +def _encode_and_sign(payload): + """ + Encode and sign the provided payload. + + The signing key and algorithm are pulled from settings. + """ + keys = jwk.KEYS() + + serialized_keypair = json.loads(settings.TOKEN_SIGNING['JWT_PRIVATE_SIGNING_JWK']) + keys.add(serialized_keypair) + algorithm = settings.TOKEN_SIGNING['JWT_SIGNING_ALGORITHM'] + + data = json.dumps(payload) + jws = JWS(data, alg=algorithm) + return jws.sign_compact(keys=keys) + + +def unpack_jwt(token, lms_user_id, now=None): + """ + Unpack and verify an encoded JWT. + + Validate the user and expiration. + + Arguments: + token (string): The token to be unpacked and verified. + lms_user_id (int): LMS user ID this token should match with. + now (int): Optional now value for testing. + + Returns a valid, decoded json payload (string). + """ + now = now or int(time()) + payload = _unpack_and_verify(token) + + if "lms_user_id" not in payload: + raise MissingKey("LMS user id is missing") + if "exp" not in payload: + raise MissingKey("Expiration is missing") + if payload["lms_user_id"] != lms_user_id: + raise Invalid("User does not match") + if payload["exp"] < now: + raise Expired("Token is expired") + + return payload + + +def _unpack_and_verify(token): + """ + Unpack and verify the provided token. + + The signing key and algorithm are pulled from settings. + """ + keys = jwk.KEYS() + keys.load_jwks(settings.TOKEN_SIGNING['JWT_PUBLIC_SIGNING_JWK_SET']) + decoded = JWS().verify_compact(token.encode('utf-8'), keys) + return decoded diff --git a/lms/djangoapps/courseware/tests/test_jwt.py b/lms/djangoapps/courseware/tests/test_jwt.py new file mode 100644 index 0000000000..d21ac1778f --- /dev/null +++ b/lms/djangoapps/courseware/tests/test_jwt.py @@ -0,0 +1,120 @@ +""" +Tests for token handling +""" +import unittest + +from django.conf import settings +from jwkest import BadSignature, Expired, Invalid, MissingKey, jwk +from jwkest.jws import JWS + +from lms.djangoapps.courseware.jwt import _encode_and_sign, create_jwt, unpack_jwt + +import unittest +from unittest.mock import patch + + +test_user_id = 121 +invalid_test_user_id = 120 +test_timeout = 60 +test_now = 1661432902 +test_claims = {"foo": "bar", "baz": "quux", "meaning": 42} +expected_full_token = { + "lms_user_id": test_user_id, + "iat": 1661432902, + "exp": 1661432902 + 60, + "iss": "token-test-issuer", # these lines from test_settings.py + "version": "1.2.0", # these lines from test_settings.py +} + + +class TestSign(unittest.TestCase): + def test_create_jwt(self): + token = create_jwt(test_user_id, test_timeout, {}, test_now) + + decoded = _verify_jwt(token) + self.assertEqual(expected_full_token, decoded) + + def test_create_jwt_with_claims(self): + token = create_jwt(test_user_id, test_timeout, test_claims, test_now) + + expected_token_with_claims = expected_full_token.copy() + expected_token_with_claims.update(test_claims) + + decoded = _verify_jwt(token) + self.assertEqual(expected_token_with_claims, decoded) + + def test_malformed_token(self): + token = create_jwt(test_user_id, test_timeout, test_claims, test_now) + token = token + "a" + + expected_token_with_claims = expected_full_token.copy() + expected_token_with_claims.update(test_claims) + + with self.assertRaises(BadSignature): + _verify_jwt(token) + +def _verify_jwt(jwt_token): + """ + Helper function which verifies the signature and decodes the token + from string back to claims form + """ + keys = jwk.KEYS() + keys.load_jwks(settings.TOKEN_SIGNING['JWT_PUBLIC_SIGNING_JWK_SET']) + decoded = JWS().verify_compact(jwt_token.encode('utf-8'), keys) + return decoded + + +class TestUnpack(unittest.TestCase): + def test_unpack_jwt(self): + token = create_jwt(test_user_id, test_timeout, {}, test_now) + decoded = unpack_jwt(token, test_user_id, test_now) + + self.assertEqual(expected_full_token, decoded) + + def test_unpack_jwt_with_claims(self): + token = create_jwt(test_user_id, test_timeout, test_claims, test_now) + + expected_token_with_claims = expected_full_token.copy() + expected_token_with_claims.update(test_claims) + + decoded = unpack_jwt(token, test_user_id, test_now) + + self.assertEqual(expected_token_with_claims, decoded) + + def test_malformed_token(self): + token = create_jwt(test_user_id, test_timeout, test_claims, test_now) + token = token + "a" + + expected_token_with_claims = expected_full_token.copy() + expected_token_with_claims.update(test_claims) + + with self.assertRaises(BadSignature): + unpack_jwt(token, test_user_id, test_now) + + def test_unpack_token_with_invalid_user(self): + token = create_jwt(invalid_test_user_id, test_timeout, {}, test_now) + + with self.assertRaises(Invalid): + unpack_jwt(token, test_user_id, test_now) + + def test_unpack_expired_token(self): + token = create_jwt(test_user_id, test_timeout, {}, test_now) + + with self.assertRaises(Expired): + unpack_jwt(token, test_user_id, test_now + test_timeout + 1) + + def test_missing_expired_lms_user_id(self): + payload = expected_full_token.copy() + del payload['lms_user_id'] + token = _encode_and_sign(payload) + + with self.assertRaises(MissingKey): + unpack_jwt(token, test_user_id, test_now) + + def test_missing_expired_key(self): + payload = expected_full_token.copy() + del payload['exp'] + token = _encode_and_sign(payload) + + with self.assertRaises(MissingKey): + unpack_jwt(token, test_user_id, test_now) diff --git a/lms/djangoapps/courseware/utils.py b/lms/djangoapps/courseware/utils.py index 5409c89f63..413e024d28 100644 --- a/lms/djangoapps/courseware/utils.py +++ b/lms/djangoapps/courseware/utils.py @@ -5,6 +5,8 @@ import datetime import hashlib import logging +from time import time + from django.conf import settings from django.http import HttpResponse, HttpResponseBadRequest from edx_rest_api_client.client import OAuthAPIClient @@ -15,6 +17,9 @@ from xmodule.partitions.partitions import \ ENROLLMENT_TRACK_PARTITION_ID # lint-amnesty, pylint: disable=wrong-import-order from xmodule.partitions.partitions_service import PartitionService # lint-amnesty, pylint: disable=wrong-import-order +from jwkest import Expired, Invalid, MissingKey, jwk +from jwkest.jws import JWS + from common.djangoapps.course_modes.models import CourseMode from lms.djangoapps.commerce.utils import EcommerceService from lms.djangoapps.courseware.config import ENABLE_NEW_FINANCIAL_ASSISTANCE_FLOW @@ -229,3 +234,36 @@ def _use_new_financial_assistance_flow(course_id): ): return True return False + + + +def unpack_jwt(token, lms_user_id, now=None): + """ + Unpack and verify an encoded JWT. + + Validate the user and expiration. + + Arguments: + token (string): The token to be unpacked and verified. + lms_user_id (int): LMS user ID this token should match with. + now (int): Optional now value for testing. + + Returns a valid, decoded json payload (string). + """ + now = now or int(time()) + + # Unpack and verify token + keys = jwk.KEYS() + keys.load_jwks(settings.TOKEN_SIGNING['JWT_PUBLIC_SIGNING_JWK_SET']) + payload = JWS().verify_compact(token.encode('utf-8'), keys) + + if "lms_user_id" not in payload: + raise MissingKey("LMS user id is missing") + if "exp" not in payload: + raise MissingKey("Expiration is missing") + if payload["lms_user_id"] != lms_user_id: + raise Invalid("User does not match") + if payload["exp"] < now: + raise Expired("Token is expired") + + return payload diff --git a/lms/djangoapps/courseware/views/views.py b/lms/djangoapps/courseware/views/views.py index 6e0804db8c..33bd4013f3 100644 --- a/lms/djangoapps/courseware/views/views.py +++ b/lms/djangoapps/courseware/views/views.py @@ -46,7 +46,6 @@ from rest_framework import status from rest_framework.decorators import api_view, throttle_classes from rest_framework.response import Response from rest_framework.throttling import UserRateThrottle -from token_utils.api import unpack_token_for from web_fragments.fragment import Fragment from xmodule.course_block import ( COURSE_VISIBILITY_PUBLIC, @@ -106,7 +105,7 @@ from lms.djangoapps.courseware.user_state_client import DjangoXBlockUserStateCli from lms.djangoapps.courseware.utils import ( _use_new_financial_assistance_flow, create_financial_assistance_application, - is_eligible_for_financial_aid + is_eligible_for_financial_aid, unpack_jwt ) from lms.djangoapps.edxnotes.helpers import is_feature_enabled from lms.djangoapps.experiments.utils import get_experiment_user_metadata_context @@ -1535,7 +1534,7 @@ def _check_sequence_exam_access(request, location): try: # unpack will validate both expiration and the requesting user matches the # token user - exam_access_unpacked = unpack_token_for(exam_access_token, request.user.id) + exam_access_unpacked = unpack_jwt(exam_access_token, request.user.id) except: # pylint: disable=bare-except log.exception(f"Failed to validate exam access token. user_id={request.user.id} location={location}") return False diff --git a/lms/envs/common.py b/lms/envs/common.py index cb7643c366..e17b88fe91 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -4315,7 +4315,22 @@ TOKEN_SIGNING = { 'JWT_ISSUER': 'http://127.0.0.1:8740', 'JWT_SIGNING_ALGORITHM': 'RS512', 'JWT_SUPPORTED_VERSION': '1.2.0', - 'JWT_PUBLIC_SIGNING_JWK_SET': None, + 'JWT_PUBLIC_SIGNING_JWK_SET': '''{ + "keys": [ + { + "kid":"token-test-wrong-key", + "e": "AQAB", + "kty": "RSA", + "n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dffgRQLD1qf5D6sprmYfWVokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ" + }, + { + "kid":"token-test-sign", + "e": "AQAB", + "kty": "RSA", + "n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ" + } + ] + }''', } COURSE_CATALOG_URL_ROOT = 'http://localhost:8008'