EDUCATOR-4498 | Allow generate_jwt_signing_key to not include key prefixes.

This commit is contained in:
Alex Dusenbery
2019-07-22 10:56:42 -04:00
committed by Alex Dusenbery
parent 13681eb499
commit 8dbc1c2fcc
2 changed files with 43 additions and 21 deletions

View File

@@ -73,6 +73,12 @@ class Command(BaseCommand):
type=str,
help='Optional YML file in which output should be stored. Needs to be absolute path.',
)
parser.add_argument(
'--strip-key-prefix',
action='store_true',
dest='strip_key_prefix',
help='If set, will not include the "COMMON_" and "EDXAPP_" prefixes on key names.',
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
@@ -95,8 +101,15 @@ class Command(BaseCommand):
options['key_size'],
options['key_id'] or self._generate_key_id(options['key_id_size']),
)
public_keys = self._output_public_keys(jwk_key, options['add_previous_public_keys'])
private_keys = self._output_private_keys(jwk_key)
public_keys = self._output_public_keys(
jwk_key,
options['add_previous_public_keys'],
options['strip_key_prefix']
)
private_keys = self._output_private_keys(
jwk_key,
options['strip_key_prefix']
)
if options['output_file']:
jwt_auth_data = {
'JWT_AUTH': public_keys,
@@ -114,13 +127,16 @@ class Command(BaseCommand):
rsa_jwk = jwk.RSAKey(kid=key_id, key=rsa_key)
return rsa_jwk
def _output_public_keys(self, jwk_key, add_previous):
def _output_public_keys(self, jwk_key, add_previous, strip_prefix):
public_keys = jwk.KEYS()
if add_previous:
self._add_previous_public_keys(public_keys)
public_keys.append(jwk_key)
serialized_public_keys = public_keys.dump_jwks()
prefix = '' if strip_prefix else 'COMMON_'
public_signing_key = '{}JWT_PUBLIC_SIGNING_JWK_SET'.format(prefix)
log.info('New JWT_PUBLIC_SIGNING_JWK_SET: %s.', serialized_public_keys)
print(" ")
print(" ")
@@ -133,10 +149,8 @@ class Command(BaseCommand):
"docs/decisions/0008-use-asymmetric-jwts.rst"
)
print(" ")
print(" COMMON_JWT_PUBLIC_SIGNING_JWK_SET: '{}'".format(serialized_public_keys))
return {
'COMMON_JWT_PUBLIC_SIGNING_JWK_SET': serialized_public_keys,
}
print(" {}: '{}'".format(public_signing_key, serialized_public_keys))
return {public_signing_key: serialized_public_keys}
def _add_previous_public_keys(self, public_keys):
previous_signing_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
@@ -144,9 +158,14 @@ class Command(BaseCommand):
log.info('Old JWT_PUBLIC_SIGNING_JWK_SET: %s.', previous_signing_keys)
public_keys.load_jwks(previous_signing_keys)
def _output_private_keys(self, jwk_key):
def _output_private_keys(self, jwk_key, strip_prefix):
serialized_keypair = jwk_key.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)
prefix = '' if strip_prefix else 'EDXAPP_'
private_signing_key = '{}JWT_PRIVATE_SIGNING_JWK'.format(prefix)
algorithm_key = '{}JWT_SIGNING_ALGORITHM'.format(prefix)
print(" ")
print(" ")
print(" *** YAML to keep PRIVATE within a single authentication service (LMS) ***")
@@ -158,10 +177,10 @@ class Command(BaseCommand):
"docs/decisions/0008-use-asymmetric-jwts.rst"
)
print(" ")
print(" EDXAPP_JWT_PRIVATE_SIGNING_JWK: '{}'".format(serialized_keypair_json))
print(" {}: '{}'".format(private_signing_key, serialized_keypair_json))
print(" ")
print(" EDXAPP_JWT_SIGNING_ALGORITHM: 'RS512'")
print(" {}: 'RS512'".format(algorithm_key))
return {
'EDXAPP_JWT_PRIVATE_SIGNING_JWK': serialized_keypair_json,
'EDXAPP_JWT_SIGNING_ALGORITHM': 'RS512',
private_signing_key: serialized_keypair_json,
algorithm_key: 'RS512',
}

View File

@@ -46,10 +46,12 @@ class TestGenerateJwtSigningKey(TestCase):
)
self.assertEqual(log_message_exists, expected_to_exist)
def _assert_key_output(self, output_stream, filename):
expected_in_output = (
'EDXAPP_JWT_PRIVATE_SIGNING_JWK', 'EDXAPP_JWT_SIGNING_ALGORITHM', 'COMMON_JWT_PUBLIC_SIGNING_JWK_SET'
)
def _assert_key_output(self, output_stream, filename, strip_key_prefix):
expected_in_output = [
'{}JWT_PRIVATE_SIGNING_JWK'.format('' if strip_key_prefix else 'EDXAPP_'),
'{}JWT_SIGNING_ALGORITHM'.format('' if strip_key_prefix else 'EDXAPP_'),
'{}JWT_PUBLIC_SIGNING_JWK_SET'.format('' if strip_key_prefix else 'COMMON_'),
]
for expected in expected_in_output:
self.assertIn(expected, output_stream.getvalue())
@@ -70,12 +72,12 @@ class TestGenerateJwtSigningKey(TestCase):
self.assertEqual(len(key_id), key_id_size or 8)
@ddt.data(
dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=None),
dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=16),
dict(add_previous_public_keys=False, provide_key_id=True, key_id_size=None),
dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=None, strip_key_prefix=True),
dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=16, strip_key_prefix=False),
dict(add_previous_public_keys=False, provide_key_id=True, key_id_size=None, strip_key_prefix=False),
)
@ddt.unpack
def test_command(self, add_previous_public_keys, provide_key_id, key_id_size):
def test_command(self, add_previous_public_keys, provide_key_id, key_id_size, strip_key_prefix):
command_options = dict(add_previous_public_keys=add_previous_public_keys)
if provide_key_id:
command_options['key_id'] = TEST_KEY_IDENTIFIER
@@ -83,12 +85,13 @@ class TestGenerateJwtSigningKey(TestCase):
command_options['key_id_size'] = key_id_size
_, filename = tempfile.mkstemp(suffix='.yml')
command_options['output_file'] = filename
command_options['strip_key_prefix'] = strip_key_prefix
with self._captured_output() as (output_stream, _):
with patch(LOGGER) as mock_log:
call_command(COMMAND_NAME, **command_options)
self._assert_key_output(output_stream, filename)
self._assert_key_output(output_stream, filename, strip_key_prefix)
self._assert_presence_of_old_keys(mock_log, add_previous_public_keys)
self._assert_presence_of_key_id(mock_log, output_stream, provide_key_id, key_id_size)
os.remove(filename)