diff --git a/common/djangoapps/util/enterprise_helpers.py b/common/djangoapps/util/enterprise_helpers.py index 627c5cffb1..5d996b437a 100644 --- a/common/djangoapps/util/enterprise_helpers.py +++ b/common/djangoapps/util/enterprise_helpers.py @@ -3,9 +3,11 @@ Helpers to access the enterprise app """ from django.conf import settings from django.utils.translation import ugettext as _ +import logging try: from enterprise.models import EnterpriseCustomer + from enterprise import api as enterprise_api from enterprise.tpa_pipeline import ( active_provider_requests_data_sharing, active_provider_enforces_data_sharing, @@ -16,6 +18,8 @@ except ImportError: pass from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers +ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS = 'enterprise_customer_branding_override_details' +LOGGER = logging.getLogger("edx.enterprise_helpers") def enterprise_enabled(): @@ -120,3 +124,60 @@ def insert_enterprise_pipeline_elements(pipeline): for index, element in enumerate(additional_elements): pipeline.insert(insert_point + index, element) + + +def get_enterprise_customer_logo_url(request): + """ + Client API operation adapter/wrapper. + """ + + if not enterprise_enabled(): + return None + + parameter = get_enterprise_branding_filter_param(request) + if not parameter: + return None + + provider_id = parameter.get('provider_id', None) + ec_uuid = parameter.get('ec_uuid', None) + + if provider_id: + branding_info = enterprise_api.get_enterprise_branding_info_by_provider_id(provider_id=provider_id) + elif ec_uuid: + branding_info = enterprise_api.get_enterprise_branding_info_by_ec_uuid(ec_uuid=ec_uuid) + + logo_url = None + if branding_info and branding_info.logo: + logo_url = branding_info.logo.url + + return logo_url + + +def set_enterprise_branding_filter_param(request, provider_id): + """ + Setting 'ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS' in session. 'ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS' + either be provider_id or ec_uuid. e.g. {provider_id: 'xyz'} or {ec_src: enterprise_customer_uuid} + """ + ec_uuid = request.GET.get('ec_src', None) + if provider_id: + LOGGER.info( + "Session key 'ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS' has been set with provider_id '%s'", + provider_id + ) + request.session[ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS] = {'provider_id': provider_id} + + elif ec_uuid: + # we are assuming that none sso based enterprise will return Enterprise Customer uuid as 'ec_src' in query + # param e.g. edx.org/foo/bar?ec_src=6185ed46-68a4-45d6-8367-96c0bf70d1a6 + LOGGER.info( + "Session key 'ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS' has been set with ec_uuid '%s'", ec_uuid + ) + request.session[ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS] = {'ec_uuid': ec_uuid} + + +def get_enterprise_branding_filter_param(request): + """ + :return Filter parameter from session for enterprise customer branding information. + + """ + return request.session.get(ENTERPRISE_CUSTOMER_BRANDING_OVERRIDE_DETAILS, None) diff --git a/common/djangoapps/util/tests/test_enterprise_helpers.py b/common/djangoapps/util/tests/test_enterprise_helpers.py index 76ea516876..5ec40cc016 100644 --- a/common/djangoapps/util/tests/test_enterprise_helpers.py +++ b/common/djangoapps/util/tests/test_enterprise_helpers.py @@ -12,7 +12,10 @@ from util.enterprise_helpers import ( data_sharing_consent_required_at_login, data_sharing_consent_requirement_at_login, insert_enterprise_fields, - insert_enterprise_pipeline_elements + insert_enterprise_pipeline_elements, + set_enterprise_branding_filter_param, + get_enterprise_branding_filter_param, + get_enterprise_customer_logo_url ) @@ -144,3 +147,68 @@ class TestEnterpriseHelpers(unittest.TestCase): mock_ec.requests_data_sharing_consent = False insert_enterprise_fields(request, form_desc) form_desc.add_field.assert_not_called() + + def test_set_enterprise_branding_filter_param(self): + """ + Test that the enterprise customer branding parameters are setting correctly. + """ + ec_uuid = '97b4a894-cea9-4103-8f9f-2c5c95a58ba3' + provider_id = 'test-provider-idp' + + request = mock.MagicMock(session={}, GET={'ec_src': ec_uuid}) + set_enterprise_branding_filter_param(request, provider_id=None) + self.assertEqual(get_enterprise_branding_filter_param(request), {'ec_uuid': ec_uuid}) + + set_enterprise_branding_filter_param(request, provider_id=provider_id) + self.assertEqual(get_enterprise_branding_filter_param(request), {'provider_id': provider_id}) + + @mock.patch('util.enterprise_helpers.enterprise_enabled', mock.Mock(return_value=True)) + def test_get_enterprise_customer_logo_url(self): + """ + Test test_get_enterprise_customer_logo_url return the logo url as desired. + """ + ec_uuid = '97b4a894-cea9-4103-8f9f-2c5c95a58ba3' + provider_id = 'test-provider-idp' + request = mock.MagicMock(session={}, GET={'ec_src': ec_uuid}) + branding_info = mock.Mock( + logo=mock.Mock( + url='/test/image.png' + ) + ) + + set_enterprise_branding_filter_param(request, provider_id=None) + with mock.patch('enterprise.api.get_enterprise_branding_info_by_ec_uuid', return_value=branding_info): + logo_url = get_enterprise_customer_logo_url(request) + self.assertEqual(logo_url, '/test/image.png') + + set_enterprise_branding_filter_param(request, provider_id) + with mock.patch('enterprise.api.get_enterprise_branding_info_by_provider_id', return_value=branding_info): + logo_url = get_enterprise_customer_logo_url(request) + self.assertEqual(logo_url, '/test/image.png') + + @mock.patch('util.enterprise_helpers.enterprise_enabled', mock.Mock(return_value=False)) + def test_get_enterprise_customer_logo_url_return_none(self): + """ + Test get_enterprise_customer_logo_url return 'None' when enterprise application is not installed. + """ + request = mock.MagicMock(session={}) + branding_info = mock.Mock() + + set_enterprise_branding_filter_param(request, 'test-idp') + with mock.patch('enterprise.api.get_enterprise_branding_info_by_provider_id', return_value=branding_info): + logo_url = get_enterprise_customer_logo_url(request) + self.assertEqual(logo_url, None) + + @mock.patch('util.enterprise_helpers.enterprise_enabled', mock.Mock(return_value=True)) + @mock.patch('util.enterprise_helpers.get_enterprise_branding_filter_param', mock.Mock(return_value=None)) + def test_get_enterprise_customer_logo_url_return_none_when_param_missing(self): + """ + Test get_enterprise_customer_logo_url return 'None' when filter parameters are missing. + """ + request = mock.MagicMock(session={}) + branding_info = mock.Mock() + + set_enterprise_branding_filter_param(request, provider_id=None) + with mock.patch('enterprise.api.get_enterprise_branding_info_by_provider_id', return_value=branding_info): + logo_url = get_enterprise_customer_logo_url(request) + self.assertEqual(logo_url, None) diff --git a/lms/djangoapps/ccx/tests/test_field_override_performance.py b/lms/djangoapps/ccx/tests/test_field_override_performance.py index 6d7e8ddb91..1582c63211 100644 --- a/lms/djangoapps/ccx/tests/test_field_override_performance.py +++ b/lms/djangoapps/ccx/tests/test_field_override_performance.py @@ -60,6 +60,7 @@ class FieldOverridePerformanceTestCase(FieldOverrideTestMixin, ProceduralCourseT self.request_factory = RequestFactory() self.student = UserFactory.create() self.request = self.request_factory.get("foo") + self.request.session = {} self.request.user = self.student patcher = mock.patch('edxmako.request_context.get_current_request', return_value=self.request) diff --git a/lms/djangoapps/student_account/views.py b/lms/djangoapps/student_account/views.py index 5e8d459d74..06aafdfa2e 100644 --- a/lms/djangoapps/student_account/views.py +++ b/lms/djangoapps/student_account/views.py @@ -47,6 +47,7 @@ from third_party_auth import pipeline from third_party_auth.decorators import xframe_allow_whitelisted from util.bad_request_rate_limiter import BadRequestRateLimiter from util.date_utils import strftime_localized +from util.enterprise_helpers import set_enterprise_branding_filter_param AUDIT_LOG = logging.getLogger("audit") log = logging.getLogger(__name__) @@ -68,7 +69,6 @@ def login_and_registration_form(request, initial_mode="login"): """ # Determine the URL to redirect to following login/registration/third_party_auth redirect_to = get_next_url_for_login_page(request) - # If we're already logged in, redirect to the dashboard if request.user.is_authenticated(): return redirect(redirect_to) @@ -76,6 +76,21 @@ def login_and_registration_form(request, initial_mode="login"): # Retrieve the form descriptions from the user API form_descriptions = _get_form_descriptions(request) + # Our ?next= URL may itself contain a parameter 'tpa_hint=x' that we need to check. + # If present, we display a login page focused on third-party auth with that provider. + third_party_auth_hint = None + if '?' in redirect_to: + try: + next_args = urlparse.parse_qs(urlparse.urlparse(redirect_to).query) + provider_id = next_args['tpa_hint'][0] + if third_party_auth.provider.Registry.get(provider_id=provider_id): + third_party_auth_hint = provider_id + initial_mode = "hinted_login" + except (KeyError, ValueError, IndexError): + pass + + set_enterprise_branding_filter_param(request=request, provider_id=third_party_auth_hint) + # If this is a themed site, revert to the old login/registration pages. # We need to do this for now to support existing themes. # Themed sites can use the new logistration page by setting @@ -92,19 +107,6 @@ def login_and_registration_form(request, initial_mode="login"): if ext_auth_response is not None: return ext_auth_response - # Our ?next= URL may itself contain a parameter 'tpa_hint=x' that we need to check. - # If present, we display a login page focused on third-party auth with that provider. - third_party_auth_hint = None - if '?' in redirect_to: - try: - next_args = urlparse.parse_qs(urlparse.urlparse(redirect_to).query) - provider_id = next_args['tpa_hint'][0] - if third_party_auth.provider.Registry.get(provider_id=provider_id): - third_party_auth_hint = provider_id - initial_mode = "hinted_login" - except (KeyError, ValueError, IndexError): - pass - # Otherwise, render the combined login/registration page context = { 'data': { diff --git a/lms/static/sass/course/layout/_courseware_header.scss b/lms/static/sass/course/layout/_courseware_header.scss index 8b96ccbc65..c3a74b5346 100644 --- a/lms/static/sass/course/layout/_courseware_header.scss +++ b/lms/static/sass/course/layout/_courseware_header.scss @@ -145,6 +145,7 @@ img { height: 30px; + width: auto; } } diff --git a/lms/static/sass/shared/_header.scss b/lms/static/sass/shared/_header.scss index 1379b44179..55532b1d10 100644 --- a/lms/static/sass/shared/_header.scss +++ b/lms/static/sass/shared/_header.scss @@ -26,6 +26,11 @@ a { display: block; } + + img.ec-logo-size { + width: 84px; + height: 56px; + } } nav { diff --git a/lms/templates/navigation.html b/lms/templates/navigation.html index 95633ba471..27caee7d2d 100644 --- a/lms/templates/navigation.html +++ b/lms/templates/navigation.html @@ -14,6 +14,7 @@ from openedx.core.djangolib.markup import HTML, Text from branding import api as branding_api # app that handles site status messages from status.status import get_site_status_msg +from util.enterprise_helpers import get_enterprise_customer_logo_url %> ## Provide a hook for themes to inject branding on top. @@ -51,7 +52,14 @@ site_status_msg = get_site_status_msg(course_id)

<%block name="navigation_logo"> - ${_( + <% + logo_url = get_enterprise_customer_logo_url(request) + logo_size = 'ec-logo-size' + if logo_url is None: + logo_url = branding_api.get_logo_url(is_secure) + logo_size = '' + %> + ${_(

diff --git a/requirements/edx/base.txt b/requirements/edx/base.txt index 2fbb84c1f8..1e1106f5e1 100644 --- a/requirements/edx/base.txt +++ b/requirements/edx/base.txt @@ -46,7 +46,7 @@ edx-drf-extensions==1.2.1 edx-lint==0.4.3 edx-django-oauth2-provider==1.1.4 edx-django-sites-extensions==2.1.1 -edx-enterprise==0.6.0 +edx-enterprise==0.8.0 edx-oauth2-provider==1.2.0 edx-opaque-keys==0.4.0 edx-organizations==0.4.1