diff --git a/lms/djangoapps/learner_home/serializers.py b/lms/djangoapps/learner_home/serializers.py index 1154c125c5..4eb37e50e0 100644 --- a/lms/djangoapps/learner_home/serializers.py +++ b/lms/djangoapps/learner_home/serializers.py @@ -508,8 +508,10 @@ class UnfulfilledEntitlementSerializer(serializers.Serializer): """ If this entitlement is part of a program, include information about the program and related programs """ - programs = self.context['programs'].get(str(instance.course_uuid), []) - return ProgramsSerializer({"relatedPrograms": programs}, context=self.context).data + programs = self.context["programs"].get(str(instance.course_uuid), []) + return ProgramsSerializer( + {"relatedPrograms": programs}, context=self.context + ).data class SuggestedCourseSerializer(serializers.Serializer): diff --git a/lms/djangoapps/learner_home/test_serializers.py b/lms/djangoapps/learner_home/test_serializers.py index 74b90b0690..7778171a37 100644 --- a/lms/djangoapps/learner_home/test_serializers.py +++ b/lms/djangoapps/learner_home/test_serializers.py @@ -801,7 +801,7 @@ class TestUnfulfilledEntitlementSerializer(LearnerDashboardBaseTest): """High-level tests for UnfulfilledEntitlementSerializer""" def make_unfulfilled_entitlement(self): - """ Create an unfulflled entitlement, along with a pseudo session and available sessions""" + """Create an unfulflled entitlement, along with a pseudo session and available sessions""" unfulfilled_entitlement = CourseEntitlementFactory.create() pseudo_sessions = { str(unfulfilled_entitlement.uuid): CatalogCourseRunFactory.create() @@ -811,8 +811,10 @@ class TestUnfulfilledEntitlementSerializer(LearnerDashboardBaseTest): } return unfulfilled_entitlement, pseudo_sessions, available_sessions - def make_pseudo_session_course_overviews(self, unfulfilled_entitlement, pseudo_sessions): - """ Create course overview for course provider info """ + def make_pseudo_session_course_overviews( + self, unfulfilled_entitlement, pseudo_sessions + ): + """Create course overview for course provider info""" course_key_str = pseudo_sessions[str(unfulfilled_entitlement.uuid)]["key"] course_key = CourseKey.from_string(course_key_str) course_overview = CourseOverviewFactory.create(id=course_key) @@ -820,16 +822,19 @@ class TestUnfulfilledEntitlementSerializer(LearnerDashboardBaseTest): def test_happy_path(self): """Test that nothing breaks and the output fields look correct""" - unfulfilled_entitlement, pseudo_sessions, available_sessions = self.make_unfulfilled_entitlement() - pseudo_session_course_overviews = self.make_pseudo_session_course_overviews( + ( unfulfilled_entitlement, - pseudo_sessions + pseudo_sessions, + available_sessions, + ) = self.make_unfulfilled_entitlement() + pseudo_session_course_overviews = self.make_pseudo_session_course_overviews( + unfulfilled_entitlement, pseudo_sessions ) context = { "unfulfilled_entitlement_pseudo_sessions": pseudo_sessions, "course_entitlement_available_sessions": available_sessions, "pseudo_session_course_overviews": pseudo_session_course_overviews, - "programs": {} + "programs": {}, } output_data = UnfulfilledEntitlementSerializer( @@ -861,30 +866,32 @@ class TestUnfulfilledEntitlementSerializer(LearnerDashboardBaseTest): assert output_data["programs"] == {"relatedPrograms": []} def test_programs(self): - unfulfilled_entitlement, pseudo_sessions, available_sessions = self.make_unfulfilled_entitlement() - pseudo_session_course_overviews = self.make_pseudo_session_course_overviews( + ( unfulfilled_entitlement, - pseudo_sessions + pseudo_sessions, + available_sessions, + ) = self.make_unfulfilled_entitlement() + pseudo_session_course_overviews = self.make_pseudo_session_course_overviews( + unfulfilled_entitlement, pseudo_sessions ) related_programs = ProgramFactory.create_batch(3) - programs = { - str(unfulfilled_entitlement.course_uuid): related_programs - } + programs = {str(unfulfilled_entitlement.course_uuid): related_programs} context = { "unfulfilled_entitlement_pseudo_sessions": pseudo_sessions, "course_entitlement_available_sessions": available_sessions, "pseudo_session_course_overviews": pseudo_session_course_overviews, - "programs": programs + "programs": programs, } output_data = UnfulfilledEntitlementSerializer( unfulfilled_entitlement, context=context ).data - assert output_data["programs"] == ProgramsSerializer( - {"relatedPrograms": related_programs} - ).data + assert ( + output_data["programs"] + == ProgramsSerializer({"relatedPrograms": related_programs}).data + ) def test_static_enrollment_data(self): """ diff --git a/lms/djangoapps/learner_home/test_views.py b/lms/djangoapps/learner_home/test_views.py index ca78accec7..ec013e3f13 100644 --- a/lms/djangoapps/learner_home/test_views.py +++ b/lms/djangoapps/learner_home/test_views.py @@ -3,7 +3,8 @@ from contextlib import contextmanager import json from unittest import TestCase -from unittest.mock import patch +from unittest.mock import Mock, patch +from urllib.parse import urlencode from uuid import uuid4 import ddt @@ -26,6 +27,7 @@ from lms.djangoapps.learner_home.views import ( get_course_programs, get_email_settings_info, get_enrollments, + get_enterprise_customer, get_platform_settings, get_suggested_courses, get_user_account_confirmation_info, @@ -267,13 +269,13 @@ class TestGetEntitlements(SharedModuleStoreTestCase): with self.mock_get_filtered_course_entitlements([], {}, {}): ( fulfilled_entitlements_by_course_key, - unfulfulled_entitlements, + unfulfilled_entitlements, course_entitlement_available_sessions, unfulfilled_entitlement_pseudo_sessions, ) = get_entitlements(self.user, None, None) assert not fulfilled_entitlements_by_course_key - assert not unfulfulled_entitlements + assert not unfulfilled_entitlements assert not course_entitlement_available_sessions assert not unfulfilled_entitlement_pseudo_sessions @@ -394,8 +396,30 @@ class TestGetSuggestedCourses(SharedModuleStoreTestCase): self.assertDictEqual(return_data, self.EMPTY_SUGGESTED_COURSES) -class TestDashboardView(SharedModuleStoreTestCase, APITestCase): - """Tests for the dashboard view""" +@ddt.ddt +class TestGetEnterpriseCustomer(TestCase): + """Test for get_enterprise_customer""" + + @ddt.data(True, False) + @patch("lms.djangoapps.learner_home.views.get_enterprise_learner_data_from_db") + @patch( + "lms.djangoapps.learner_home.views.enterprise_customer_from_session_or_learner_data" + ) + def test_get_enterprise_customer( + self, is_masquerading, mock_get_from_session, mock_get_from_db + ): + """Don't load the user from session if we're masquerading, load directly from db""" + user, request = Mock(), Mock() + result = get_enterprise_customer(user, request, is_masquerading) + if is_masquerading: + assert not mock_get_from_session.called + assert result is mock_get_from_db.return_value[0]["enterprise_customer"] + else: + assert result is mock_get_from_session.return_value + + +class BaseTestDashboardView(SharedModuleStoreTestCase, APITestCase): + """Base class for test setup""" MODULESTORE = TEST_DATA_SPLIT_MODULESTORE @@ -413,9 +437,16 @@ class TestDashboardView(SharedModuleStoreTestCase, APITestCase): # Set up a user cls.username = "alan" cls.password = "enigma" - cls.user = UserFactory(username=cls.username, password=cls.password) + + cls.user = UserFactory( + username=cls.username, password=cls.password, is_staff=False + ) cls.site = SiteFactory() + +class TestDashboardView(BaseTestDashboardView): + """Tests for the dashboard view""" + def log_in(self): """Log in as a test user""" self.client.login(username=self.username, password=self.password) @@ -590,3 +621,132 @@ class TestDashboardView(SharedModuleStoreTestCase, APITestCase): assert len(data) == len(programs) assert programs[course_uuid][0] == program assert programs[course_uuid2][0] == program2 + + +class TestDashboardMasquerade(BaseTestDashboardView): + """Tests for the masquerade function for the learner home""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.staff_username = "sudo_alan" + cls.user_2_username = "Alan II" + cls.staff_user = UserFactory( + username=cls.staff_username, password=cls.password, is_staff=True + ) + cls.user_2 = UserFactory.create( + username=cls.user_2_username, password=cls.password, is_staff=False + ) + cls.user_1_enrollment = create_test_enrollment(cls.user) + cls.user_2_enrollment = create_test_enrollment(cls.user_2) + cls.staff_user_enrollment = create_test_enrollment(cls.staff_user) + + def log_in(self, user): + """Log in as the given user""" + self.client.login(username=user.username, password=self.password) + + def get_first_course_id(self, response): + """Get the first course id from a dashboard init response""" + return response.json()["courses"][0]["courseRun"]["courseId"] + + def get(self, user=None): + """Make a get request to the dashboard init view""" + if user: + params = {"user": user} + url_params = "/?" + urlencode(params) + else: + url_params = "" + url = self.view_url + url_params + return self.client.get(url) + + def test_no_student_access(self): + # If I log in as a student, not staff + self.log_in(self.user) + + # I get my own dashboard info while not masquerading + response = self.get() + assert response.status_code == 200 + assert self.get_first_course_id(response) == str( + self.user_1_enrollment.course_id + ) + + # If I try to masquerade as another user I get a 403 + response = self.get(self.user_2.username) + assert response.status_code == 403 + + # Even if I try to masquerade as myself I get a 403 + response = self.get(self.user.username) + assert response.status_code == 403 + + def test_staff_user(self): + # If I log in as site staff + self.log_in(self.staff_user) + + # I get my own dashboard info while not masquerading + response = self.get() + assert response.status_code == 200 + assert self.get_first_course_id(response) == str( + self.staff_user_enrollment.course_id + ) + + # I can also get other users' dashboard info by masquerading + response = self.get(self.user.username) + assert response.status_code == 200 + assert self.get_first_course_id(response) == str( + self.user_1_enrollment.course_id + ) + + response = self.get(self.user_2.username) + assert response.status_code == 200 + assert self.get_first_course_id(response) == str( + self.user_2_enrollment.course_id + ) + + def test_nonexistent_user__staff(self): + # If I log in as course staff + self.log_in(self.staff_user) + + # If I request to masquerade a nonexistent user I get a 404 + response = self.get(str(uuid4())) + assert response.status_code == 404 + + def test_nonexistent_user__student(self): + # If I log in as a non-staff user + self.log_in(self.user) + + # If I request to masquerade a nonexistent user I get a 403 + response = self.get(str(uuid4())) + assert response.status_code == 403 + + def test_get_user_by_email(self): + # If log in as a staff user + self.log_in(self.staff_user) + + # I can masquerade as a user by providing their email + response = self.get(self.user.email) + assert response.status_code == 200 + assert self.get_first_course_id(response) == str( + self.user_1_enrollment.course_id + ) + + response = self.get(self.user_2.email) + assert response.status_code == 200 + assert self.get_first_course_id(response) == str( + self.user_2_enrollment.course_id + ) + + def test_user_email_collision(self): + # If log in as a staff user + self.log_in(self.staff_user) + + # and we have a user whose username is the same as another user's email + user_3 = UserFactory(username=self.user_2.email) + assert user_3.username == self.user_2.email + user_3_enrollment = create_test_enrollment(user_3) + + # when a staff user masquerades as that value + response = self.get(user_3.username) + + # username has priority in the lookup + assert response.status_code == 200 + assert self.get_first_course_id(response) == str(user_3_enrollment.course_id) diff --git a/lms/djangoapps/learner_home/urls.py b/lms/djangoapps/learner_home/urls.py index 025c9a578d..3a57ac5e2e 100644 --- a/lms/djangoapps/learner_home/urls.py +++ b/lms/djangoapps/learner_home/urls.py @@ -1,6 +1,6 @@ """Learner home URL routing configuration""" -from django.urls import path +from django.urls import re_path from lms.djangoapps.learner_home import mock_views, views @@ -8,6 +8,8 @@ app_name = "learner_home" # Learner Dashboard Routing urlpatterns = [ - path("init", views.InitializeView.as_view(), name="initialize"), - path("mock/init", mock_views.InitializeView.as_view(), name="mock_initialize"), + re_path(r"init/?", views.InitializeView.as_view(), name="initialize"), + re_path( + r"mock/init/?", mock_views.InitializeView.as_view(), name="mock_initialize" + ), ] diff --git a/lms/djangoapps/learner_home/views.py b/lms/djangoapps/learner_home/views.py index fd06aecbb4..3fdfdffe8d 100644 --- a/lms/djangoapps/learner_home/views.py +++ b/lms/djangoapps/learner_home/views.py @@ -1,15 +1,20 @@ """ Views for the learner dashboard. """ +import logging from django.conf import settings +from django.contrib.auth import get_user_model +from django.core.exceptions import MultipleObjectsReturned from edx_django_utils import monitoring as monitoring_utils from opaque_keys.edx.keys import CourseKey +from rest_framework.exceptions import PermissionDenied, NotFound from rest_framework.response import Response from rest_framework.generics import RetrieveAPIView from common.djangoapps.course_modes.models import CourseMode from common.djangoapps.edxmako.shortcuts import marketing_link from common.djangoapps.student.helpers import cert_info, get_resume_urls_for_enrollments +from common.djangoapps.student.models import get_user_by_username_or_email from common.djangoapps.student.views.dashboard import ( complete_course_mode_info, get_course_enrollments, @@ -32,8 +37,12 @@ from openedx.core.djangoapps.site_configuration import helpers as configuration_ from openedx.core.djangoapps.programs.utils import ProgramProgressMeter from openedx.features.enterprise_support.api import ( enterprise_customer_from_session_or_learner_data, + get_enterprise_learner_data_from_db, ) +logger = logging.getLogger(__name__) +User = get_user_model() + def get_platform_settings(): """Get settings used for platform level connections: emails, url routes, etc.""" @@ -164,6 +173,18 @@ def get_email_settings_info(user, course_enrollments): return show_email_settings_for, course_optouts +def get_enterprise_customer(user, request, is_masquerading): + """ + If we are not masquerading, try to load the enterprise learner from session data, falling back to the db. + If we are masquerading, don't read or write to/from session data, go directly to db. + """ + if is_masquerading: + learner_data = get_enterprise_learner_data_from_db(user) + return learner_data[0]["enterprise_customer"] if learner_data else None + else: + return enterprise_customer_from_session_or_learner_data(request) + + def get_ecommerce_payment_page(user): """Determine the ecommerce payment page URL if enabled for this user""" ecommerce_service = EcommerceService() @@ -248,7 +269,9 @@ def get_course_programs(user, course_enrollments, site): } } """ - meter = ProgramProgressMeter(site, user, enrollments=course_enrollments, include_course_entitlements=True) + meter = ProgramProgressMeter( + site, user, enrollments=course_enrollments, include_course_entitlements=True + ) return meter.invert_programs() @@ -269,13 +292,44 @@ class InitializeView(RetrieveAPIView): # pylint: disable=unused-argument """List of courses a user is enrolled in or entitled to""" def get(self, request, *args, **kwargs): # pylint: disable=unused-argument - # Get user, determine if user needs to confirm email account - user = request.user - site = request.site + if request.GET.get("user"): + if not request.user.is_staff: + logger.info( + f"[Learner Home] {request.user.username} attempted to masquerade but is not staff" + ) + raise PermissionDenied() + + masquerade_identifier = request.GET.get("user") + try: + masquerade_user = get_user_by_username_or_email(masquerade_identifier) + except User.DoesNotExist: + raise NotFound() # pylint: disable=raise-missing-from + except MultipleObjectsReturned: + msg = ( + f"[Learner Home] {masquerade_identifier} could refer to multiple learners. " + " Defaulting to username." + ) + logger.info(msg) + masquerade_user = User.objects.get(username=masquerade_identifier) + + success_msg = ( + f"[Learner Home] {request.user.username} masquerades as " + f"{masquerade_user.username} - {masquerade_user.email}" + ) + logger.info(success_msg) + return self._initialize(masquerade_user, True) + else: + return self._initialize(request.user, False) + + def _initialize(self, user, is_masquerade): + """ + Load information required for displaying the learner home + """ + # Determine if user needs to confirm email account email_confirmation = get_user_account_confirmation_info(user) # Gather info for enterprise dashboard - enterprise_customer = enterprise_customer_from_session_or_learner_data(request) + enterprise_customer = get_enterprise_customer(user, self.request, is_masquerade) # Get the org whitelist or the org blacklist for the current site site_org_whitelist, site_org_blacklist = get_org_black_and_whitelist_for_site() @@ -308,7 +362,7 @@ class InitializeView(RetrieveAPIView): # pylint: disable=unused-argument course_access_checks = check_course_access(user, course_enrollments) # Get programs related to the courses the user is enrolled in - programs = get_course_programs(user, course_enrollments, site) + programs = get_course_programs(user, course_enrollments, self.request.site) # e-commerce info ecommerce_payment_page = get_ecommerce_payment_page(user)