diff --git a/common/djangoapps/student/roles.py b/common/djangoapps/student/roles.py index c829cf5b9d..321c769322 100644 --- a/common/djangoapps/student/roles.py +++ b/common/djangoapps/student/roles.py @@ -7,6 +7,7 @@ adding users, removing users, and listing members import logging from abc import ABCMeta, abstractmethod from collections import defaultdict +from contextlib import contextmanager from django.contrib.auth.models import User # lint-amnesty, pylint: disable=imported-auth-user from opaque_keys.edx.django.models import CourseKeyField @@ -44,6 +45,17 @@ def register_access_role(cls): return cls +@contextmanager +def strict_role_checking(): + """ + Context manager that temporarily disables role inheritance. + """ + OLD_ACCESS_ROLES_INHERITANCE = ACCESS_ROLES_INHERITANCE.copy() + ACCESS_ROLES_INHERITANCE.clear() + yield + ACCESS_ROLES_INHERITANCE.update(OLD_ACCESS_ROLES_INHERITANCE) + + class BulkRoleCache: # lint-amnesty, pylint: disable=missing-class-docstring CACHE_NAMESPACE = "student.roles.BulkRoleCache" CACHE_KEY = 'roles_by_user' diff --git a/lms/djangoapps/courseware/rules.py b/lms/djangoapps/courseware/rules.py index 8202418f6f..07cbbab902 100644 --- a/lms/djangoapps/courseware/rules.py +++ b/lms/djangoapps/courseware/rules.py @@ -18,7 +18,7 @@ from xblock.core import XBlock from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from openedx.core.djangoapps.enrollments.api import is_enrollment_valid_for_proctoring from common.djangoapps.student.models import CourseAccessRole -from common.djangoapps.student.roles import CourseRole, OrgRole +from common.djangoapps.student.roles import CourseRole, OrgRole, strict_role_checking from xmodule.course_block import CourseBlock # lint-amnesty, pylint: disable=wrong-import-order from xmodule.error_block import ErrorBlock # lint-amnesty, pylint: disable=wrong-import-order @@ -47,10 +47,14 @@ class HasAccessRule(Rule): """ A rule that calls `has_access` to determine whether it passes """ - def __init__(self, action): + def __init__(self, action, strict=False): self.action = action + self.strict = strict def check(self, user, instance=None): + if self.strict: + with strict_role_checking(): + return has_access(user, self.action, instance) return has_access(user, self.action, instance) def query(self, user):