From 8dbc1c2fcc94cd6da73ffd084353dcd9bb4eefde Mon Sep 17 00:00:00 2001 From: Alex Dusenbery Date: Mon, 22 Jul 2019 10:56:42 -0400 Subject: [PATCH] EDUCATOR-4498 | Allow generate_jwt_signing_key to not include key prefixes. --- .../commands/generate_jwt_signing_key.py | 43 +++++++++++++------ .../tests/test_generate_jwt_signing_key.py | 21 +++++---- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/generate_jwt_signing_key.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/generate_jwt_signing_key.py index 71f22cf020..c10c2663d8 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/generate_jwt_signing_key.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/generate_jwt_signing_key.py @@ -73,6 +73,12 @@ class Command(BaseCommand): type=str, help='Optional YML file in which output should be stored. Needs to be absolute path.', ) + parser.add_argument( + '--strip-key-prefix', + action='store_true', + dest='strip_key_prefix', + help='If set, will not include the "COMMON_" and "EDXAPP_" prefixes on key names.', + ) group = parser.add_mutually_exclusive_group() group.add_argument( @@ -95,8 +101,15 @@ class Command(BaseCommand): options['key_size'], options['key_id'] or self._generate_key_id(options['key_id_size']), ) - public_keys = self._output_public_keys(jwk_key, options['add_previous_public_keys']) - private_keys = self._output_private_keys(jwk_key) + public_keys = self._output_public_keys( + jwk_key, + options['add_previous_public_keys'], + options['strip_key_prefix'] + ) + private_keys = self._output_private_keys( + jwk_key, + options['strip_key_prefix'] + ) if options['output_file']: jwt_auth_data = { 'JWT_AUTH': public_keys, @@ -114,13 +127,16 @@ class Command(BaseCommand): rsa_jwk = jwk.RSAKey(kid=key_id, key=rsa_key) return rsa_jwk - def _output_public_keys(self, jwk_key, add_previous): + def _output_public_keys(self, jwk_key, add_previous, strip_prefix): public_keys = jwk.KEYS() if add_previous: self._add_previous_public_keys(public_keys) public_keys.append(jwk_key) serialized_public_keys = public_keys.dump_jwks() + prefix = '' if strip_prefix else 'COMMON_' + public_signing_key = '{}JWT_PUBLIC_SIGNING_JWK_SET'.format(prefix) + log.info('New JWT_PUBLIC_SIGNING_JWK_SET: %s.', serialized_public_keys) print(" ") print(" ") @@ -133,10 +149,8 @@ class Command(BaseCommand): "docs/decisions/0008-use-asymmetric-jwts.rst" ) print(" ") - print(" COMMON_JWT_PUBLIC_SIGNING_JWK_SET: '{}'".format(serialized_public_keys)) - return { - 'COMMON_JWT_PUBLIC_SIGNING_JWK_SET': serialized_public_keys, - } + print(" {}: '{}'".format(public_signing_key, serialized_public_keys)) + return {public_signing_key: serialized_public_keys} def _add_previous_public_keys(self, public_keys): previous_signing_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET') @@ -144,9 +158,14 @@ class Command(BaseCommand): log.info('Old JWT_PUBLIC_SIGNING_JWK_SET: %s.', previous_signing_keys) public_keys.load_jwks(previous_signing_keys) - def _output_private_keys(self, jwk_key): + def _output_private_keys(self, jwk_key, strip_prefix): serialized_keypair = jwk_key.serialize(private=True) serialized_keypair_json = json.dumps(serialized_keypair) + + prefix = '' if strip_prefix else 'EDXAPP_' + private_signing_key = '{}JWT_PRIVATE_SIGNING_JWK'.format(prefix) + algorithm_key = '{}JWT_SIGNING_ALGORITHM'.format(prefix) + print(" ") print(" ") print(" *** YAML to keep PRIVATE within a single authentication service (LMS) ***") @@ -158,10 +177,10 @@ class Command(BaseCommand): "docs/decisions/0008-use-asymmetric-jwts.rst" ) print(" ") - print(" EDXAPP_JWT_PRIVATE_SIGNING_JWK: '{}'".format(serialized_keypair_json)) + print(" {}: '{}'".format(private_signing_key, serialized_keypair_json)) print(" ") - print(" EDXAPP_JWT_SIGNING_ALGORITHM: 'RS512'") + print(" {}: 'RS512'".format(algorithm_key)) return { - 'EDXAPP_JWT_PRIVATE_SIGNING_JWK': serialized_keypair_json, - 'EDXAPP_JWT_SIGNING_ALGORITHM': 'RS512', + private_signing_key: serialized_keypair_json, + algorithm_key: 'RS512', } diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_generate_jwt_signing_key.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_generate_jwt_signing_key.py index a1bbc93789..24df11d066 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_generate_jwt_signing_key.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_generate_jwt_signing_key.py @@ -46,10 +46,12 @@ class TestGenerateJwtSigningKey(TestCase): ) self.assertEqual(log_message_exists, expected_to_exist) - def _assert_key_output(self, output_stream, filename): - expected_in_output = ( - 'EDXAPP_JWT_PRIVATE_SIGNING_JWK', 'EDXAPP_JWT_SIGNING_ALGORITHM', 'COMMON_JWT_PUBLIC_SIGNING_JWK_SET' - ) + def _assert_key_output(self, output_stream, filename, strip_key_prefix): + expected_in_output = [ + '{}JWT_PRIVATE_SIGNING_JWK'.format('' if strip_key_prefix else 'EDXAPP_'), + '{}JWT_SIGNING_ALGORITHM'.format('' if strip_key_prefix else 'EDXAPP_'), + '{}JWT_PUBLIC_SIGNING_JWK_SET'.format('' if strip_key_prefix else 'COMMON_'), + ] for expected in expected_in_output: self.assertIn(expected, output_stream.getvalue()) @@ -70,12 +72,12 @@ class TestGenerateJwtSigningKey(TestCase): self.assertEqual(len(key_id), key_id_size or 8) @ddt.data( - dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=None), - dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=16), - dict(add_previous_public_keys=False, provide_key_id=True, key_id_size=None), + dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=None, strip_key_prefix=True), + dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=16, strip_key_prefix=False), + dict(add_previous_public_keys=False, provide_key_id=True, key_id_size=None, strip_key_prefix=False), ) @ddt.unpack - def test_command(self, add_previous_public_keys, provide_key_id, key_id_size): + def test_command(self, add_previous_public_keys, provide_key_id, key_id_size, strip_key_prefix): command_options = dict(add_previous_public_keys=add_previous_public_keys) if provide_key_id: command_options['key_id'] = TEST_KEY_IDENTIFIER @@ -83,12 +85,13 @@ class TestGenerateJwtSigningKey(TestCase): command_options['key_id_size'] = key_id_size _, filename = tempfile.mkstemp(suffix='.yml') command_options['output_file'] = filename + command_options['strip_key_prefix'] = strip_key_prefix with self._captured_output() as (output_stream, _): with patch(LOGGER) as mock_log: call_command(COMMAND_NAME, **command_options) - self._assert_key_output(output_stream, filename) + self._assert_key_output(output_stream, filename, strip_key_prefix) self._assert_presence_of_old_keys(mock_log, add_previous_public_keys) self._assert_presence_of_key_id(mock_log, output_stream, provide_key_id, key_id_size) os.remove(filename)