diff --git a/common/djangoapps/embargo/models.py b/common/djangoapps/embargo/models.py index 8b6d2c7eaa..f62aae8903 100644 --- a/common/djangoapps/embargo/models.py +++ b/common/djangoapps/embargo/models.py @@ -401,6 +401,8 @@ class CountryAccessRule(models.Model): CACHE_KEY = u"embargo.allowed_countries.{course_key}" + ALL_COUNTRIES = set(code[0] for code in list(countries)) + @classmethod def check_country_access(cls, course_id, country): """ @@ -415,6 +417,14 @@ class CountryAccessRule(models.Model): True if country found in allowed country otherwise check given country exists in list """ + # If the country code is not in the list of all countries, + # we don't want to automatically exclude the user. + # This can happen, for example, when GeoIP falls back + # to using a continent code because it cannot determine + # the specific country. + if country not in cls.ALL_COUNTRIES: + return True + cache_key = cls.CACHE_KEY.format(course_key=course_id) allowed_countries = cache.get(cache_key) if allowed_countries is None: @@ -454,7 +464,7 @@ class CountryAccessRule(models.Model): # If there are no whitelist countries, default to all countries if not whitelist_countries: - whitelist_countries = set(code[0] for code in list(countries)) + whitelist_countries = cls.ALL_COUNTRIES # Consolidate the rules into a single list of countries # that have access to the course. diff --git a/common/djangoapps/embargo/tests/test_api.py b/common/djangoapps/embargo/tests/test_api.py index bf1b349f78..4d96ddc029 100644 --- a/common/djangoapps/embargo/tests/test_api.py +++ b/common/djangoapps/embargo/tests/test_api.py @@ -2,6 +2,7 @@ Tests for EmbargoMiddleware """ +from contextlib import contextmanager import mock import unittest import pygeoip @@ -85,9 +86,7 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): self.user.profile.save() # Appear to make a request from an IP in a particular country - with mock.patch.object(pygeoip.GeoIP, 'country_code_by_addr') as mock_ip: - mock_ip.return_value = ip_country - + with self._mock_geoip(ip_country): # Call the API. Note that the IP address we pass in doesn't # matter, since we're injecting a mock for geo-location result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') @@ -113,9 +112,7 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): country=Country.objects.get(country='US') ) - with mock.patch.object(pygeoip.GeoIP, 'country_code_by_addr') as mock_ip: - mock_ip.return_value = 'US' - + with self._mock_geoip('US'): # The user is set to None, because the user has not been authenticated. result = embargo_api.check_course_access(self.course.id, ip_address='0.0.0.0') self.assertFalse(result) @@ -137,6 +134,14 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='FE80::0202:B3FF:FE1E:8329') self.assertTrue(result) + def test_country_access_fallback_to_continent_code(self): + # Simulate PyGeoIP falling back to a continent code + # instead of a country code. In this case, we should + # allow the user access. + with self._mock_geoip('EU'): + result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') + self.assertTrue(result) + @mock.patch.dict(settings.FEATURES, {'EMBARGO': True}) def test_profile_country_db_null(self): # Django country fields treat NULL values inconsistently. @@ -156,15 +161,16 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): self.assertTrue(result) def test_caching(self): - # Test the scenario that will go through every check - # (restricted course, but pass all the checks) - # This is the worst case, so it will hit all of the - # caching code. - with self.assertNumQueries(3): - embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') + with self._mock_geoip('US'): + # Test the scenario that will go through every check + # (restricted course, but pass all the checks) + # This is the worst case, so it will hit all of the + # caching code. + with self.assertNumQueries(3): + embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') - with self.assertNumQueries(0): - embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') + with self.assertNumQueries(0): + embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') def test_caching_no_restricted_courses(self): RestrictedCourse.objects.all().delete() @@ -176,6 +182,11 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): with self.assertNumQueries(0): embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') + @contextmanager + def _mock_geoip(self, country_code): + with mock.patch.object(pygeoip.GeoIP, 'country_code_by_addr') as mock_ip: + mock_ip.return_value = country_code + yield @ddt.ddt @override_settings(MODULESTORE=MODULESTORE_CONFIG)