Replace pyjwkest with pyjwt (#32270)

* chore: replace pyjwkest with pyjwt
This commit is contained in:
Muhammad Umar Khan
2023-10-18 15:15:17 +05:00
committed by GitHub
parent 7ef865029a
commit 92731be0dc
6 changed files with 112 additions and 140 deletions

View File

@@ -517,25 +517,35 @@ VIDEO_TRANSCRIPTS_SETTINGS = dict(
####################### Authentication Settings ##########################
JWT_AUTH.update({
'JWT_PUBLIC_SIGNING_JWK_SET': (
'{"keys": [{"kid": "BTZ9HA6K", "e": "AQAB", "kty": "RSA", "n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6'
'sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc'
'4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEu'
'lLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ"}]}'
),
'JWT_PRIVATE_SIGNING_JWK': (
'{"e": "AQAB", "d": "HIiV7KNjcdhVbpn3KT-I9n3JPf5YbGXsCIedmPqDH1d4QhBofuAqZ9zebQuxkRUpmqtYMv0Zi6ECSUqH387GYQF_Xv'
'FUFcjQRPycISd8TH0DAKaDpGr-AYNshnKiEtQpINhcP44I1AYNPCwyoxXA1fGTtmkKChsuWea7o8kytwU5xSejvh5-jiqu2SF4GEl0BEXIAPZs'
'gbzoPIWNxgO4_RzNnWs6nJZeszcaDD0CyezVSuH9QcI6g5QFzAC_YuykSsaaFJhZ05DocBsLczShJ9Omf6PnK9xlm26I84xrEh_7x4fVmNBg3x'
'WTLh8qOnHqGko93A1diLRCrKHOvnpvgQ", "n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0H'
'ChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwO'
'n5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5'
'q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ", "q": "3T3DEtBUka7hLGdIsDlC96Uadx_q_E4Vb1cxx_4Ss_wGp1Lo'
'z3N3ZngGyInsKlmbBgLo1Ykd6T9TRvRNEWEtFSOcm2INIBoVoXk7W5RuPa8Cgq2tjQj9ziGQ08JMejrPlj3Q1wmALJr5VTfvSYBu0WkljhKNCy'
'1KB6fCby0C9WE", "p": "vUqzWPZnDG4IXyo-k5F0bHV0BNL_pVhQoLW7eyFHnw74IOEfSbdsMspNcPSFIrtgPsn7981qv3lN_staZ6JflKfH'
'ayjB_lvltHyZxfl0dvruShZOx1N6ykEo7YrAskC_qxUyrIvqmJ64zPW3jkuOYrFs7Ykj3zFx3Zq1H5568G0", "kid": "BTZ9HA6K", "kty"'
': "RSA"}'
),
'JWT_PUBLIC_SIGNING_JWK_SET': """
{
"keys":[
{
"kid":"BTZ9HA6K",
"e":"AQAB",
"kty":"RSA",
"n":"o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ"
}
]
}
""",
'JWT_PRIVATE_SIGNING_JWK': """
{
"kid": "BTZ9HA6K",
"kty": "RSA",
"key_ops": [
"sign"
],
"n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ",
"e": "AQAB",
"d": "HIiV7KNjcdhVbpn3KT-I9n3JPf5YbGXsCIedmPqDH1d4QhBofuAqZ9zebQuxkRUpmqtYMv0Zi6ECSUqH387GYQF_XvFUFcjQRPycISd8TH0DAKaDpGr-AYNshnKiEtQpINhcP44I1AYNPCwyoxXA1fGTtmkKChsuWea7o8kytwU5xSejvh5-jiqu2SF4GEl0BEXIAPZsgbzoPIWNxgO4_RzNnWs6nJZeszcaDD0CyezVSuH9QcI6g5QFzAC_YuykSsaaFJhZ05DocBsLczShJ9Omf6PnK9xlm26I84xrEh_7x4fVmNBg3xWTLh8qOnHqGko93A1diLRCrKHOvnpvgQ",
"p": "3T3DEtBUka7hLGdIsDlC96Uadx_q_E4Vb1cxx_4Ss_wGp1Loz3N3ZngGyInsKlmbBgLo1Ykd6T9TRvRNEWEtFSOcm2INIBoVoXk7W5RuPa8Cgq2tjQj9ziGQ08JMejrPlj3Q1wmALJr5VTfvSYBu0WkljhKNCy1KB6fCby0C9WE",
"q": "vUqzWPZnDG4IXyo-k5F0bHV0BNL_pVhQoLW7eyFHnw74IOEfSbdsMspNcPSFIrtgPsn7981qv3lN_staZ6JflKfHayjB_lvltHyZxfl0dvruShZOx1N6ykEo7YrAskC_qxUyrIvqmJ64zPW3jkuOYrFs7Ykj3zFx3Zq1H5568G0",
"dp": "Azh08H8r2_sJuBXAzx_mQ6iZnAZQ619PnJFOXjTqnMgcaK8iSHLL2CgDIUQwteUcBphgP0uBrfWIBs5jmM8rUtVz4CcrPb5jdjhHjuu4NxmnFbPlhNoOp8OBUjPP3S-h-fPoaFjxDrUqz_zCdPVzp4S6UTkf6Hu-SiI9CFVFZ8E",
"dq": "WQ44_KTIbIej9qnYUPMA1DoaAF8ImVDIdiOp9c79dC7FvCpN3w-lnuugrYDM1j9Tk5bRrY7-JuE6OaKQgOtajoS1BIxjYHj5xAVPD15CVevOihqeq5Zx0ZAAYmmCKRrfUe0iLx2QnIcoKH1-Azs23OXeeo6nysznZjvv9NVJv60",
"qi": "KSWGH607H1kNG2okjYdmVdNgLxTUB-Wye9a9FNFE49UmQIOJeZYXtDzcjk8IiK3g-EU3CqBeDKVUgHvHFu4_Wj3IrIhKYizS4BeFmOcPDvylDQCmJcC9tXLQgHkxM_MEJ7iLn9FOLRshh7GPgZphXxMhezM26Cz-8r3_mACHu84"
}
""",
})
# pylint: enable=unicode-format-string # lint-amnesty, pylint: disable=bad-option-value
####################### Plugin Settings ##########################

View File

@@ -63,7 +63,7 @@ The code examples below show this in action.
Remove JWT_ISSUERS
~~~~~~~~~~~~~~~~~~
edx_rest_framework_extensions.settings_ supports having a list of **JWT_ISSUERS** instead of just a single
`edx_rest_framework_extensions.settings`_ supports having a list of **JWT_ISSUERS** instead of just a single
one. This support for configuring multiple issuers is present across many services. However, this does not
conform to the `JWT standard`_, where the `issuer`_ is intended to identify the entity that generates and
signs the JWT. In our case, that should be the single Auth service only.
@@ -81,70 +81,56 @@ issuer, but with (the potential of) multiple signing keys stored in a JWT Set.
.. _JSON Web Key Set (JWK Set): https://tools.ietf.org/html/draft-ietf-jose-json-web-key-36#section-5
.. _site configuration: https://github.com/openedx/edx-platform/blob/af841336c7e39d634c238cd8a11c5a3a661aa9e2/openedx/core/djangoapps/site_configuration/__init__.py
Example Code
------------
Features
--------
KeyPair Generation
~~~~~~~~~~~~~~~~~~
Here is code for generating a keypair::
Please have a look at ``openedx/core/djangoapps/oauth_dispatch/management/commands/generate_jwt_signing_key.py``
to get better understanding how to generate keypair using ``PyJWT``.
from Cryptodome.PublicKey import RSA
from jwkest import jwk
The public and private keypair would be similar to the following::
rsa_key = RSA.generate(2048)
rsa_jwk = jwk.RSAKey(kid="your_key_id", key=rsa_key)
## Public keyset
"""
{
"keys": [
{
"kty": "RSA",
"key_ops": ["verify"],
"n": "...",
"e": "...",
"kid": "your_key_id"
}
]
}
"""
To serialize the **public key** in a `JSON Web Key Set (JWK Set)`_::
public_keys = jwk.KEYS()
public_keys.append(rsa_jwk)
serialized_public_keys_json = public_keys.dump_jwks()
and its sample output::
{
"keys": [
{
"kid": "your_key_id",
"e": "strawberry",
"kty": "RSA",
"n": "something"
}
]
}
To serialize the **keypair** as a JWK::
serialized_keypair = rsa_jwk.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)
and its sample output::
{
"e": "strawberry",
"d": "apple",
"n": "banana",
"q": "pear",
"p": "plum",
"kid": "your_key_id",
"kty": "RSA"
}
## Private key
"""
{
"kty": "RSA",
"key_ops": ["sign"],
"n": "...",
"e": "...",
"d": "...",
"p": "...",
"q": "...",
"dp": "...",
"dq": "...",
"qi": "...",
"kid": "your_key_id"
}
"""
Signing
~~~~~~~
To deserialize the keypair from above::
To create a signature you simply need a **payload**, **private key** and your hashing algorithm::
private_keys = jwk.KEYS()
serialized_keypair = json.loads(serialized_keypair_json)
private_keys.add(serialized_keypair)
To create a signature::
from jwkest.jws import JWS
jws = JWS("JWT payload", alg="RS512")
signed_message = jws.sign_compact(keys=private_keys)
signed_message = jwt.encode("JWT payload in dict format", key=private_key, algorithm="RS512")
Note: we specify **RS512** above to identify *RSASSA-PKCS1-v1_5 using SHA-512* as
the signature algorithm value as described in the `JSON Web Algorithms (JWA)`_ spec.
@@ -154,24 +140,20 @@ the signature algorithm value as described in the `JSON Web Algorithms (JWA)`_ s
Verify Signature
~~~~~~~~~~~~~~~~
To verify the signature from above::
To verify the signature we'll be looping through the public keys and try to verify the signature with each of them.
For more details you can have a look at `verify_jwk_signature_using_keyset`_. To generate ``keyset`` required for verification you
can use `get_verification_jwk_key_set`_ method.
public_keys = jwk.KEYS()
public_keys.load_jwks(serialized_public_keys_json)
jws.verify_compact(signed_message, public_keys)
.. _verify_jwk_signature_using_keyset: https://github.com/openedx/edx-drf-extensions/blob/master/edx_rest_framework_extensions/auth/jwt/decoder.py#L270
.. _get_verification_jwk_key_set : https://github.com/openedx/edx-drf-extensions/blob/master/edx_rest_framework_extensions/auth/jwt/decoder.py#L395
Key Rotation
~~~~~~~~~~~~
When a new public key is added in the future, it should have a unique "kid"
value and added to the public keys JWK set::
new_rsa_key = RSA.generate(2048)
new_rsa_jwk = jwk.RSAKey(kid="new_id", key=new_rsa_key)
public_keys.append(new_rsa_jwk)
When a JWS is created, it is signed with a certain "kid"-identified keypair. When it
is later verified, the public key with the matching "kid" in the JWK set is used.
In future if we plan to rotate the keys, we can simply add new key public key to the public keyset and remove the old private one.
Means, at any time there might be more than one public key but there will be only one private key. Considering that we are doing verification
by looping through all the available public keys, the ``kid`` parameter is not
as important as it was before. But it's still recommended to use it. It will help us to differentiate between the old and new public keys.
Consequences
------------

View File

@@ -5,12 +5,13 @@ import json
import logging
from time import time
import jwt
from django.conf import settings
from edx_django_utils.monitoring import increment, set_custom_attribute
from edx_rbac.utils import create_role_auth_claim_for_user
from edx_toggles.toggles import SettingToggle
from jwkest import jwk
from jwkest.jws import JWS
from jwt import PyJWK
from jwt.utils import base64url_encode
from common.djangoapps.student.models import UserProfile, anonymous_id_for_user
@@ -273,17 +274,14 @@ def _attach_profile_claim(payload, user):
def _encode_and_sign(payload, use_asymmetric_key, secret):
"""Encode and sign the provided payload."""
keys = jwk.KEYS()
if use_asymmetric_key:
serialized_keypair = json.loads(settings.JWT_AUTH['JWT_PRIVATE_SIGNING_JWK'])
keys.add(serialized_keypair)
key = json.loads(settings.JWT_AUTH['JWT_PRIVATE_SIGNING_JWK'])
algorithm = settings.JWT_AUTH['JWT_SIGNING_ALGORITHM']
else:
key = secret if secret else settings.JWT_AUTH['JWT_SECRET_KEY']
keys.add({'key': key, 'kty': 'oct'})
secret = secret if secret else settings.JWT_AUTH['JWT_SECRET_KEY']
key = {'k': base64url_encode(secret.encode('utf-8')), 'kty': 'oct'}
algorithm = settings.JWT_AUTH['JWT_ALGORITHM']
data = json.dumps(payload)
jws = JWS(data, alg=algorithm)
return jws.sign_compact(keys=keys)
jwk = PyJWK(key, algorithm)
return jwt.encode(payload, jwk.key, algorithm=algorithm)

View File

@@ -14,7 +14,7 @@ import yaml
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.core.management.base import BaseCommand
from jwkest import jwk
from jwt.algorithms import get_default_algorithms
log = logging.getLogger(__name__)
@@ -123,15 +123,23 @@ class Command(BaseCommand):
def _generate_key_pair(self, key_size, key_id):
log.info('Generating new JWT signing keypair for key id %s.', key_id)
rsa_key = RSA.generate(key_size)
rsa_jwk = jwk.RSAKey(kid=key_id, key=rsa_key)
return rsa_jwk
algo = get_default_algorithms()['RS512']
key_data = algo.prepare_key(rsa_key.export_key('PEM').decode())
rsa_jwk = json.loads(algo.to_jwk(key_data))
public_rsa_jwk = json.loads(algo.to_jwk(key_data.public_key()))
rsa_jwk['kid'] = key_id
public_rsa_jwk['kid'] = key_id
return {'private': rsa_jwk, 'public': public_rsa_jwk}
def _output_public_keys(self, jwk_key, add_previous, strip_prefix):
public_keys = jwk.KEYS()
public_keys = {'keys': []}
if add_previous:
self._add_previous_public_keys(public_keys)
public_keys.append(jwk_key)
serialized_public_keys = public_keys.dump_jwks()
public_keys['keys'].append(jwk_key['public'])
serialized_public_keys = json.dumps(public_keys)
prefix = '' if strip_prefix else 'COMMON_'
public_signing_key = f'{prefix}JWT_PUBLIC_SIGNING_JWK_SET'
@@ -155,11 +163,10 @@ class Command(BaseCommand):
previous_signing_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
if previous_signing_keys:
log.info('Old JWT_PUBLIC_SIGNING_JWK_SET: %s.', previous_signing_keys)
public_keys.load_jwks(previous_signing_keys)
public_keys['keys'].extend(json.loads(previous_signing_keys)['keys'])
def _output_private_keys(self, jwk_key, strip_prefix):
serialized_keypair = jwk_key.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)
serialized_keypair_json = json.dumps(jwk_key['private'])
prefix = '' if strip_prefix else 'EDXAPP_'
private_signing_key = f'{prefix}JWT_PRIVATE_SIGNING_JWK'

View File

@@ -3,10 +3,11 @@ OAuth Dispatch test mixins
"""
import pytest
import jwt
from django.conf import settings
from jwkest.jwk import KEYS
from jwkest.jws import JWS
from edx_rest_framework_extensions.auth.jwt.decoder import (
get_verification_jwk_key_set,
verify_jwk_signature_using_keyset
)
from jwt.exceptions import ExpiredSignatureError
from common.djangoapps.student.models import UserProfile, anonymous_id_for_user
@@ -33,25 +34,15 @@ class AccessTokenMixin:
Helper method to decode a JWT with the ability to
verify the expiration of said token
"""
keys = KEYS()
if should_be_asymmetric_key:
keys.load_jwks(settings.JWT_AUTH['JWT_PUBLIC_SIGNING_JWK_SET'])
else:
keys.add({'key': secret_key, 'kty': 'oct'})
asymmetric_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET') if should_be_asymmetric_key else None
key_set = get_verification_jwk_key_set(asymmetric_keys=asymmetric_keys, secret_key=secret_key)
data = verify_jwk_signature_using_keyset(access_token,
key_set,
iss=issuer,
aud=aud,
verify_exp=verify_expiration)
_ = JWS().verify_compact(access_token.encode('utf-8'), keys)
return jwt.decode(
access_token,
secret_key,
algorithms=[settings.JWT_AUTH['JWT_ALGORITHM']],
audience=audience,
issuer=issuer,
options={
'verify_signature': False,
"verify_exp": verify_expiration
},
)
return data
# Note that if we expect the claims to have expired
# then we ask the JWT library not to verify expiration

View File

@@ -9,12 +9,10 @@ from unittest.mock import call, patch
import ddt
import httpretty
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.test import RequestFactory, TestCase
from django.urls import reverse
from edx_toggles.toggles.testutils import override_waffle_switch
from jwkest import jwk
from oauth2_provider import models as dot_models
from common.djangoapps.student.tests.factories import UserFactory
@@ -164,20 +162,6 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
return body
def _generate_key_pair(self):
""" Generates an asymmetric key pair and returns the JWK of its public keys and keypair. """
rsa_key = RSA.generate(2048)
rsa_jwk = jwk.RSAKey(kid="key_id", key=rsa_key)
public_keys = jwk.KEYS()
public_keys.append(rsa_jwk)
serialized_public_keys_json = public_keys.dump_jwks()
serialized_keypair = rsa_jwk.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)
return serialized_public_keys_json, serialized_keypair_json
def _test_jwt_access_token(self, client_attr, token_type=None, headers=None, grant_type=None, asymmetric_jwt=False):
"""
Test response for JWT token.