From 99fbf4d6b2c165820f2fb8aa7382aeee565e399e Mon Sep 17 00:00:00 2001 From: Phil McGachey Date: Tue, 23 Jun 2015 14:04:25 -0400 Subject: [PATCH] [LTI Provider] Refactoring to remove the lti_run method This change is a follow-up to the chages in PR 8347, which removed the edX login page from the workflow for a new user. Where previously we redirected a user to the login page, PR 8347 instead creates a new user transparently and logs them in. The initial reason for splitting the LTI view between lti_launch and lti_run was so that there was a target for the GET request that followed the login page. Since we no longer use the login page, we no longer need the second view. We also don't need to store the LTI parameters in the session any more, since they are not persisting between calls. This simplifies the view logic significantly. The other change here is to fetch the LtiConsumer object early in the view, and pass it to the SignatureValidator and scoring system. When the views were split, this required multiple DB hits for the same data; we're now only fetching it once. --- .../lti_provider/signature_validator.py | 12 +- .../tests/test_signature_validator.py | 132 +++++++++--------- .../lti_provider/tests/test_views.py | 117 +--------------- lms/djangoapps/lti_provider/urls.py | 1 - lms/djangoapps/lti_provider/views.py | 106 +++----------- 5 files changed, 92 insertions(+), 276 deletions(-) diff --git a/lms/djangoapps/lti_provider/signature_validator.py b/lms/djangoapps/lti_provider/signature_validator.py index de9f497946..3353db18c1 100644 --- a/lms/djangoapps/lti_provider/signature_validator.py +++ b/lms/djangoapps/lti_provider/signature_validator.py @@ -5,8 +5,6 @@ Subclass of oauthlib's RequestValidator that checks an OAuth signature. from oauthlib.oauth1 import SignatureOnlyEndpoint from oauthlib.oauth1 import RequestValidator -from lti_provider.models import LtiConsumer - class SignatureValidator(RequestValidator): """ @@ -18,9 +16,10 @@ class SignatureValidator(RequestValidator): application-specific requirements. """ - def __init__(self): + def __init__(self, lti_consumer): super(SignatureValidator, self).__init__() self.endpoint = SignatureOnlyEndpoint(self) + self.lti_consumer = lti_consumer # The OAuth signature uses the endpoint URL as part of the request to be # hashed. By default, the oauthlib library rejects any URLs that do not @@ -77,7 +76,7 @@ class SignatureValidator(RequestValidator): :return: True if the key is valid, False if it is not. """ - return LtiConsumer.objects.filter(consumer_key=client_key).count() == 1 + return self.lti_consumer.consumer_key == client_key def get_client_secret(self, client_key, request): """ @@ -87,10 +86,7 @@ class SignatureValidator(RequestValidator): :return: the client secret that corresponds to the supplied key if present, or None if the key does not exist in the database. """ - try: - return LtiConsumer.objects.get(consumer_key=client_key).consumer_secret - except LtiConsumer.DoesNotExist: - return None + return self.lti_consumer.consumer_secret def verify(self, request): """ diff --git a/lms/djangoapps/lti_provider/tests/test_signature_validator.py b/lms/djangoapps/lti_provider/tests/test_signature_validator.py index 4e0e9c03bf..bb12d88d77 100644 --- a/lms/djangoapps/lti_provider/tests/test_signature_validator.py +++ b/lms/djangoapps/lti_provider/tests/test_signature_validator.py @@ -2,6 +2,7 @@ Tests for the SignatureValidator class. """ +import ddt from django.test import TestCase from django.test.client import RequestFactory from mock import patch @@ -10,100 +11,95 @@ from lti_provider.models import LtiConsumer from lti_provider.signature_validator import SignatureValidator -class SignatureValidatorTest(TestCase): +def get_lti_consumer(): """ - Tests for the custom SignatureValidator class that uses the oauthlib library - to check message signatures. Note that these tests mock out the library - itself, since we assume it to be correct. + Helper method for all Signature Validator tests to get an LtiConsumer object. """ + return LtiConsumer( + consumer_name='Consumer Name', + consumer_key='Consumer Key', + consumer_secret='Consumer Secret' + ) + + +@ddt.ddt +class ClientKeyValidatorTest(TestCase): + """ + Tests for the check_client_key method in the SignatureValidator class. + """ + + def setUp(self): + super(ClientKeyValidatorTest, self).setUp() + self.lti_consumer = get_lti_consumer() def test_valid_client_key(self): """ Verify that check_client_key succeeds with a valid key """ - key = 'valid_key' - self.assertTrue(SignatureValidator().check_client_key(key)) + key = self.lti_consumer.consumer_key + self.assertTrue(SignatureValidator(self.lti_consumer).check_client_key(key)) - def test_long_client_key(self): + @ddt.data( + ('0123456789012345678901234567890123456789',), + ('',), + (None,), + ) + @ddt.unpack + def test_invalid_client_key(self, key): """ - Verify that check_client_key fails with a key that is too long + Verify that check_client_key fails with a disallowed key """ - key = '0123456789012345678901234567890123456789' - self.assertFalse(SignatureValidator().check_client_key(key)) + self.assertFalse(SignatureValidator(self.lti_consumer).check_client_key(key)) - def test_empty_client_key(self): - """ - Verify that check_client_key fails with a key that is an empty string - """ - key = '' - self.assertFalse(SignatureValidator().check_client_key(key)) - def test_null_client_key(self): - """ - Verify that check_client_key fails with a key that is None - """ - key = None - self.assertFalse(SignatureValidator().check_client_key(key)) +@ddt.ddt +class NonceValidatorTest(TestCase): + """ + Tests for the check_nonce method in the SignatureValidator class. + """ + + def setUp(self): + super(NonceValidatorTest, self).setUp() + self.lti_consumer = get_lti_consumer() def test_valid_nonce(self): """ Verify that check_nonce succeeds with a key of maximum length """ nonce = '0123456789012345678901234567890123456789012345678901234567890123' - self.assertTrue(SignatureValidator().check_nonce(nonce)) + self.assertTrue(SignatureValidator(self.lti_consumer).check_nonce(nonce)) - def test_long_nonce(self): + @ddt.data( + ('01234567890123456789012345678901234567890123456789012345678901234',), + ('',), + (None,), + ) + @ddt.unpack + def test_invalid_nonce(self, nonce): """ - Verify that check_nonce fails with a key that is too long + Verify that check_nonce fails with badly formatted nonce """ - nonce = '01234567890123456789012345678901234567890123456789012345678901234' - self.assertFalse(SignatureValidator().check_nonce(nonce)) + self.assertFalse(SignatureValidator(self.lti_consumer).check_nonce(nonce)) - def test_empty_nonce(self): - """ - Verify that check_nonce fails with a key that is an empty string - """ - nonce = '' - self.assertFalse(SignatureValidator().check_nonce(nonce)) - def test_null_nonce(self): - """ - Verify that check_nonce fails with a key that is None - """ - nonce = None - self.assertFalse(SignatureValidator().check_nonce(nonce)) - - def test_validate_existing_key(self): - """ - Verify that validate_client_key succeeds if the client key exists in the - database - """ - LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret') - self.assertTrue(SignatureValidator().validate_client_key('client_key', None)) - - def test_validate_missing_key(self): - """ - Verify that validate_client_key fails if the client key is not in the - database - """ - self.assertFalse(SignatureValidator().validate_client_key('client_key', None)) +class SignatureValidatorTest(TestCase): + """ + Tests for the custom SignatureValidator class that uses the oauthlib library + to check message signatures. Note that these tests mock out the library + itself, since we assume it to be correct. + """ + def setUp(self): + super(SignatureValidatorTest, self).setUp() + self.lti_consumer = get_lti_consumer() def test_get_existing_client_secret(self): """ - Verify that get_client_secret returns the right value if the key is in - the database + Verify that get_client_secret returns the right value for the correct + key """ - LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret') - secret = SignatureValidator().get_client_secret('client_key', None) - self.assertEqual(secret, 'client_secret') - - def test_get_missing_client_secret(self): - """ - Verify that get_client_secret returns None if the key is not in the - database - """ - secret = SignatureValidator().get_client_secret('client_key', None) - self.assertIsNone(secret) + key = self.lti_consumer.consumer_key + secret = SignatureValidator(self.lti_consumer).get_client_secret(key, None) + self.assertEqual(secret, self.lti_consumer.consumer_secret) @patch('oauthlib.oauth1.SignatureOnlyEndpoint.validate_request', return_value=(True, None)) @@ -116,6 +112,6 @@ class SignatureValidatorTest(TestCase): content_type = 'application/x-www-form-urlencoded' request = RequestFactory().post('/url', body, content_type=content_type) headers = {'Content-Type': content_type} - SignatureValidator().verify(request) + SignatureValidator(self.lti_consumer).verify(request) verify_mock.assert_called_once_with( request.build_absolute_uri(), 'POST', body, headers) diff --git a/lms/djangoapps/lti_provider/tests/test_views.py b/lms/djangoapps/lti_provider/tests/test_views.py index 671175b2d2..44fc311332 100644 --- a/lms/djangoapps/lti_provider/tests/test_views.py +++ b/lms/djangoapps/lti_provider/tests/test_views.py @@ -57,17 +57,6 @@ def build_launch_request(authenticated=True): return request -def build_run_request(authenticated=True): - """ - Helper method to create a new request object - """ - request = RequestFactory().get('/') - request.user = UserFactory.create() - request.user.is_authenticated = MagicMock(return_value=authenticated) - request.session = {views.LTI_SESSION_KEY: ALL_PARAMS.copy()} - return request - - class LtiTestMixin(object): """ Mixin for LTI tests @@ -144,49 +133,6 @@ class LtiLaunchTest(LtiTestMixin, TestCase): response = views.lti_launch(request, None, None) self.assertEqual(response.status_code, 403) - @patch('lti_provider.views.lti_run') - @patch('lti_provider.views.authenticate_lti_user') - def test_session_contents_after_launch(self, _authenticate, _run): - """ - Verifies that the LTI parameters and the course and usage IDs are - properly stored in the session - """ - request = build_launch_request() - views.lti_launch(request, unicode(COURSE_KEY), unicode(USAGE_KEY)) - session = request.session[views.LTI_SESSION_KEY] - self.assertEqual(session['course_key'], COURSE_KEY, 'Course key not set in the session') - self.assertEqual(session['usage_key'], USAGE_KEY, 'Usage key not set in the session') - for key in views.REQUIRED_PARAMETERS: - self.assertEqual(session[key], request.POST[key], key + ' not set in the session') - - @patch('lti_provider.views.lti_run') - @patch('lti_provider.views.authenticate_lti_user') - def test_optional_parameters_in_session(self, _authenticate, _run): - """ - Verifies that the outcome-related optional LTI parameters are properly - stored in the session - """ - request = build_launch_request() - request.POST.update(LTI_OPTIONAL_PARAMS) - views.lti_launch( - request, - unicode(COURSE_PARAMS['course_key']), - unicode(COURSE_PARAMS['usage_key']) - ) - session = request.session[views.LTI_SESSION_KEY] - self.assertEqual( - session['lis_result_sourcedid'], u'result sourcedid', - 'Result sourcedid not set in the session' - ) - self.assertEqual( - session['lis_outcome_service_url'], u'outcome service URL', - 'Outcome service URL not set in the session' - ) - self.assertEqual( - session['tool_consumer_instance_guid'], u'consumer instance guid', - 'Consumer instance GUID not set in the session' - ) - def test_forbidden_if_signature_fails(self): """ Verifies that the view returns Forbidden if the LTI OAuth signature is @@ -198,71 +144,22 @@ class LtiLaunchTest(LtiTestMixin, TestCase): self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403) - -class LtiRunTest(LtiTestMixin, TestCase): - """ - Tests for the lti_run view - """ - @patch('lti_provider.views.render_courseware') - def test_valid_launch(self, render): - """ - Verifies that the view returns OK if called with the correct context - """ - request = build_run_request() - views.lti_run(request) - render.assert_called_with(request, ALL_PARAMS['usage_key']) - - def test_forbidden_if_session_key_missing(self): - """ - Verifies that the lti_run view returns a Forbidden status if the session - doesn't have an entry for the LTI parameters. - """ - request = build_run_request() - del request.session[views.LTI_SESSION_KEY] - response = views.lti_run(request) - self.assertEqual(response.status_code, 403) - - def test_forbidden_if_session_incomplete(self): - """ - Verifies that the lti_run view returns a Forbidden status if the session - is missing any of the required LTI parameters or course information. - """ - extra_keys = ['course_key', 'usage_key'] - for key in views.REQUIRED_PARAMETERS + extra_keys: - request = build_run_request() - del request.session[views.LTI_SESSION_KEY][key] - response = views.lti_run(request) - self.assertEqual( - response.status_code, - 403, - 'Expected Forbidden response when session is missing ' + key - ) - - @patch('lti_provider.views.render_courseware') - def test_session_cleared_in_view(self, _render): - """ - Verifies that the LTI parameters are cleaned out of the session after - launching the view to prevent a launch being replayed. - """ - request = build_run_request() - views.lti_run(request) - self.assertNotIn(views.LTI_SESSION_KEY, request.session) - @patch('lti_provider.views.render_courseware') def test_lti_consumer_record_supplemented_with_guid(self, _render): - request = build_run_request() - request.session[views.LTI_SESSION_KEY]['tool_consumer_instance_guid'] = 'instance_guid' + SignatureValidator.verify = MagicMock(return_value=False) + request = build_launch_request() + request.POST.update(LTI_OPTIONAL_PARAMS) with self.assertNumQueries(4): - views.lti_run(request) + views.lti_launch(request, None, None) consumer = models.LtiConsumer.objects.get( consumer_key=LTI_DEFAULT_PARAMS['oauth_consumer_key'] ) - self.assertEqual(consumer.instance_guid, 'instance_guid') + self.assertEqual(consumer.instance_guid, u'consumer instance guid') -class LtiRunTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase): +class LtiLaunchTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase): """ - Tests for the rendering returned by lti_run view. + Tests for the rendering returned by lti_launch view. This class overrides the get_response method, which is used by the tests defined in RenderXBlockTestMixin. """ diff --git a/lms/djangoapps/lti_provider/urls.py b/lms/djangoapps/lti_provider/urls.py index def0a70245..1e1843950d 100644 --- a/lms/djangoapps/lti_provider/urls.py +++ b/lms/djangoapps/lti_provider/urls.py @@ -14,5 +14,4 @@ urlpatterns = patterns( usage_id=settings.USAGE_ID_PATTERN ), 'lti_provider.views.lti_launch', name="lti_provider_launch"), - url(r'^lti_run$', 'lti_provider.views.lti_run', name="lti_provider_run"), ) diff --git a/lms/djangoapps/lti_provider/views.py b/lms/djangoapps/lti_provider/views.py index fc7a852aa3..2a461ce550 100644 --- a/lms/djangoapps/lti_provider/views.py +++ b/lms/djangoapps/lti_provider/views.py @@ -33,8 +33,6 @@ OPTIONAL_PARAMETERS = [ 'tool_consumer_instance_guid' ] -LTI_SESSION_KEY = 'lti_provider_parameters' - @csrf_exempt def lti_launch(request, course_id, usage_id): @@ -48,38 +46,32 @@ def lti_launch(request, course_id, usage_id): - The launch contains all the required parameters - The launch data is correctly signed using a known client key/secret pair - - The user is logged into the edX instance - - Authentication in this view is a little tricky, since clients use a POST - with parameters to fetch it. We can't just use @login_required since in the - case where a user is not logged in it will redirect back after login using a - GET request, which would lose all of our LTI parameters. - - Instead, we verify the LTI launch in this view before checking if the user - is logged in, and store the required LTI parameters in the session. Then we - do the authentication check, and if login is required we redirect back to - the lti_run view. If the user is already logged in, we just call that view - directly. """ - if not settings.FEATURES['ENABLE_LTI_PROVIDER']: return HttpResponseForbidden() - # Check the OAuth signature on the message - try: - if not SignatureValidator().verify(request): - return HttpResponseForbidden() - except LtiConsumer.DoesNotExist: - return HttpResponseForbidden() - + # Check the LTI parameters, and return 400 if any required parameters are + # missing params = get_required_parameters(request.POST) if not params: return HttpResponseBadRequest() params.update(get_optional_parameters(request.POST)) - # Store the course, and usage ID in the session to prevent privilege - # escalation if a staff member in one course tries to access material in - # another. + # Get the consumer information from either the instance GUID or the consumer + # key + try: + lti_consumer = LtiConsumer.get_or_supplement( + params.get('tool_consumer_instance_guid', None), + params['oauth_consumer_key'] + ) + except LtiConsumer.DoesNotExist: + return HttpResponseForbidden() + + # Check the OAuth signature on the message + if not SignatureValidator(lti_consumer).verify(request): + return HttpResponseForbidden() + + # Add the course and usage keys to the parameters array try: course_key, usage_key = parse_course_and_usage_keys(course_id, usage_id) except InvalidKeyError: @@ -93,57 +85,13 @@ def lti_launch(request, course_id, usage_id): params['course_key'] = course_key params['usage_key'] = usage_key - try: - lti_consumer = LtiConsumer.get_or_supplement( - params.get('tool_consumer_instance_guid', None), - params['oauth_consumer_key'] - ) - except LtiConsumer.DoesNotExist: - return HttpResponseForbidden() - # Create an edX account if the user identifed by the LTI launch doesn't have # one already, and log the edX account into the platform. authenticate_lti_user(request, params['user_id'], lti_consumer) - request.session[LTI_SESSION_KEY] = params - - return lti_run(request) - - -@login_required -def lti_run(request): - """ - This method can be reached in two ways, and must always follow a POST to - lti_launch: - - The user was logged in, so this method was called by lti_launch - - The user was not logged in, so the login process redirected them back here. - - In either case, the session was populated by lti_launch, so all the required - LTI parameters will be stored there. Note that the request passed here may - or may not contain the LTI parameters (depending on how the user got here), - and so we should only use LTI parameters from the session. - - Users should never call this view directly; if a user attempts to call it - without having first gone through lti_launch (and had the LTI parameters - stored in the session) they will get a 403 response. - """ - - # Check the parameters to make sure that the session is associated with a - # valid LTI launch - params = restore_params_from_session(request) - if not params: - # This view has been called without first setting the session - return HttpResponseForbidden() - # Remove the parameters from the session to prevent replay - del request.session[LTI_SESSION_KEY] - # Store any parameters required by the outcome service in order to report # scores back later. We know that the consumer exists, since the record was # used earlier to verify the oauth signature. - lti_consumer = LtiConsumer.get_or_supplement( - params.get('tool_consumer_instance_guid', None), - params['oauth_consumer_key'] - ) store_outcome_parameters(params, request.user, lti_consumer) return render_courseware(request, params['usage_key']) @@ -184,26 +132,6 @@ def get_optional_parameters(dictionary): return {key: dictionary[key] for key in OPTIONAL_PARAMETERS if key in dictionary} -def restore_params_from_session(request): - """ - Fetch the parameters that were stored in the session by an LTI launch, and - verify that all required parameters are present. Missing parameters could - indicate that a user has directly called the lti_run endpoint, rather than - going through the LTI launch. - - :return: A dictionary of all LTI parameters from the session, or None if - any parameters are missing. - """ - if LTI_SESSION_KEY not in request.session: - return None - session_params = request.session[LTI_SESSION_KEY] - additional_params = ['course_key', 'usage_key'] - for key in REQUIRED_PARAMETERS + additional_params: - if key not in session_params: - return None - return session_params - - def render_courseware(request, usage_key): """ Render the content requested for the LTI launch.