73 lines
3.4 KiB
Python
73 lines
3.4 KiB
Python
"""Tests for util.request module."""
|
|
|
|
import unittest
|
|
|
|
from django.conf import settings
|
|
from django.core.exceptions import SuspiciousOperation
|
|
from django.test.client import RequestFactory
|
|
from util.request import course_id_from_url, safe_get_host
|
|
|
|
|
|
class ResponseTestCase(unittest.TestCase):
|
|
""" Tests for response-related utility functions """
|
|
def setUp(self):
|
|
super(ResponseTestCase, self).setUp()
|
|
self.old_site_name = settings.SITE_NAME
|
|
self.old_allowed_hosts = settings.ALLOWED_HOSTS
|
|
|
|
def tearDown(self):
|
|
super(ResponseTestCase, self).tearDown()
|
|
settings.SITE_NAME = self.old_site_name
|
|
settings.ALLOWED_HOSTS = self.old_allowed_hosts
|
|
|
|
def test_safe_get_host(self):
|
|
""" Tests that the safe_get_host function returns the desired host """
|
|
settings.SITE_NAME = 'siteName.com'
|
|
factory = RequestFactory()
|
|
request = factory.request()
|
|
request.META['HTTP_HOST'] = 'www.userProvidedHost.com'
|
|
# If ALLOWED_HOSTS is not set properly, safe_get_host should return SITE_NAME
|
|
settings.ALLOWED_HOSTS = None
|
|
self.assertEqual(safe_get_host(request), "siteName.com")
|
|
settings.ALLOWED_HOSTS = ["*"]
|
|
self.assertEqual(safe_get_host(request), "siteName.com")
|
|
settings.ALLOWED_HOSTS = ["foo.com", "*"]
|
|
self.assertEqual(safe_get_host(request), "siteName.com")
|
|
|
|
# If ALLOWED_HOSTS is set properly, and the host is valid, we just return the user-provided host
|
|
settings.ALLOWED_HOSTS = [request.META['HTTP_HOST']]
|
|
self.assertEqual(safe_get_host(request), request.META['HTTP_HOST'])
|
|
|
|
# If ALLOWED_HOSTS is set properly but the host is invalid, we should get a SuspiciousOperation
|
|
settings.ALLOWED_HOSTS = ["the_valid_website.com"]
|
|
with self.assertRaises(SuspiciousOperation):
|
|
safe_get_host(request)
|
|
|
|
def test_course_id_from_url(self):
|
|
""" Test course_id_from_url(). """
|
|
|
|
self.assertIsNone(course_id_from_url('/login'))
|
|
self.assertIsNone(course_id_from_url('/course/edX/maths/2020'))
|
|
self.assertIsNone(course_id_from_url('/courses/edX/maths/'))
|
|
self.assertIsNone(course_id_from_url('/api/courses/v1/blocks/edX/maths/2020'))
|
|
self.assertIsNone(course_id_from_url('/api/courses/v1/blocks/course-v1:incidental+courseid+formatting'))
|
|
self.assertIsNone(course_id_from_url('/api/courses/v41/notcourses/course-v1:incidental+courseid+formatting'))
|
|
|
|
course_id = course_id_from_url('/courses/course-v1:edX+maths+2020')
|
|
self.assertCourseIdFieldsMatch(course_id=course_id, org="edX", course='maths', run='2020')
|
|
|
|
course_id = course_id_from_url('/courses/edX/maths/2020')
|
|
self.assertCourseIdFieldsMatch(course_id=course_id, org='edX', course='maths', run='2020')
|
|
|
|
course_id = course_id_from_url('/api/courses/v1/courses/course-v1:edX+maths+2020')
|
|
self.assertCourseIdFieldsMatch(course_id=course_id, org='edX', course='maths', run='2020')
|
|
|
|
course_id = course_id_from_url('/api/courses/v1/courses/edX/maths/2020')
|
|
self.assertCourseIdFieldsMatch(course_id=course_id, org='edX', course='maths', run='2020')
|
|
|
|
def assertCourseIdFieldsMatch(self, course_id, org, course, run):
|
|
""" Asserts that the passed-in course id matches the specified fields"""
|
|
self.assertEqual(course_id.org, org)
|
|
self.assertEqual(course_id.course, course)
|
|
self.assertEqual(course_id.run, run)
|