Merge pull request #8647 from mcgachey/lti-view-refactoring
[LTI Provider] Refactoring to remove the lti_run method
This commit is contained in:
@@ -5,8 +5,6 @@ Subclass of oauthlib's RequestValidator that checks an OAuth signature.
|
||||
from oauthlib.oauth1 import SignatureOnlyEndpoint
|
||||
from oauthlib.oauth1 import RequestValidator
|
||||
|
||||
from lti_provider.models import LtiConsumer
|
||||
|
||||
|
||||
class SignatureValidator(RequestValidator):
|
||||
"""
|
||||
@@ -18,9 +16,10 @@ class SignatureValidator(RequestValidator):
|
||||
application-specific requirements.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, lti_consumer):
|
||||
super(SignatureValidator, self).__init__()
|
||||
self.endpoint = SignatureOnlyEndpoint(self)
|
||||
self.lti_consumer = lti_consumer
|
||||
|
||||
# The OAuth signature uses the endpoint URL as part of the request to be
|
||||
# hashed. By default, the oauthlib library rejects any URLs that do not
|
||||
@@ -77,7 +76,7 @@ class SignatureValidator(RequestValidator):
|
||||
|
||||
:return: True if the key is valid, False if it is not.
|
||||
"""
|
||||
return LtiConsumer.objects.filter(consumer_key=client_key).count() == 1
|
||||
return self.lti_consumer.consumer_key == client_key
|
||||
|
||||
def get_client_secret(self, client_key, request):
|
||||
"""
|
||||
@@ -87,10 +86,7 @@ class SignatureValidator(RequestValidator):
|
||||
:return: the client secret that corresponds to the supplied key if
|
||||
present, or None if the key does not exist in the database.
|
||||
"""
|
||||
try:
|
||||
return LtiConsumer.objects.get(consumer_key=client_key).consumer_secret
|
||||
except LtiConsumer.DoesNotExist:
|
||||
return None
|
||||
return self.lti_consumer.consumer_secret
|
||||
|
||||
def verify(self, request):
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Tests for the SignatureValidator class.
|
||||
"""
|
||||
|
||||
import ddt
|
||||
from django.test import TestCase
|
||||
from django.test.client import RequestFactory
|
||||
from mock import patch
|
||||
@@ -10,100 +11,95 @@ from lti_provider.models import LtiConsumer
|
||||
from lti_provider.signature_validator import SignatureValidator
|
||||
|
||||
|
||||
class SignatureValidatorTest(TestCase):
|
||||
def get_lti_consumer():
|
||||
"""
|
||||
Tests for the custom SignatureValidator class that uses the oauthlib library
|
||||
to check message signatures. Note that these tests mock out the library
|
||||
itself, since we assume it to be correct.
|
||||
Helper method for all Signature Validator tests to get an LtiConsumer object.
|
||||
"""
|
||||
return LtiConsumer(
|
||||
consumer_name='Consumer Name',
|
||||
consumer_key='Consumer Key',
|
||||
consumer_secret='Consumer Secret'
|
||||
)
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class ClientKeyValidatorTest(TestCase):
|
||||
"""
|
||||
Tests for the check_client_key method in the SignatureValidator class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super(ClientKeyValidatorTest, self).setUp()
|
||||
self.lti_consumer = get_lti_consumer()
|
||||
|
||||
def test_valid_client_key(self):
|
||||
"""
|
||||
Verify that check_client_key succeeds with a valid key
|
||||
"""
|
||||
key = 'valid_key'
|
||||
self.assertTrue(SignatureValidator().check_client_key(key))
|
||||
key = self.lti_consumer.consumer_key
|
||||
self.assertTrue(SignatureValidator(self.lti_consumer).check_client_key(key))
|
||||
|
||||
def test_long_client_key(self):
|
||||
@ddt.data(
|
||||
('0123456789012345678901234567890123456789',),
|
||||
('',),
|
||||
(None,),
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_invalid_client_key(self, key):
|
||||
"""
|
||||
Verify that check_client_key fails with a key that is too long
|
||||
Verify that check_client_key fails with a disallowed key
|
||||
"""
|
||||
key = '0123456789012345678901234567890123456789'
|
||||
self.assertFalse(SignatureValidator().check_client_key(key))
|
||||
self.assertFalse(SignatureValidator(self.lti_consumer).check_client_key(key))
|
||||
|
||||
def test_empty_client_key(self):
|
||||
"""
|
||||
Verify that check_client_key fails with a key that is an empty string
|
||||
"""
|
||||
key = ''
|
||||
self.assertFalse(SignatureValidator().check_client_key(key))
|
||||
|
||||
def test_null_client_key(self):
|
||||
"""
|
||||
Verify that check_client_key fails with a key that is None
|
||||
"""
|
||||
key = None
|
||||
self.assertFalse(SignatureValidator().check_client_key(key))
|
||||
@ddt.ddt
|
||||
class NonceValidatorTest(TestCase):
|
||||
"""
|
||||
Tests for the check_nonce method in the SignatureValidator class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super(NonceValidatorTest, self).setUp()
|
||||
self.lti_consumer = get_lti_consumer()
|
||||
|
||||
def test_valid_nonce(self):
|
||||
"""
|
||||
Verify that check_nonce succeeds with a key of maximum length
|
||||
"""
|
||||
nonce = '0123456789012345678901234567890123456789012345678901234567890123'
|
||||
self.assertTrue(SignatureValidator().check_nonce(nonce))
|
||||
self.assertTrue(SignatureValidator(self.lti_consumer).check_nonce(nonce))
|
||||
|
||||
def test_long_nonce(self):
|
||||
@ddt.data(
|
||||
('01234567890123456789012345678901234567890123456789012345678901234',),
|
||||
('',),
|
||||
(None,),
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_invalid_nonce(self, nonce):
|
||||
"""
|
||||
Verify that check_nonce fails with a key that is too long
|
||||
Verify that check_nonce fails with badly formatted nonce
|
||||
"""
|
||||
nonce = '01234567890123456789012345678901234567890123456789012345678901234'
|
||||
self.assertFalse(SignatureValidator().check_nonce(nonce))
|
||||
self.assertFalse(SignatureValidator(self.lti_consumer).check_nonce(nonce))
|
||||
|
||||
def test_empty_nonce(self):
|
||||
"""
|
||||
Verify that check_nonce fails with a key that is an empty string
|
||||
"""
|
||||
nonce = ''
|
||||
self.assertFalse(SignatureValidator().check_nonce(nonce))
|
||||
|
||||
def test_null_nonce(self):
|
||||
"""
|
||||
Verify that check_nonce fails with a key that is None
|
||||
"""
|
||||
nonce = None
|
||||
self.assertFalse(SignatureValidator().check_nonce(nonce))
|
||||
|
||||
def test_validate_existing_key(self):
|
||||
"""
|
||||
Verify that validate_client_key succeeds if the client key exists in the
|
||||
database
|
||||
"""
|
||||
LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret')
|
||||
self.assertTrue(SignatureValidator().validate_client_key('client_key', None))
|
||||
|
||||
def test_validate_missing_key(self):
|
||||
"""
|
||||
Verify that validate_client_key fails if the client key is not in the
|
||||
database
|
||||
"""
|
||||
self.assertFalse(SignatureValidator().validate_client_key('client_key', None))
|
||||
class SignatureValidatorTest(TestCase):
|
||||
"""
|
||||
Tests for the custom SignatureValidator class that uses the oauthlib library
|
||||
to check message signatures. Note that these tests mock out the library
|
||||
itself, since we assume it to be correct.
|
||||
"""
|
||||
def setUp(self):
|
||||
super(SignatureValidatorTest, self).setUp()
|
||||
self.lti_consumer = get_lti_consumer()
|
||||
|
||||
def test_get_existing_client_secret(self):
|
||||
"""
|
||||
Verify that get_client_secret returns the right value if the key is in
|
||||
the database
|
||||
Verify that get_client_secret returns the right value for the correct
|
||||
key
|
||||
"""
|
||||
LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret')
|
||||
secret = SignatureValidator().get_client_secret('client_key', None)
|
||||
self.assertEqual(secret, 'client_secret')
|
||||
|
||||
def test_get_missing_client_secret(self):
|
||||
"""
|
||||
Verify that get_client_secret returns None if the key is not in the
|
||||
database
|
||||
"""
|
||||
secret = SignatureValidator().get_client_secret('client_key', None)
|
||||
self.assertIsNone(secret)
|
||||
key = self.lti_consumer.consumer_key
|
||||
secret = SignatureValidator(self.lti_consumer).get_client_secret(key, None)
|
||||
self.assertEqual(secret, self.lti_consumer.consumer_secret)
|
||||
|
||||
@patch('oauthlib.oauth1.SignatureOnlyEndpoint.validate_request',
|
||||
return_value=(True, None))
|
||||
@@ -116,6 +112,6 @@ class SignatureValidatorTest(TestCase):
|
||||
content_type = 'application/x-www-form-urlencoded'
|
||||
request = RequestFactory().post('/url', body, content_type=content_type)
|
||||
headers = {'Content-Type': content_type}
|
||||
SignatureValidator().verify(request)
|
||||
SignatureValidator(self.lti_consumer).verify(request)
|
||||
verify_mock.assert_called_once_with(
|
||||
request.build_absolute_uri(), 'POST', body, headers)
|
||||
|
||||
@@ -57,17 +57,6 @@ def build_launch_request(authenticated=True):
|
||||
return request
|
||||
|
||||
|
||||
def build_run_request(authenticated=True):
|
||||
"""
|
||||
Helper method to create a new request object
|
||||
"""
|
||||
request = RequestFactory().get('/')
|
||||
request.user = UserFactory.create()
|
||||
request.user.is_authenticated = MagicMock(return_value=authenticated)
|
||||
request.session = {views.LTI_SESSION_KEY: ALL_PARAMS.copy()}
|
||||
return request
|
||||
|
||||
|
||||
class LtiTestMixin(object):
|
||||
"""
|
||||
Mixin for LTI tests
|
||||
@@ -144,49 +133,6 @@ class LtiLaunchTest(LtiTestMixin, TestCase):
|
||||
response = views.lti_launch(request, None, None)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
@patch('lti_provider.views.lti_run')
|
||||
@patch('lti_provider.views.authenticate_lti_user')
|
||||
def test_session_contents_after_launch(self, _authenticate, _run):
|
||||
"""
|
||||
Verifies that the LTI parameters and the course and usage IDs are
|
||||
properly stored in the session
|
||||
"""
|
||||
request = build_launch_request()
|
||||
views.lti_launch(request, unicode(COURSE_KEY), unicode(USAGE_KEY))
|
||||
session = request.session[views.LTI_SESSION_KEY]
|
||||
self.assertEqual(session['course_key'], COURSE_KEY, 'Course key not set in the session')
|
||||
self.assertEqual(session['usage_key'], USAGE_KEY, 'Usage key not set in the session')
|
||||
for key in views.REQUIRED_PARAMETERS:
|
||||
self.assertEqual(session[key], request.POST[key], key + ' not set in the session')
|
||||
|
||||
@patch('lti_provider.views.lti_run')
|
||||
@patch('lti_provider.views.authenticate_lti_user')
|
||||
def test_optional_parameters_in_session(self, _authenticate, _run):
|
||||
"""
|
||||
Verifies that the outcome-related optional LTI parameters are properly
|
||||
stored in the session
|
||||
"""
|
||||
request = build_launch_request()
|
||||
request.POST.update(LTI_OPTIONAL_PARAMS)
|
||||
views.lti_launch(
|
||||
request,
|
||||
unicode(COURSE_PARAMS['course_key']),
|
||||
unicode(COURSE_PARAMS['usage_key'])
|
||||
)
|
||||
session = request.session[views.LTI_SESSION_KEY]
|
||||
self.assertEqual(
|
||||
session['lis_result_sourcedid'], u'result sourcedid',
|
||||
'Result sourcedid not set in the session'
|
||||
)
|
||||
self.assertEqual(
|
||||
session['lis_outcome_service_url'], u'outcome service URL',
|
||||
'Outcome service URL not set in the session'
|
||||
)
|
||||
self.assertEqual(
|
||||
session['tool_consumer_instance_guid'], u'consumer instance guid',
|
||||
'Consumer instance GUID not set in the session'
|
||||
)
|
||||
|
||||
def test_forbidden_if_signature_fails(self):
|
||||
"""
|
||||
Verifies that the view returns Forbidden if the LTI OAuth signature is
|
||||
@@ -198,71 +144,22 @@ class LtiLaunchTest(LtiTestMixin, TestCase):
|
||||
self.assertEqual(response.status_code, 403)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
|
||||
class LtiRunTest(LtiTestMixin, TestCase):
|
||||
"""
|
||||
Tests for the lti_run view
|
||||
"""
|
||||
@patch('lti_provider.views.render_courseware')
|
||||
def test_valid_launch(self, render):
|
||||
"""
|
||||
Verifies that the view returns OK if called with the correct context
|
||||
"""
|
||||
request = build_run_request()
|
||||
views.lti_run(request)
|
||||
render.assert_called_with(request, ALL_PARAMS['usage_key'])
|
||||
|
||||
def test_forbidden_if_session_key_missing(self):
|
||||
"""
|
||||
Verifies that the lti_run view returns a Forbidden status if the session
|
||||
doesn't have an entry for the LTI parameters.
|
||||
"""
|
||||
request = build_run_request()
|
||||
del request.session[views.LTI_SESSION_KEY]
|
||||
response = views.lti_run(request)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
def test_forbidden_if_session_incomplete(self):
|
||||
"""
|
||||
Verifies that the lti_run view returns a Forbidden status if the session
|
||||
is missing any of the required LTI parameters or course information.
|
||||
"""
|
||||
extra_keys = ['course_key', 'usage_key']
|
||||
for key in views.REQUIRED_PARAMETERS + extra_keys:
|
||||
request = build_run_request()
|
||||
del request.session[views.LTI_SESSION_KEY][key]
|
||||
response = views.lti_run(request)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
403,
|
||||
'Expected Forbidden response when session is missing ' + key
|
||||
)
|
||||
|
||||
@patch('lti_provider.views.render_courseware')
|
||||
def test_session_cleared_in_view(self, _render):
|
||||
"""
|
||||
Verifies that the LTI parameters are cleaned out of the session after
|
||||
launching the view to prevent a launch being replayed.
|
||||
"""
|
||||
request = build_run_request()
|
||||
views.lti_run(request)
|
||||
self.assertNotIn(views.LTI_SESSION_KEY, request.session)
|
||||
|
||||
@patch('lti_provider.views.render_courseware')
|
||||
def test_lti_consumer_record_supplemented_with_guid(self, _render):
|
||||
request = build_run_request()
|
||||
request.session[views.LTI_SESSION_KEY]['tool_consumer_instance_guid'] = 'instance_guid'
|
||||
SignatureValidator.verify = MagicMock(return_value=False)
|
||||
request = build_launch_request()
|
||||
request.POST.update(LTI_OPTIONAL_PARAMS)
|
||||
with self.assertNumQueries(4):
|
||||
views.lti_run(request)
|
||||
views.lti_launch(request, None, None)
|
||||
consumer = models.LtiConsumer.objects.get(
|
||||
consumer_key=LTI_DEFAULT_PARAMS['oauth_consumer_key']
|
||||
)
|
||||
self.assertEqual(consumer.instance_guid, 'instance_guid')
|
||||
self.assertEqual(consumer.instance_guid, u'consumer instance guid')
|
||||
|
||||
|
||||
class LtiRunTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase):
|
||||
class LtiLaunchTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase):
|
||||
"""
|
||||
Tests for the rendering returned by lti_run view.
|
||||
Tests for the rendering returned by lti_launch view.
|
||||
This class overrides the get_response method, which is used by
|
||||
the tests defined in RenderXBlockTestMixin.
|
||||
"""
|
||||
|
||||
@@ -14,5 +14,4 @@ urlpatterns = patterns(
|
||||
usage_id=settings.USAGE_ID_PATTERN
|
||||
),
|
||||
'lti_provider.views.lti_launch', name="lti_provider_launch"),
|
||||
url(r'^lti_run$', 'lti_provider.views.lti_run', name="lti_provider_run"),
|
||||
)
|
||||
|
||||
@@ -33,8 +33,6 @@ OPTIONAL_PARAMETERS = [
|
||||
'tool_consumer_instance_guid'
|
||||
]
|
||||
|
||||
LTI_SESSION_KEY = 'lti_provider_parameters'
|
||||
|
||||
|
||||
@csrf_exempt
|
||||
def lti_launch(request, course_id, usage_id):
|
||||
@@ -48,38 +46,32 @@ def lti_launch(request, course_id, usage_id):
|
||||
- The launch contains all the required parameters
|
||||
- The launch data is correctly signed using a known client key/secret
|
||||
pair
|
||||
- The user is logged into the edX instance
|
||||
|
||||
Authentication in this view is a little tricky, since clients use a POST
|
||||
with parameters to fetch it. We can't just use @login_required since in the
|
||||
case where a user is not logged in it will redirect back after login using a
|
||||
GET request, which would lose all of our LTI parameters.
|
||||
|
||||
Instead, we verify the LTI launch in this view before checking if the user
|
||||
is logged in, and store the required LTI parameters in the session. Then we
|
||||
do the authentication check, and if login is required we redirect back to
|
||||
the lti_run view. If the user is already logged in, we just call that view
|
||||
directly.
|
||||
"""
|
||||
|
||||
if not settings.FEATURES['ENABLE_LTI_PROVIDER']:
|
||||
return HttpResponseForbidden()
|
||||
|
||||
# Check the OAuth signature on the message
|
||||
try:
|
||||
if not SignatureValidator().verify(request):
|
||||
return HttpResponseForbidden()
|
||||
except LtiConsumer.DoesNotExist:
|
||||
return HttpResponseForbidden()
|
||||
|
||||
# Check the LTI parameters, and return 400 if any required parameters are
|
||||
# missing
|
||||
params = get_required_parameters(request.POST)
|
||||
if not params:
|
||||
return HttpResponseBadRequest()
|
||||
params.update(get_optional_parameters(request.POST))
|
||||
|
||||
# Store the course, and usage ID in the session to prevent privilege
|
||||
# escalation if a staff member in one course tries to access material in
|
||||
# another.
|
||||
# Get the consumer information from either the instance GUID or the consumer
|
||||
# key
|
||||
try:
|
||||
lti_consumer = LtiConsumer.get_or_supplement(
|
||||
params.get('tool_consumer_instance_guid', None),
|
||||
params['oauth_consumer_key']
|
||||
)
|
||||
except LtiConsumer.DoesNotExist:
|
||||
return HttpResponseForbidden()
|
||||
|
||||
# Check the OAuth signature on the message
|
||||
if not SignatureValidator(lti_consumer).verify(request):
|
||||
return HttpResponseForbidden()
|
||||
|
||||
# Add the course and usage keys to the parameters array
|
||||
try:
|
||||
course_key, usage_key = parse_course_and_usage_keys(course_id, usage_id)
|
||||
except InvalidKeyError:
|
||||
@@ -93,57 +85,13 @@ def lti_launch(request, course_id, usage_id):
|
||||
params['course_key'] = course_key
|
||||
params['usage_key'] = usage_key
|
||||
|
||||
try:
|
||||
lti_consumer = LtiConsumer.get_or_supplement(
|
||||
params.get('tool_consumer_instance_guid', None),
|
||||
params['oauth_consumer_key']
|
||||
)
|
||||
except LtiConsumer.DoesNotExist:
|
||||
return HttpResponseForbidden()
|
||||
|
||||
# Create an edX account if the user identifed by the LTI launch doesn't have
|
||||
# one already, and log the edX account into the platform.
|
||||
authenticate_lti_user(request, params['user_id'], lti_consumer)
|
||||
|
||||
request.session[LTI_SESSION_KEY] = params
|
||||
|
||||
return lti_run(request)
|
||||
|
||||
|
||||
@login_required
|
||||
def lti_run(request):
|
||||
"""
|
||||
This method can be reached in two ways, and must always follow a POST to
|
||||
lti_launch:
|
||||
- The user was logged in, so this method was called by lti_launch
|
||||
- The user was not logged in, so the login process redirected them back here.
|
||||
|
||||
In either case, the session was populated by lti_launch, so all the required
|
||||
LTI parameters will be stored there. Note that the request passed here may
|
||||
or may not contain the LTI parameters (depending on how the user got here),
|
||||
and so we should only use LTI parameters from the session.
|
||||
|
||||
Users should never call this view directly; if a user attempts to call it
|
||||
without having first gone through lti_launch (and had the LTI parameters
|
||||
stored in the session) they will get a 403 response.
|
||||
"""
|
||||
|
||||
# Check the parameters to make sure that the session is associated with a
|
||||
# valid LTI launch
|
||||
params = restore_params_from_session(request)
|
||||
if not params:
|
||||
# This view has been called without first setting the session
|
||||
return HttpResponseForbidden()
|
||||
# Remove the parameters from the session to prevent replay
|
||||
del request.session[LTI_SESSION_KEY]
|
||||
|
||||
# Store any parameters required by the outcome service in order to report
|
||||
# scores back later. We know that the consumer exists, since the record was
|
||||
# used earlier to verify the oauth signature.
|
||||
lti_consumer = LtiConsumer.get_or_supplement(
|
||||
params.get('tool_consumer_instance_guid', None),
|
||||
params['oauth_consumer_key']
|
||||
)
|
||||
store_outcome_parameters(params, request.user, lti_consumer)
|
||||
|
||||
return render_courseware(request, params['usage_key'])
|
||||
@@ -184,26 +132,6 @@ def get_optional_parameters(dictionary):
|
||||
return {key: dictionary[key] for key in OPTIONAL_PARAMETERS if key in dictionary}
|
||||
|
||||
|
||||
def restore_params_from_session(request):
|
||||
"""
|
||||
Fetch the parameters that were stored in the session by an LTI launch, and
|
||||
verify that all required parameters are present. Missing parameters could
|
||||
indicate that a user has directly called the lti_run endpoint, rather than
|
||||
going through the LTI launch.
|
||||
|
||||
:return: A dictionary of all LTI parameters from the session, or None if
|
||||
any parameters are missing.
|
||||
"""
|
||||
if LTI_SESSION_KEY not in request.session:
|
||||
return None
|
||||
session_params = request.session[LTI_SESSION_KEY]
|
||||
additional_params = ['course_key', 'usage_key']
|
||||
for key in REQUIRED_PARAMETERS + additional_params:
|
||||
if key not in session_params:
|
||||
return None
|
||||
return session_params
|
||||
|
||||
|
||||
def render_courseware(request, usage_key):
|
||||
"""
|
||||
Render the content requested for the LTI launch.
|
||||
|
||||
Reference in New Issue
Block a user