diff --git a/lms/djangoapps/courseware/model_data.py b/lms/djangoapps/courseware/model_data.py index 383b4032bc..5d535bfdf5 100644 --- a/lms/djangoapps/courseware/model_data.py +++ b/lms/djangoapps/courseware/model_data.py @@ -267,6 +267,15 @@ class LmsKeyValueStore(KeyValueStore): If the key isn't found in the expected table during a read or a delete, then a KeyError will be raised """ + + _allowed_scopes = ( + Scope.content, + Scope.settings, + Scope.student_state, + Scope.student_preferences, + Scope.student_info, + Scope.children, + ) def __init__(self, descriptor_model_data, model_data_cache): self._descriptor_model_data = descriptor_model_data self._model_data_cache = model_data_cache @@ -278,6 +287,9 @@ class LmsKeyValueStore(KeyValueStore): if key.scope == Scope.parent: return None + if key.scope not in self._allowed_scopes: + raise InvalidScopeError(key.scope) + field_object = self._model_data_cache.find(key) if field_object is None: raise KeyError(key.field_name) @@ -293,6 +305,9 @@ class LmsKeyValueStore(KeyValueStore): field_object = self._model_data_cache.find_or_create(key) + if key.scope not in self._allowed_scopes: + raise InvalidScopeError(key.scope) + if key.scope == Scope.student_state: state = json.loads(field_object.state) state[key.field_name] = value @@ -306,6 +321,9 @@ class LmsKeyValueStore(KeyValueStore): if key.field_name in self._descriptor_model_data: raise InvalidWriteError("Not allowed to deleted descriptor model data", key.field_name) + if key.scope not in self._allowed_scopes: + raise InvalidScopeError(key.scope) + field_object = self._model_data_cache.find(key) if field_object is None: raise KeyError(key.field_name) diff --git a/lms/djangoapps/courseware/tests/test_model_data.py b/lms/djangoapps/courseware/tests/test_model_data.py index da89412238..13ecf9429d 100644 --- a/lms/djangoapps/courseware/tests/test_model_data.py +++ b/lms/djangoapps/courseware/tests/test_model_data.py @@ -131,8 +131,8 @@ class TestStudentModuleStorage(TestCase): def setUp(self): self.desc_md = {} - self.mdc = Mock() - self.mdc.find.return_value.state = json.dumps({'a_field': 'a_value'}) + self.user = UserFactory.create() + self.mdc = ModelDataCache([mock_descriptor([mock_field(Scope.student_state, 'a_field')])], course_id, self.user) self.kvs = LmsKeyValueStore(self.desc_md, self.mdc) def test_get_existing_field(self):