Merge pull request #12954 from edx/renzo/extract-token-generation
Unify JWT generation code
This commit is contained in:
@@ -1,49 +1,14 @@
|
||||
""" Course Discovery API Service. """
|
||||
import datetime
|
||||
|
||||
import jwt
|
||||
from django.conf import settings
|
||||
from edx_rest_api_client.client import EdxRestApiClient
|
||||
|
||||
from openedx.core.djangoapps.theming import helpers
|
||||
from student.models import UserProfile, anonymous_id_for_user
|
||||
|
||||
|
||||
def get_id_token(user):
|
||||
"""
|
||||
Return a JWT for `user`, suitable for use with the course discovery service.
|
||||
|
||||
Arguments:
|
||||
user (User): User for whom to generate the JWT.
|
||||
|
||||
Returns:
|
||||
str: The JWT.
|
||||
"""
|
||||
try:
|
||||
# Service users may not have user profiles.
|
||||
full_name = UserProfile.objects.get(user=user).name
|
||||
except UserProfile.DoesNotExist:
|
||||
full_name = None
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
expires_in = getattr(settings, 'OAUTH_ID_TOKEN_EXPIRATION', 30)
|
||||
|
||||
payload = {
|
||||
'preferred_username': user.username,
|
||||
'name': full_name,
|
||||
'email': user.email,
|
||||
'administrator': user.is_staff,
|
||||
'iss': helpers.get_value('OAUTH_OIDC_ISSUER', settings.OAUTH_OIDC_ISSUER),
|
||||
'exp': now + datetime.timedelta(seconds=expires_in),
|
||||
'iat': now,
|
||||
'aud': helpers.get_value('JWT_AUTH', settings.JWT_AUTH)['JWT_AUDIENCE'],
|
||||
'sub': anonymous_id_for_user(user, None),
|
||||
}
|
||||
secret_key = helpers.get_value('JWT_AUTH', settings.JWT_AUTH)['JWT_SECRET_KEY']
|
||||
|
||||
return jwt.encode(payload, secret_key).decode('utf-8')
|
||||
from openedx.core.lib.token_utils import JwtBuilder
|
||||
|
||||
|
||||
def course_discovery_api_client(user):
|
||||
""" Returns a Course Discovery API client setup with authentication for the specified user. """
|
||||
return EdxRestApiClient(settings.COURSE_CATALOG_API_URL, jwt=get_id_token(user))
|
||||
scopes = ['email', 'profile']
|
||||
expires_in = settings.OAUTH_ID_TOKEN_EXPIRATION
|
||||
jwt = JwtBuilder(user).build_token(scopes, expires_in)
|
||||
|
||||
return EdxRestApiClient(settings.COURSE_CATALOG_API_URL, jwt=jwt)
|
||||
|
||||
@@ -5,13 +5,15 @@ from celery import task
|
||||
from celery.utils.log import get_task_logger # pylint: disable=no-name-in-module, import-error
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from edx_rest_api_client.client import EdxRestApiClient
|
||||
from provider.oauth2.models import Client
|
||||
|
||||
from openedx.core.djangoapps.credentials.models import CredentialsApiConfig
|
||||
from openedx.core.djangoapps.credentials.utils import get_user_credentials
|
||||
from openedx.core.djangoapps.programs.models import ProgramsApiConfig
|
||||
from openedx.core.djangoapps.programs.utils import ProgramProgressMeter
|
||||
from openedx.core.lib.token_utils import get_id_token
|
||||
from openedx.core.lib.token_utils import JwtBuilder
|
||||
|
||||
|
||||
LOGGER = get_task_logger(__name__)
|
||||
@@ -31,8 +33,21 @@ def get_api_client(api_config, student):
|
||||
EdxRestApiClient
|
||||
|
||||
"""
|
||||
id_token = get_id_token(student, api_config.OAUTH2_CLIENT_NAME)
|
||||
return EdxRestApiClient(api_config.internal_api_url, jwt=id_token)
|
||||
# TODO: Use the system's JWT_AUDIENCE and JWT_SECRET_KEY instead of client ID and name.
|
||||
client_name = api_config.OAUTH2_CLIENT_NAME
|
||||
|
||||
try:
|
||||
client = Client.objects.get(name=client_name)
|
||||
except Client.DoesNotExist:
|
||||
raise ImproperlyConfigured(
|
||||
'OAuth2 Client with name [{}] does not exist.'.format(client_name)
|
||||
)
|
||||
|
||||
scopes = ['email', 'profile']
|
||||
expires_in = settings.OAUTH_ID_TOKEN_EXPIRATION
|
||||
jwt = JwtBuilder(student, secret=client.client_secret).build_token(scopes, expires_in, aud=client.client_id)
|
||||
|
||||
return EdxRestApiClient(api_config.internal_api_url, jwt=jwt)
|
||||
|
||||
|
||||
def get_completed_programs(student):
|
||||
|
||||
@@ -34,8 +34,8 @@ class GetApiClientTestCase(TestCase, ProgramsApiConfigMixin):
|
||||
Test the get_api_client function
|
||||
"""
|
||||
|
||||
@mock.patch(TASKS_MODULE + '.get_id_token')
|
||||
def test_get_api_client(self, mock_get_id_token):
|
||||
@mock.patch(TASKS_MODULE + '.JwtBuilder.build_token')
|
||||
def test_get_api_client(self, mock_build_token):
|
||||
"""
|
||||
Ensure the function is making the right API calls based on inputs
|
||||
"""
|
||||
@@ -45,10 +45,9 @@ class GetApiClientTestCase(TestCase, ProgramsApiConfigMixin):
|
||||
internal_service_url='http://foo',
|
||||
api_version_number=99,
|
||||
)
|
||||
mock_get_id_token.return_value = 'test-token'
|
||||
mock_build_token.return_value = 'test-token'
|
||||
|
||||
api_client = tasks.get_api_client(api_config, student)
|
||||
self.assertEqual(mock_get_id_token.call_args[0], (student, 'programs'))
|
||||
self.assertEqual(api_client._store['base_url'], 'http://foo/api/v99/') # pylint: disable=protected-access
|
||||
self.assertEqual(api_client._store['session'].auth.token, 'test-token') # pylint: disable=protected-access
|
||||
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
from __future__ import unicode_literals
|
||||
import logging
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from edx_rest_api_client.client import EdxRestApiClient
|
||||
from provider.oauth2.models import Client
|
||||
|
||||
from openedx.core.lib.token_utils import get_id_token
|
||||
from openedx.core.lib.token_utils import JwtBuilder
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -48,7 +51,20 @@ def get_edx_api_data(api_config, user, resource,
|
||||
|
||||
try:
|
||||
if not api:
|
||||
jwt = get_id_token(user, api_config.OAUTH2_CLIENT_NAME)
|
||||
# TODO: Use the system's JWT_AUDIENCE and JWT_SECRET_KEY instead of client ID and name.
|
||||
client_name = api_config.OAUTH2_CLIENT_NAME
|
||||
|
||||
try:
|
||||
client = Client.objects.get(name=client_name)
|
||||
except Client.DoesNotExist:
|
||||
raise ImproperlyConfigured(
|
||||
'OAuth2 Client with name [{}] does not exist.'.format(client_name)
|
||||
)
|
||||
|
||||
scopes = ['email', 'profile']
|
||||
expires_in = settings.OAUTH_ID_TOKEN_EXPIRATION
|
||||
jwt = JwtBuilder(user, secret=client.client_secret).build_token(scopes, expires_in, aud=client.client_id)
|
||||
|
||||
api = EdxRestApiClient(api_config.internal_api_url, jwt=jwt)
|
||||
except: # pylint: disable=bare-except
|
||||
log.exception('Failed to initialize the %s API client.', api_config.API_NAME)
|
||||
|
||||
@@ -1,71 +1,60 @@
|
||||
"""Tests covering utilities for working with ID tokens."""
|
||||
import calendar
|
||||
import datetime
|
||||
|
||||
"""Tests covering JWT construction utilities."""
|
||||
import ddt
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
import freezegun
|
||||
import jwt
|
||||
from nose.plugins.attrib import attr
|
||||
from edx_oauth2_provider.tests.factories import ClientFactory
|
||||
from provider.constants import CONFIDENTIAL
|
||||
|
||||
from openedx.core.lib.token_utils import get_id_token
|
||||
from student.models import anonymous_id_for_user
|
||||
from lms.djangoapps.oauth_dispatch.tests import mixins
|
||||
from openedx.core.lib.token_utils import JwtBuilder
|
||||
from student.tests.factories import UserFactory, UserProfileFactory
|
||||
|
||||
|
||||
@attr('shard_2')
|
||||
@ddt.ddt
|
||||
class TestIdTokenGeneration(TestCase):
|
||||
"""Tests covering ID token generation."""
|
||||
client_name = 'edx-dummy-client'
|
||||
class TestJwtBuilder(mixins.AccessTokenMixin, TestCase):
|
||||
"""
|
||||
Test class for JwtBuilder.
|
||||
"""
|
||||
|
||||
expires_in = 10
|
||||
|
||||
def setUp(self):
|
||||
super(TestIdTokenGeneration, self).setUp()
|
||||
super(TestJwtBuilder, self).setUp()
|
||||
|
||||
self.oauth2_client = ClientFactory(name=self.client_name, client_type=CONFIDENTIAL)
|
||||
self.user = UserFactory()
|
||||
self.profile = UserProfileFactory(user=self.user)
|
||||
|
||||
self.user = UserFactory.build()
|
||||
self.user.save()
|
||||
@ddt.data(
|
||||
[],
|
||||
['email'],
|
||||
['profile'],
|
||||
['email', 'profile'],
|
||||
)
|
||||
def test_jwt_construction(self, scopes):
|
||||
"""
|
||||
Verify that a valid JWT is built, including claims for the requested scopes.
|
||||
"""
|
||||
token = JwtBuilder(self.user).build_token(scopes, self.expires_in)
|
||||
self.assert_valid_jwt_access_token(token, self.user, scopes)
|
||||
|
||||
@override_settings(OAUTH_OIDC_ISSUER='test-issuer', OAUTH_ID_TOKEN_EXPIRATION=1)
|
||||
@freezegun.freeze_time('2015-01-01 12:00:00')
|
||||
@ddt.data(True, False)
|
||||
def test_get_id_token(self, has_profile):
|
||||
"""Verify that ID tokens are signed with the correct secret and generated with the correct claims."""
|
||||
full_name = UserProfileFactory(user=self.user).name if has_profile else None
|
||||
def test_user_profile_missing(self):
|
||||
"""
|
||||
Verify that token construction succeeds if the UserProfile is missing.
|
||||
"""
|
||||
self.profile.delete() # pylint: disable=no-member
|
||||
|
||||
token = get_id_token(self.user, self.client_name)
|
||||
scopes = ['profile']
|
||||
token = JwtBuilder(self.user).build_token(scopes, self.expires_in)
|
||||
self.assert_valid_jwt_access_token(token, self.user, scopes)
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.oauth2_client.client_secret,
|
||||
audience=self.oauth2_client.client_id,
|
||||
issuer=settings.OAUTH_OIDC_ISSUER,
|
||||
)
|
||||
def test_override_secret_and_audience(self):
|
||||
"""
|
||||
Verify that the signing key and audience can be overridden.
|
||||
"""
|
||||
secret = 'avoid-this'
|
||||
audience = 'avoid-this-too'
|
||||
scopes = []
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
expiration = now + datetime.timedelta(seconds=settings.OAUTH_ID_TOKEN_EXPIRATION)
|
||||
token = JwtBuilder(self.user, secret=secret).build_token(scopes, self.expires_in, aud=audience)
|
||||
|
||||
expected_payload = {
|
||||
'preferred_username': self.user.username,
|
||||
'name': full_name,
|
||||
'email': self.user.email,
|
||||
'administrator': self.user.is_staff,
|
||||
'iss': settings.OAUTH_OIDC_ISSUER,
|
||||
'exp': calendar.timegm(expiration.utctimetuple()),
|
||||
'iat': calendar.timegm(now.utctimetuple()),
|
||||
'aud': self.oauth2_client.client_id,
|
||||
'sub': anonymous_id_for_user(self.user, None),
|
||||
}
|
||||
|
||||
self.assertEqual(payload, expected_payload)
|
||||
|
||||
def test_get_id_token_invalid_client(self):
|
||||
"""Verify that ImproperlyConfigured is raised when an invalid client name is provided."""
|
||||
with self.assertRaises(ImproperlyConfigured):
|
||||
get_id_token(self.user, 'does-not-exist')
|
||||
jwt.decode(token, secret, audience=audience)
|
||||
|
||||
@@ -1,120 +1,100 @@
|
||||
"""Utilities for working with ID tokens."""
|
||||
import datetime
|
||||
from time import time
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.utils.functional import cached_property
|
||||
import jwt
|
||||
from provider.oauth2.models import Client
|
||||
|
||||
from openedx.core.djangoapps.theming import helpers
|
||||
from student.models import UserProfile, anonymous_id_for_user
|
||||
|
||||
|
||||
def get_id_token(user, client_name, secret_key=None):
|
||||
"""Construct a JWT for use with the named client.
|
||||
class JwtBuilder(object):
|
||||
"""Utility for building JWTs.
|
||||
|
||||
The JWT is signed with the named client's secret, and includes the following claims:
|
||||
Unifies diverse approaches to JWT creation in a single class. This utility defaults to using the system's
|
||||
JWT configuration.
|
||||
|
||||
preferred_username (str): The user's username. The claim name is borrowed from edx-oauth2-provider.
|
||||
name (str): The user's full name.
|
||||
email (str): The user's email address.
|
||||
administrator (Boolean): Whether the user has staff permissions.
|
||||
iss (str): Registered claim. Identifies the principal that issued the JWT.
|
||||
exp (int): Registered claim. Identifies the expiration time on or after which
|
||||
the JWT must NOT be accepted for processing.
|
||||
iat (int): Registered claim. Identifies the time at which the JWT was issued.
|
||||
aud (str): Registered claim. Identifies the recipients that the JWT is intended for. This implementation
|
||||
uses the named client's ID.
|
||||
sub (int): Registered claim. Identifies the user. This implementation uses the raw user id.
|
||||
|
||||
Arguments:
|
||||
user (User): User for which to generate the JWT.
|
||||
client_name (unicode): Name of the OAuth2 Client for which the token is intended.
|
||||
secret_key (str): Optional secret key for signing the JWT. Defaults to the configured client secret
|
||||
if not provided.
|
||||
|
||||
Returns:
|
||||
str: the JWT
|
||||
|
||||
Raises:
|
||||
ImproperlyConfigured: If no OAuth2 Client with the provided name exists.
|
||||
"""
|
||||
try:
|
||||
client = Client.objects.get(name=client_name)
|
||||
except Client.DoesNotExist:
|
||||
raise ImproperlyConfigured('OAuth2 Client with name [%s] does not exist' % client_name)
|
||||
|
||||
try:
|
||||
# Service users may not have user profiles.
|
||||
full_name = UserProfile.objects.get(user=user).name
|
||||
except UserProfile.DoesNotExist:
|
||||
full_name = None
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
expires_in = getattr(settings, 'OAUTH_ID_TOKEN_EXPIRATION', 30)
|
||||
|
||||
payload = {
|
||||
'preferred_username': user.username,
|
||||
'name': full_name,
|
||||
'email': user.email,
|
||||
'administrator': user.is_staff,
|
||||
'iss': settings.OAUTH_OIDC_ISSUER,
|
||||
'exp': now + datetime.timedelta(seconds=expires_in),
|
||||
'iat': now,
|
||||
'aud': client.client_id,
|
||||
'sub': anonymous_id_for_user(user, None),
|
||||
}
|
||||
|
||||
if secret_key is None:
|
||||
secret_key = client.client_secret
|
||||
|
||||
return jwt.encode(payload, secret_key)
|
||||
|
||||
|
||||
def get_asymmetric_token(user, client_id):
|
||||
"""Construct a JWT signed with this app's private key.
|
||||
|
||||
The JWT includes the following claims:
|
||||
|
||||
preferred_username (str): The user's username. The claim name is borrowed from edx-oauth2-provider.
|
||||
name (str): The user's full name.
|
||||
email (str): The user's email address.
|
||||
administrator (Boolean): Whether the user has staff permissions.
|
||||
iss (str): Registered claim. Identifies the principal that issued the JWT.
|
||||
exp (int): Registered claim. Identifies the expiration time on or after which
|
||||
the JWT must NOT be accepted for processing.
|
||||
iat (int): Registered claim. Identifies the time at which the JWT was issued.
|
||||
sub (int): Registered claim. Identifies the user. This implementation uses the raw user id.
|
||||
NOTE: This utility class will allow you to override the signing key and audience claim to support those
|
||||
clients which still require this. This approach to JWT creation is DEPRECATED. Avoid doing this for new clients.
|
||||
|
||||
Arguments:
|
||||
user (User): User for which to generate the JWT.
|
||||
|
||||
Returns:
|
||||
str: the JWT
|
||||
|
||||
Keyword Arguments:
|
||||
asymmetric (Boolean): Whether the JWT should be signed with this app's private key.
|
||||
secret (string): Overrides configured JWT secret (signing) key. Unused if an asymmetric signature is requested.
|
||||
"""
|
||||
private_key = load_pem_private_key(settings.PRIVATE_RSA_KEY, None, default_backend())
|
||||
def __init__(self, user, asymmetric=False, secret=None):
|
||||
self.user = user
|
||||
self.asymmetric = asymmetric
|
||||
self.secret = secret
|
||||
self.jwt_auth = helpers.get_value('JWT_AUTH', settings.JWT_AUTH)
|
||||
|
||||
try:
|
||||
# Service users may not have user profiles.
|
||||
full_name = UserProfile.objects.get(user=user).name
|
||||
except UserProfile.DoesNotExist:
|
||||
full_name = None
|
||||
def build_token(self, scopes, expires_in, aud=None):
|
||||
"""Returns a JWT access token.
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
expires_in = getattr(settings, 'OAUTH_ID_TOKEN_EXPIRATION', 30)
|
||||
Arguments:
|
||||
scopes (list): Scopes controlling which optional claims are included in the token.
|
||||
expires_in (int): Time to token expiry, specified in seconds.
|
||||
|
||||
payload = {
|
||||
'preferred_username': user.username,
|
||||
'name': full_name,
|
||||
'email': user.email,
|
||||
'administrator': user.is_staff,
|
||||
'iss': settings.OAUTH_OIDC_ISSUER,
|
||||
'exp': now + datetime.timedelta(seconds=expires_in),
|
||||
'iat': now,
|
||||
'aud': client_id,
|
||||
'sub': anonymous_id_for_user(user, None),
|
||||
}
|
||||
Keyword Arguments:
|
||||
aud (string): Overrides configured JWT audience claim.
|
||||
"""
|
||||
now = int(time())
|
||||
payload = {
|
||||
'aud': aud if aud else self.jwt_auth['JWT_AUDIENCE'],
|
||||
'exp': now + expires_in,
|
||||
'iat': now,
|
||||
'iss': self.jwt_auth['JWT_ISSUER'],
|
||||
'preferred_username': self.user.username,
|
||||
'scopes': scopes,
|
||||
'sub': anonymous_id_for_user(self.user, None),
|
||||
}
|
||||
|
||||
return jwt.encode(payload, private_key, algorithm='RS512')
|
||||
for scope in scopes:
|
||||
handler = self.claim_handlers.get(scope)
|
||||
|
||||
if handler:
|
||||
handler(payload)
|
||||
|
||||
return self.encode(payload)
|
||||
|
||||
@cached_property
|
||||
def claim_handlers(self):
|
||||
"""Returns a dictionary mapping scopes to methods that will add claims to the JWT payload."""
|
||||
|
||||
return {
|
||||
'email': self.attach_email_claim,
|
||||
'profile': self.attach_profile_claim
|
||||
}
|
||||
|
||||
def attach_email_claim(self, payload):
|
||||
"""Add the email claim details to the JWT payload."""
|
||||
payload['email'] = self.user.email
|
||||
|
||||
def attach_profile_claim(self, payload):
|
||||
"""Add the profile claim details to the JWT payload."""
|
||||
try:
|
||||
# Some users (e.g., service users) may not have user profiles.
|
||||
name = UserProfile.objects.get(user=self.user).name
|
||||
except UserProfile.DoesNotExist:
|
||||
name = None
|
||||
|
||||
payload.update({
|
||||
'name': name,
|
||||
'administrator': self.user.is_staff,
|
||||
})
|
||||
|
||||
def encode(self, payload):
|
||||
"""Encode the provided payload."""
|
||||
if self.asymmetric:
|
||||
secret = load_pem_private_key(settings.PRIVATE_RSA_KEY, None, default_backend())
|
||||
algorithm = 'RS512'
|
||||
else:
|
||||
secret = self.secret if self.secret else self.jwt_auth['JWT_SECRET_KEY']
|
||||
algorithm = self.jwt_auth['JWT_ALGORITHM']
|
||||
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
Reference in New Issue
Block a user