diff --git a/openedx/core/djangoapps/oauth_dispatch/jwt.py b/openedx/core/djangoapps/oauth_dispatch/jwt.py index cf13a2259b..de493715a7 100644 --- a/openedx/core/djangoapps/oauth_dispatch/jwt.py +++ b/openedx/core/djangoapps/oauth_dispatch/jwt.py @@ -276,6 +276,19 @@ def _attach_profile_claim(payload, user): 'superuser': user.is_superuser, }) +# .. toggle_name: JWT_AUTH_ADD_KID_HEADER +# .. toggle_implementation: SettingToggle +# .. toggle_default: False +# .. toggle_description: When True, add KID header to JWT using asymmetrical key. +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2024-03-20 +# .. toggle_target_removal_date: 2024-04-20 +# .. toggle_tickets: +# https://2u-internal.atlassian.net/browse/AUTH-195?atlOrigin=eyJpIjoiODMzODBiODMwMjU5NGRiZTkyOTIzYThhZjZiNWE0MzMiLCJwIjoiaiJ9 +JWT_AUTH_ADD_KID_HEADER = SettingToggle( + 'JWT_AUTH_ADD_KID_HEADER', default=False, module_name=__name__ +) + def _encode_and_sign(payload, use_asymmetric_key, secret): """Encode and sign the provided payload.""" @@ -289,6 +302,9 @@ def _encode_and_sign(payload, use_asymmetric_key, secret): algorithm = settings.JWT_AUTH['JWT_ALGORITHM'] jwk = PyJWK(key, algorithm) + if JWT_AUTH_ADD_KID_HEADER.is_enabled() and jwk.key_id: + return jwt.encode(payload, jwk.key, algorithm=algorithm, headers={'kid': jwk.key_id}) + return jwt.encode(payload, jwk.key, algorithm=algorithm) diff --git a/openedx/core/djangoapps/oauth_dispatch/tests/test_jwt.py b/openedx/core/djangoapps/oauth_dispatch/tests/test_jwt.py index 7851158614..da95fd072d 100644 --- a/openedx/core/djangoapps/oauth_dispatch/tests/test_jwt.py +++ b/openedx/core/djangoapps/oauth_dispatch/tests/test_jwt.py @@ -79,6 +79,33 @@ class TestCreateJWTs(AccessTokenMixin, TestCase): jwt_token = self._create_jwt_for_token(DOTAdapter(), use_asymmetric_key=False) self._assert_jwt_is_valid(jwt_token, should_be_asymmetric_key=True) + def test_kid_not_in_jwt_header_with_symmetric_key_and_kid_disabled(self): + jwt_token = self._create_jwt_for_token(DOTAdapter(), use_asymmetric_key=False) + header = jwt_api.jwt.get_unverified_header(jwt_token) + assert 'kid' not in header + self._assert_jwt_is_valid(jwt_token, should_be_asymmetric_key=False) + + def test_kid_not_in_jwt_header_with_asymmetric_key_and_kid_disabled(self): + jwt_token = self._create_jwt_for_token(DOTAdapter(), use_asymmetric_key=True) + header = jwt_api.jwt.get_unverified_header(jwt_token) + assert 'kid' not in header + self._assert_jwt_is_valid(jwt_token, should_be_asymmetric_key=True) + + @override_settings(JWT_AUTH_ADD_KID_HEADER=True) + def test_kid_not_in_jwt_header_with_symmetric_key_and_kid_enabled(self): + jwt_token = self._create_jwt_for_token(DOTAdapter(), use_asymmetric_key=False) + header = jwt_api.jwt.get_unverified_header(jwt_token) + assert 'kid' not in header + self._assert_jwt_is_valid(jwt_token, should_be_asymmetric_key=False) + + @override_settings(JWT_AUTH_ADD_KID_HEADER=True) + def test_kid_in_jwt_header_with_asymmetric_key_and_kid_enabled(self): + jwt_token = self._create_jwt_for_token(DOTAdapter(), use_asymmetric_key=True) + header = jwt_api.jwt.get_unverified_header(jwt_token) + assert 'kid' in header + assert header['kid'] == 'BTZ9HA6K' + self._assert_jwt_is_valid(jwt_token, should_be_asymmetric_key=True) + def test_create_jwt_for_token_default_expire_seconds(self): oauth_adapter = DOTAdapter() jwt_token = self._create_jwt_for_token(oauth_adapter, use_asymmetric_key=False)