diff --git a/lms/djangoapps/lti_provider/signature_validator.py b/lms/djangoapps/lti_provider/signature_validator.py index de9f497946..3353db18c1 100644 --- a/lms/djangoapps/lti_provider/signature_validator.py +++ b/lms/djangoapps/lti_provider/signature_validator.py @@ -5,8 +5,6 @@ Subclass of oauthlib's RequestValidator that checks an OAuth signature. from oauthlib.oauth1 import SignatureOnlyEndpoint from oauthlib.oauth1 import RequestValidator -from lti_provider.models import LtiConsumer - class SignatureValidator(RequestValidator): """ @@ -18,9 +16,10 @@ class SignatureValidator(RequestValidator): application-specific requirements. """ - def __init__(self): + def __init__(self, lti_consumer): super(SignatureValidator, self).__init__() self.endpoint = SignatureOnlyEndpoint(self) + self.lti_consumer = lti_consumer # The OAuth signature uses the endpoint URL as part of the request to be # hashed. By default, the oauthlib library rejects any URLs that do not @@ -77,7 +76,7 @@ class SignatureValidator(RequestValidator): :return: True if the key is valid, False if it is not. """ - return LtiConsumer.objects.filter(consumer_key=client_key).count() == 1 + return self.lti_consumer.consumer_key == client_key def get_client_secret(self, client_key, request): """ @@ -87,10 +86,7 @@ class SignatureValidator(RequestValidator): :return: the client secret that corresponds to the supplied key if present, or None if the key does not exist in the database. """ - try: - return LtiConsumer.objects.get(consumer_key=client_key).consumer_secret - except LtiConsumer.DoesNotExist: - return None + return self.lti_consumer.consumer_secret def verify(self, request): """ diff --git a/lms/djangoapps/lti_provider/tests/test_signature_validator.py b/lms/djangoapps/lti_provider/tests/test_signature_validator.py index 4e0e9c03bf..bb12d88d77 100644 --- a/lms/djangoapps/lti_provider/tests/test_signature_validator.py +++ b/lms/djangoapps/lti_provider/tests/test_signature_validator.py @@ -2,6 +2,7 @@ Tests for the SignatureValidator class. """ +import ddt from django.test import TestCase from django.test.client import RequestFactory from mock import patch @@ -10,100 +11,95 @@ from lti_provider.models import LtiConsumer from lti_provider.signature_validator import SignatureValidator -class SignatureValidatorTest(TestCase): +def get_lti_consumer(): """ - Tests for the custom SignatureValidator class that uses the oauthlib library - to check message signatures. Note that these tests mock out the library - itself, since we assume it to be correct. + Helper method for all Signature Validator tests to get an LtiConsumer object. """ + return LtiConsumer( + consumer_name='Consumer Name', + consumer_key='Consumer Key', + consumer_secret='Consumer Secret' + ) + + +@ddt.ddt +class ClientKeyValidatorTest(TestCase): + """ + Tests for the check_client_key method in the SignatureValidator class. + """ + + def setUp(self): + super(ClientKeyValidatorTest, self).setUp() + self.lti_consumer = get_lti_consumer() def test_valid_client_key(self): """ Verify that check_client_key succeeds with a valid key """ - key = 'valid_key' - self.assertTrue(SignatureValidator().check_client_key(key)) + key = self.lti_consumer.consumer_key + self.assertTrue(SignatureValidator(self.lti_consumer).check_client_key(key)) - def test_long_client_key(self): + @ddt.data( + ('0123456789012345678901234567890123456789',), + ('',), + (None,), + ) + @ddt.unpack + def test_invalid_client_key(self, key): """ - Verify that check_client_key fails with a key that is too long + Verify that check_client_key fails with a disallowed key """ - key = '0123456789012345678901234567890123456789' - self.assertFalse(SignatureValidator().check_client_key(key)) + self.assertFalse(SignatureValidator(self.lti_consumer).check_client_key(key)) - def test_empty_client_key(self): - """ - Verify that check_client_key fails with a key that is an empty string - """ - key = '' - self.assertFalse(SignatureValidator().check_client_key(key)) - def test_null_client_key(self): - """ - Verify that check_client_key fails with a key that is None - """ - key = None - self.assertFalse(SignatureValidator().check_client_key(key)) +@ddt.ddt +class NonceValidatorTest(TestCase): + """ + Tests for the check_nonce method in the SignatureValidator class. + """ + + def setUp(self): + super(NonceValidatorTest, self).setUp() + self.lti_consumer = get_lti_consumer() def test_valid_nonce(self): """ Verify that check_nonce succeeds with a key of maximum length """ nonce = '0123456789012345678901234567890123456789012345678901234567890123' - self.assertTrue(SignatureValidator().check_nonce(nonce)) + self.assertTrue(SignatureValidator(self.lti_consumer).check_nonce(nonce)) - def test_long_nonce(self): + @ddt.data( + ('01234567890123456789012345678901234567890123456789012345678901234',), + ('',), + (None,), + ) + @ddt.unpack + def test_invalid_nonce(self, nonce): """ - Verify that check_nonce fails with a key that is too long + Verify that check_nonce fails with badly formatted nonce """ - nonce = '01234567890123456789012345678901234567890123456789012345678901234' - self.assertFalse(SignatureValidator().check_nonce(nonce)) + self.assertFalse(SignatureValidator(self.lti_consumer).check_nonce(nonce)) - def test_empty_nonce(self): - """ - Verify that check_nonce fails with a key that is an empty string - """ - nonce = '' - self.assertFalse(SignatureValidator().check_nonce(nonce)) - def test_null_nonce(self): - """ - Verify that check_nonce fails with a key that is None - """ - nonce = None - self.assertFalse(SignatureValidator().check_nonce(nonce)) - - def test_validate_existing_key(self): - """ - Verify that validate_client_key succeeds if the client key exists in the - database - """ - LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret') - self.assertTrue(SignatureValidator().validate_client_key('client_key', None)) - - def test_validate_missing_key(self): - """ - Verify that validate_client_key fails if the client key is not in the - database - """ - self.assertFalse(SignatureValidator().validate_client_key('client_key', None)) +class SignatureValidatorTest(TestCase): + """ + Tests for the custom SignatureValidator class that uses the oauthlib library + to check message signatures. Note that these tests mock out the library + itself, since we assume it to be correct. + """ + def setUp(self): + super(SignatureValidatorTest, self).setUp() + self.lti_consumer = get_lti_consumer() def test_get_existing_client_secret(self): """ - Verify that get_client_secret returns the right value if the key is in - the database + Verify that get_client_secret returns the right value for the correct + key """ - LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret') - secret = SignatureValidator().get_client_secret('client_key', None) - self.assertEqual(secret, 'client_secret') - - def test_get_missing_client_secret(self): - """ - Verify that get_client_secret returns None if the key is not in the - database - """ - secret = SignatureValidator().get_client_secret('client_key', None) - self.assertIsNone(secret) + key = self.lti_consumer.consumer_key + secret = SignatureValidator(self.lti_consumer).get_client_secret(key, None) + self.assertEqual(secret, self.lti_consumer.consumer_secret) @patch('oauthlib.oauth1.SignatureOnlyEndpoint.validate_request', return_value=(True, None)) @@ -116,6 +112,6 @@ class SignatureValidatorTest(TestCase): content_type = 'application/x-www-form-urlencoded' request = RequestFactory().post('/url', body, content_type=content_type) headers = {'Content-Type': content_type} - SignatureValidator().verify(request) + SignatureValidator(self.lti_consumer).verify(request) verify_mock.assert_called_once_with( request.build_absolute_uri(), 'POST', body, headers) diff --git a/lms/djangoapps/lti_provider/tests/test_views.py b/lms/djangoapps/lti_provider/tests/test_views.py index 671175b2d2..44fc311332 100644 --- a/lms/djangoapps/lti_provider/tests/test_views.py +++ b/lms/djangoapps/lti_provider/tests/test_views.py @@ -57,17 +57,6 @@ def build_launch_request(authenticated=True): return request -def build_run_request(authenticated=True): - """ - Helper method to create a new request object - """ - request = RequestFactory().get('/') - request.user = UserFactory.create() - request.user.is_authenticated = MagicMock(return_value=authenticated) - request.session = {views.LTI_SESSION_KEY: ALL_PARAMS.copy()} - return request - - class LtiTestMixin(object): """ Mixin for LTI tests @@ -144,49 +133,6 @@ class LtiLaunchTest(LtiTestMixin, TestCase): response = views.lti_launch(request, None, None) self.assertEqual(response.status_code, 403) - @patch('lti_provider.views.lti_run') - @patch('lti_provider.views.authenticate_lti_user') - def test_session_contents_after_launch(self, _authenticate, _run): - """ - Verifies that the LTI parameters and the course and usage IDs are - properly stored in the session - """ - request = build_launch_request() - views.lti_launch(request, unicode(COURSE_KEY), unicode(USAGE_KEY)) - session = request.session[views.LTI_SESSION_KEY] - self.assertEqual(session['course_key'], COURSE_KEY, 'Course key not set in the session') - self.assertEqual(session['usage_key'], USAGE_KEY, 'Usage key not set in the session') - for key in views.REQUIRED_PARAMETERS: - self.assertEqual(session[key], request.POST[key], key + ' not set in the session') - - @patch('lti_provider.views.lti_run') - @patch('lti_provider.views.authenticate_lti_user') - def test_optional_parameters_in_session(self, _authenticate, _run): - """ - Verifies that the outcome-related optional LTI parameters are properly - stored in the session - """ - request = build_launch_request() - request.POST.update(LTI_OPTIONAL_PARAMS) - views.lti_launch( - request, - unicode(COURSE_PARAMS['course_key']), - unicode(COURSE_PARAMS['usage_key']) - ) - session = request.session[views.LTI_SESSION_KEY] - self.assertEqual( - session['lis_result_sourcedid'], u'result sourcedid', - 'Result sourcedid not set in the session' - ) - self.assertEqual( - session['lis_outcome_service_url'], u'outcome service URL', - 'Outcome service URL not set in the session' - ) - self.assertEqual( - session['tool_consumer_instance_guid'], u'consumer instance guid', - 'Consumer instance GUID not set in the session' - ) - def test_forbidden_if_signature_fails(self): """ Verifies that the view returns Forbidden if the LTI OAuth signature is @@ -198,71 +144,22 @@ class LtiLaunchTest(LtiTestMixin, TestCase): self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403) - -class LtiRunTest(LtiTestMixin, TestCase): - """ - Tests for the lti_run view - """ - @patch('lti_provider.views.render_courseware') - def test_valid_launch(self, render): - """ - Verifies that the view returns OK if called with the correct context - """ - request = build_run_request() - views.lti_run(request) - render.assert_called_with(request, ALL_PARAMS['usage_key']) - - def test_forbidden_if_session_key_missing(self): - """ - Verifies that the lti_run view returns a Forbidden status if the session - doesn't have an entry for the LTI parameters. - """ - request = build_run_request() - del request.session[views.LTI_SESSION_KEY] - response = views.lti_run(request) - self.assertEqual(response.status_code, 403) - - def test_forbidden_if_session_incomplete(self): - """ - Verifies that the lti_run view returns a Forbidden status if the session - is missing any of the required LTI parameters or course information. - """ - extra_keys = ['course_key', 'usage_key'] - for key in views.REQUIRED_PARAMETERS + extra_keys: - request = build_run_request() - del request.session[views.LTI_SESSION_KEY][key] - response = views.lti_run(request) - self.assertEqual( - response.status_code, - 403, - 'Expected Forbidden response when session is missing ' + key - ) - - @patch('lti_provider.views.render_courseware') - def test_session_cleared_in_view(self, _render): - """ - Verifies that the LTI parameters are cleaned out of the session after - launching the view to prevent a launch being replayed. - """ - request = build_run_request() - views.lti_run(request) - self.assertNotIn(views.LTI_SESSION_KEY, request.session) - @patch('lti_provider.views.render_courseware') def test_lti_consumer_record_supplemented_with_guid(self, _render): - request = build_run_request() - request.session[views.LTI_SESSION_KEY]['tool_consumer_instance_guid'] = 'instance_guid' + SignatureValidator.verify = MagicMock(return_value=False) + request = build_launch_request() + request.POST.update(LTI_OPTIONAL_PARAMS) with self.assertNumQueries(4): - views.lti_run(request) + views.lti_launch(request, None, None) consumer = models.LtiConsumer.objects.get( consumer_key=LTI_DEFAULT_PARAMS['oauth_consumer_key'] ) - self.assertEqual(consumer.instance_guid, 'instance_guid') + self.assertEqual(consumer.instance_guid, u'consumer instance guid') -class LtiRunTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase): +class LtiLaunchTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase): """ - Tests for the rendering returned by lti_run view. + Tests for the rendering returned by lti_launch view. This class overrides the get_response method, which is used by the tests defined in RenderXBlockTestMixin. """ diff --git a/lms/djangoapps/lti_provider/urls.py b/lms/djangoapps/lti_provider/urls.py index def0a70245..1e1843950d 100644 --- a/lms/djangoapps/lti_provider/urls.py +++ b/lms/djangoapps/lti_provider/urls.py @@ -14,5 +14,4 @@ urlpatterns = patterns( usage_id=settings.USAGE_ID_PATTERN ), 'lti_provider.views.lti_launch', name="lti_provider_launch"), - url(r'^lti_run$', 'lti_provider.views.lti_run', name="lti_provider_run"), ) diff --git a/lms/djangoapps/lti_provider/views.py b/lms/djangoapps/lti_provider/views.py index fc7a852aa3..2a461ce550 100644 --- a/lms/djangoapps/lti_provider/views.py +++ b/lms/djangoapps/lti_provider/views.py @@ -33,8 +33,6 @@ OPTIONAL_PARAMETERS = [ 'tool_consumer_instance_guid' ] -LTI_SESSION_KEY = 'lti_provider_parameters' - @csrf_exempt def lti_launch(request, course_id, usage_id): @@ -48,38 +46,32 @@ def lti_launch(request, course_id, usage_id): - The launch contains all the required parameters - The launch data is correctly signed using a known client key/secret pair - - The user is logged into the edX instance - - Authentication in this view is a little tricky, since clients use a POST - with parameters to fetch it. We can't just use @login_required since in the - case where a user is not logged in it will redirect back after login using a - GET request, which would lose all of our LTI parameters. - - Instead, we verify the LTI launch in this view before checking if the user - is logged in, and store the required LTI parameters in the session. Then we - do the authentication check, and if login is required we redirect back to - the lti_run view. If the user is already logged in, we just call that view - directly. """ - if not settings.FEATURES['ENABLE_LTI_PROVIDER']: return HttpResponseForbidden() - # Check the OAuth signature on the message - try: - if not SignatureValidator().verify(request): - return HttpResponseForbidden() - except LtiConsumer.DoesNotExist: - return HttpResponseForbidden() - + # Check the LTI parameters, and return 400 if any required parameters are + # missing params = get_required_parameters(request.POST) if not params: return HttpResponseBadRequest() params.update(get_optional_parameters(request.POST)) - # Store the course, and usage ID in the session to prevent privilege - # escalation if a staff member in one course tries to access material in - # another. + # Get the consumer information from either the instance GUID or the consumer + # key + try: + lti_consumer = LtiConsumer.get_or_supplement( + params.get('tool_consumer_instance_guid', None), + params['oauth_consumer_key'] + ) + except LtiConsumer.DoesNotExist: + return HttpResponseForbidden() + + # Check the OAuth signature on the message + if not SignatureValidator(lti_consumer).verify(request): + return HttpResponseForbidden() + + # Add the course and usage keys to the parameters array try: course_key, usage_key = parse_course_and_usage_keys(course_id, usage_id) except InvalidKeyError: @@ -93,57 +85,13 @@ def lti_launch(request, course_id, usage_id): params['course_key'] = course_key params['usage_key'] = usage_key - try: - lti_consumer = LtiConsumer.get_or_supplement( - params.get('tool_consumer_instance_guid', None), - params['oauth_consumer_key'] - ) - except LtiConsumer.DoesNotExist: - return HttpResponseForbidden() - # Create an edX account if the user identifed by the LTI launch doesn't have # one already, and log the edX account into the platform. authenticate_lti_user(request, params['user_id'], lti_consumer) - request.session[LTI_SESSION_KEY] = params - - return lti_run(request) - - -@login_required -def lti_run(request): - """ - This method can be reached in two ways, and must always follow a POST to - lti_launch: - - The user was logged in, so this method was called by lti_launch - - The user was not logged in, so the login process redirected them back here. - - In either case, the session was populated by lti_launch, so all the required - LTI parameters will be stored there. Note that the request passed here may - or may not contain the LTI parameters (depending on how the user got here), - and so we should only use LTI parameters from the session. - - Users should never call this view directly; if a user attempts to call it - without having first gone through lti_launch (and had the LTI parameters - stored in the session) they will get a 403 response. - """ - - # Check the parameters to make sure that the session is associated with a - # valid LTI launch - params = restore_params_from_session(request) - if not params: - # This view has been called without first setting the session - return HttpResponseForbidden() - # Remove the parameters from the session to prevent replay - del request.session[LTI_SESSION_KEY] - # Store any parameters required by the outcome service in order to report # scores back later. We know that the consumer exists, since the record was # used earlier to verify the oauth signature. - lti_consumer = LtiConsumer.get_or_supplement( - params.get('tool_consumer_instance_guid', None), - params['oauth_consumer_key'] - ) store_outcome_parameters(params, request.user, lti_consumer) return render_courseware(request, params['usage_key']) @@ -184,26 +132,6 @@ def get_optional_parameters(dictionary): return {key: dictionary[key] for key in OPTIONAL_PARAMETERS if key in dictionary} -def restore_params_from_session(request): - """ - Fetch the parameters that were stored in the session by an LTI launch, and - verify that all required parameters are present. Missing parameters could - indicate that a user has directly called the lti_run endpoint, rather than - going through the LTI launch. - - :return: A dictionary of all LTI parameters from the session, or None if - any parameters are missing. - """ - if LTI_SESSION_KEY not in request.session: - return None - session_params = request.session[LTI_SESSION_KEY] - additional_params = ['course_key', 'usage_key'] - for key in REQUIRED_PARAMETERS + additional_params: - if key not in session_params: - return None - return session_params - - def render_courseware(request, usage_key): """ Render the content requested for the LTI launch.