Move user retirement scripts code from the tubular repo (#34063)

* refactor: Migragte user retirement scripts code from the tubular repo
This commit is contained in:
Muhammad Farhan Khan
2024-02-22 21:09:00 +05:00
committed by GitHub
parent 20570ff417
commit 65ea55c8aa
48 changed files with 7808 additions and 1 deletions

View File

@@ -0,0 +1,33 @@
name: units-test-scripts-user-retirement
on:
pull_request:
push:
branches:
- master
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8' ]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r scripts/user_retirement/requirements/testing.txt
- name: Run pytest
run: |
pytest scripts/user_retirement

View File

@@ -137,7 +137,9 @@ REQ_FILES = \
requirements/edx/development \
requirements/edx/assets \
requirements/edx/semgrep \
scripts/xblock/requirements
scripts/xblock/requirements \
scripts/user_retirement/requirements/base \
scripts/user_retirement/requirements/testing
define COMMON_CONSTRAINTS_TEMP_COMMENT
# This is a temporary solution to override the real common_constraints.txt\n# In edx-lint, until the pyjwt constraint in edx-lint has been removed.\n# See BOM-2721 for more details.\n# Below is the copied and edited version of common_constraints\n

View File

@@ -0,0 +1,100 @@
User Retirement Scripts
=======================
`This <https://github.com/openedx/edx-platform/tree/master/scripts/user_retirement>`_ directory contains python scripts which are migrated from the `tubular <https://github.com/openedx/tubular/tree/master/scripts>`_ respository.
These scripts are intended to drive the user retirement workflow which involves handling the deactivation or removal of user accounts as part of the platform's management process.
These scripts could be called from any automation/CD framework.
How to run the scripts
======================
Download the Scripts
--------------------
To download the scripts, you can perform a partial clone of the edx-platform repository to obtain only the required scripts. The following steps demonstrate how to achieve this. Alternatively, you may choose other utilities or libraries for the partial clone.
.. code-block:: bash
repo_url=git@github.com:openedx/edx-platform.git
branch=master
directory=scripts/user_retirement
git clone --branch $branch --single-branch --depth=1 --filter=tree:0 $repo_url
cd edx-platform
git sparse-checkout init --cone
git sparse-checkout set $directory
Create Python Virtual Environment
---------------------------------
Create a Python virtual environment using Python 3.8:
.. code-block:: bash
python3.8 -m venv ../venv
source ../venv/bin/activate
Install Pip Packages
--------------------
Install the required pip packages using the provided requirements file:
.. code-block:: bash
pip install -r scripts/user_retirement/requirements/base.txt
In-depth Documentation and Configuration Steps
----------------------------------------------
For in-depth documentation and essential configurations follow these docs
`Documentation <https://edx.readthedocs.io/projects/edx-installing-configuring-and-running/en/latest/configuration/user_retire/index.html>`_
`Configuration Docs <https://edx.readthedocs.io/projects/edx-installing-configuring-and-running/en/latest/configuration/user_retire/driver_setup.html>`_
Execute Script
--------------
Execute the following shell command to establish entry points for the scripts
.. code-block:: bash
chmod +x scripts/user_retirement/entry_points.sh
source scripts/user_retirement/entry_points.sh
To retire a specific learner, you can use the provided example script:
.. code-block:: bash
retire_one_learner.py \
--config_file=src/config.yml \
--username=user1
Make sure to replace ``src/config.yml`` with the actual path to your configuration file and ``user1`` with the actual username.
You can also execute Python scripts directly using the file path:
.. code-block:: bash
python scripts/user_retirement/retire_one_learner.py \
--config_file=src/config.yml \
--username=user1
Feel free to customize these steps according to your specific environment and requirements.
Run Test Cases
==============
Before running test cases, install the testing requirements:
.. code-block:: bash
pip install -r scripts/user_retirement/requirements/testing.txt
Run the test cases using pytest:
.. code-block:: bash
pytest scripts/user_retirement

View File

View File

@@ -0,0 +1,7 @@
#!/usr/bin/env bash
alias get_learners_to_retire.py='python scripts/user_retirement/get_learners_to_retire.py'
alias replace_usernames.py='python scripts/user_retirement/replace_usernames.py'
alias retire_one_learner.py='python scripts/user_retirement/retire_one_learner.py'
alias retirement_archive_and_cleanup.py='python scripts/user_retirement/retirement_archive_and_cleanup.py'
alias retirement_bulk_status_update.py='python scripts/user_retirement/retirement_bulk_status_update.py'
alias retirement_partner_report.py='python scripts/user_retirement/retirement_partner_report.py'

View File

@@ -0,0 +1,105 @@
#! /usr/bin/env python3
"""
Command-line script to retrieve list of learners that have requested to be retired.
The script calls the appropriate LMS endpoint to get this list of learners.
"""
import io
import logging
import sys
from os import path
import click
import yaml
# Add top-level project path to sys.path before importing scripts code
sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..')))
from scripts.user_retirement.utils.edx_api import LmsApi
from scripts.user_retirement.utils.jenkins import export_learner_job_properties
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
LOG = logging.getLogger(__name__)
@click.command("get_learners_to_retire")
@click.option(
'--config_file',
help='File in which YAML config exists that overrides all other params.'
)
@click.option(
'--cool_off_days',
help='Number of days a learner should be in the retirement queue before being actually retired.',
default=7
)
@click.option(
'--output_dir',
help="Directory in which to write the Jenkins properties files.",
default='./jenkins_props'
)
@click.option(
'--user_count_error_threshold',
help="If more users than this number are returned we will error out instead of retiring. This is a failsafe"
"against attacks that somehow manage to add users to the retirement queue.",
default=300
)
@click.option(
'--max_user_batch_size',
help="This setting will only get at most X number of users. If this number is lower than the user_count_error_threshold"
"setting then it will not error.",
default=200
)
def get_learners_to_retire(config_file,
cool_off_days,
output_dir,
user_count_error_threshold,
max_user_batch_size):
"""
Retrieves a JWT token as the retirement service user, then calls the LMS
endpoint to retrieve the list of learners awaiting retirement.
"""
if not config_file:
click.echo('A config file is required.')
sys.exit(-1)
with io.open(config_file, 'r') as config:
config_yaml = yaml.safe_load(config)
user_count_error_threshold = int(user_count_error_threshold)
cool_off_days = int(cool_off_days)
client_id = config_yaml['client_id']
client_secret = config_yaml['client_secret']
lms_base_url = config_yaml['base_urls']['lms']
retirement_pipeline = config_yaml['retirement_pipeline']
end_states = [state[1] for state in retirement_pipeline]
states_to_request = ['PENDING'] + end_states
api = LmsApi(lms_base_url, lms_base_url, client_id, client_secret)
# Retrieve the learners to retire and export them to separate Jenkins property files.
learners_to_retire = api.learners_to_retire(states_to_request, cool_off_days, max_user_batch_size)
if max_user_batch_size:
learners_to_retire = learners_to_retire[:max_user_batch_size]
learners_to_retire_cnt = len(learners_to_retire)
if learners_to_retire_cnt > user_count_error_threshold:
click.echo(
'Too many learners to retire! Expected {} or fewer, got {}!'.format(
user_count_error_threshold,
learners_to_retire_cnt
)
)
sys.exit(-1)
export_learner_job_properties(
learners_to_retire,
output_dir
)
if __name__ == "__main__":
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
# If using env vars to provide params, prefix them with "RETIREMENT_", e.g. RETIREMENT_CLIENT_ID
get_learners_to_retire(auto_envvar_prefix='RETIREMENT')

View File

View File

@@ -0,0 +1,153 @@
#! /usr/bin/env python3
"""
Command-line script to replace the usernames for all passed in learners.
Accepts a list of current usernames and their preferred new username. This
script will call LMS first which generates a unique username if the passed in
new username is not unique. It then calls all other services to replace the
username in their DBs.
"""
import csv
import io
import logging
import sys
from os import path
import click
import yaml
# Add top-level project path to sys.path before importing scripts code
sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..')))
from scripts.user_retirement.utils.edx_api import ( # pylint: disable=wrong-import-position
CredentialsApi,
DiscoveryApi,
EcommerceApi,
LmsApi
)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
LOG = logging.getLogger(__name__)
def write_responses(writer, replacements, status):
for replacement in replacements:
original_username = list(replacement.keys())[0]
new_username = list(replacement.values())[0]
writer.writerow([original_username, new_username, status])
@click.command("replace_usernames")
@click.option(
'--config_file',
help='File in which YAML config exists that overrides all other params.'
)
@click.option(
'--username_replacement_csv',
help='File in which YAML config exists that overrides all other params.'
)
def replace_usernames(config_file, username_replacement_csv):
"""
Retrieves a JWT token as the retirement service user, then calls the LMS
endpoint to retrieve the list of learners awaiting retirement.
Config file example:
```
client_id: xxx
client_secret: xxx
base_urls:
lms: http://localhost:18000
ecommerce: http://localhost:18130
discovery: http://localhost:18381
credentials: http://localhost:18150
```
Username file example:
```
current_un_1,desired_un_1
current_un_2,desired_un_2,
current_un_3,desired_un_3
```
"""
if not config_file:
click.echo('A config file is required.')
sys.exit(-1)
if not username_replacement_csv:
click.echo('A username replacement CSV file is required')
sys.exit(-1)
with io.open(config_file, 'r') as config:
config_yaml = yaml.safe_load(config)
with io.open(username_replacement_csv, 'r') as replacement_file:
csv_reader = csv.reader(replacement_file)
lms_username_mappings = [
{current_username: desired_username}
for (current_username, desired_username)
in csv_reader
]
client_id = config_yaml['client_id']
client_secret = config_yaml['client_secret']
lms_base_url = config_yaml['base_urls']['lms']
ecommerce_base_url = config_yaml['base_urls']['ecommerce']
discovery_base_url = config_yaml['base_urls']['discovery']
credentials_base_url = config_yaml['base_urls']['credentials']
# Note that though partially_failed sounds better than completely_failed,
# it's actually worse since the user is not consistant across DBs.
# Partially failed username replacements will need to be triaged so the
# user isn't in a broken state
successful_replacements = []
partially_failed_replacements = []
fully_failed_replacements = []
lms_api = LmsApi(lms_base_url, lms_base_url, client_id, client_secret)
ecommerce_api = EcommerceApi(lms_base_url, ecommerce_base_url, client_id, client_secret)
discovery_api = DiscoveryApi(lms_base_url, discovery_base_url, client_id, client_secret)
credentials_api = CredentialsApi(lms_base_url, credentials_base_url, client_id, client_secret)
# Call LMS with current and desired usernames
response = lms_api.replace_lms_usernames(lms_username_mappings)
fully_failed_replacements += response['failed_replacements']
in_progress_replacements = response['successful_replacements']
# Step through each services endpoints with the list returned from LMS.
# The LMS list has already verified usernames and made any duplicate
# usernames unique (e.g. 'matt' => 'mattf56a'). We pass successful
# replacements onto the next service and store all failed replacments.
replacement_methods = [
ecommerce_api.replace_usernames,
discovery_api.replace_usernames,
credentials_api.replace_usernames,
lms_api.replace_forums_usernames,
]
# Iterate through the endpoints above and if the APIs return any failures
# capture these in partially_failed_replacements. Only successfuly
# replacements will continue to be passed to the next service.
for replacement_method in replacement_methods:
response = replacement_method(in_progress_replacements)
partially_failed_replacements += response['failed_replacements']
in_progress_replacements = response['successful_replacements']
successful_replacements = in_progress_replacements
with open('username_replacement_results.csv', 'w', newline='') as output_file:
csv_writer = csv.writer(output_file)
# Write header
csv_writer.writerow(['Original Username', 'New Username', 'Status'])
write_responses(csv_writer, successful_replacements, "SUCCESS")
write_responses(csv_writer, partially_failed_replacements, "PARTIALLY FAILED")
write_responses(csv_writer, fully_failed_replacements, "FAILED")
if partially_failed_replacements or fully_failed_replacements:
sys.exit(-1)
if __name__ == "__main__":
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
# If using env vars to provide params, prefix them with "RETIREMENT_", e.g. RETIREMENT_CLIENT_ID
replace_usernames(auto_envvar_prefix='USERNAME_REPLACEMENT')

View File

@@ -0,0 +1,11 @@
boto3
click
pyyaml
backoff
requests
edx-rest-api-client
jenkinsapi
unicodecsv
simplejson
simple-salesforce
google-api-python-client

View File

@@ -0,0 +1,178 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# make upgrade
#
asgiref==3.7.2
# via django
attrs==23.2.0
# via zeep
backoff==2.2.1
# via -r scripts/user_retirement/requirements/base.in
backports-zoneinfo==0.2.1
# via
# django
# pendulum
boto3==1.34.26
# via -r scripts/user_retirement/requirements/base.in
botocore==1.34.26
# via
# boto3
# s3transfer
cachetools==5.3.2
# via google-auth
certifi==2023.11.17
# via requests
cffi==1.16.0
# via
# cryptography
# pynacl
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via
# -r scripts/user_retirement/requirements/base.in
# edx-django-utils
cryptography==42.0.0
# via simple-salesforce
django==4.2.9
# via
# django-crum
# django-waffle
# edx-django-utils
django-crum==0.7.9
# via edx-django-utils
django-waffle==4.1.0
# via edx-django-utils
edx-django-utils==5.10.1
# via edx-rest-api-client
edx-rest-api-client==5.6.1
# via -r scripts/user_retirement/requirements/base.in
google-api-core==2.15.0
# via google-api-python-client
google-api-python-client==2.115.0
# via -r scripts/user_retirement/requirements/base.in
google-auth==2.26.2
# via
# google-api-core
# google-api-python-client
# google-auth-httplib2
google-auth-httplib2==0.2.0
# via google-api-python-client
googleapis-common-protos==1.62.0
# via google-api-core
httplib2==0.22.0
# via
# google-api-python-client
# google-auth-httplib2
idna==3.6
# via requests
importlib-resources==6.1.1
# via pendulum
isodate==0.6.1
# via zeep
jenkinsapi==0.3.13
# via -r scripts/user_retirement/requirements/base.in
jmespath==1.0.1
# via
# boto3
# botocore
lxml==4.9.3
# via zeep
more-itertools==10.2.0
# via simple-salesforce
newrelic==9.5.0
# via edx-django-utils
pbr==6.0.0
# via stevedore
pendulum==3.0.0
# via simple-salesforce
platformdirs==4.1.0
# via zeep
protobuf==4.25.2
# via
# google-api-core
# googleapis-common-protos
psutil==5.9.8
# via edx-django-utils
pyasn1==0.5.1
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.3.0
# via google-auth
pycparser==2.21
# via cffi
pyjwt==2.8.0
# via
# edx-rest-api-client
# simple-salesforce
pynacl==1.5.0
# via edx-django-utils
pyparsing==3.1.1
# via httplib2
python-dateutil==2.8.2
# via
# botocore
# pendulum
# time-machine
pytz==2023.3.post1
# via
# jenkinsapi
# zeep
pyyaml==6.0.1
# via -r scripts/user_retirement/requirements/base.in
requests==2.31.0
# via
# -r scripts/user_retirement/requirements/base.in
# edx-rest-api-client
# google-api-core
# jenkinsapi
# requests-file
# requests-toolbelt
# simple-salesforce
# slumber
# zeep
requests-file==1.5.1
# via zeep
requests-toolbelt==1.0.0
# via zeep
rsa==4.9
# via google-auth
s3transfer==0.10.0
# via boto3
simple-salesforce==1.12.5
# via -r scripts/user_retirement/requirements/base.in
simplejson==3.19.2
# via -r scripts/user_retirement/requirements/base.in
six==1.16.0
# via
# isodate
# jenkinsapi
# python-dateutil
# requests-file
slumber==0.7.1
# via edx-rest-api-client
sqlparse==0.4.4
# via django
stevedore==5.1.0
# via edx-django-utils
time-machine==2.13.0
# via pendulum
typing-extensions==4.9.0
# via asgiref
tzdata==2023.4
# via pendulum
unicodecsv==0.14.1
# via -r scripts/user_retirement/requirements/base.in
uritemplate==4.1.1
# via google-api-python-client
urllib3==1.26.18
# via
# botocore
# requests
zeep==4.2.1
# via simple-salesforce
zipp==3.17.0
# via importlib-resources

View File

@@ -0,0 +1,8 @@
-r base.txt
moto
pytest
requests_mock
responses
mock
ddt

View File

@@ -0,0 +1,316 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# make upgrade
#
asgiref==3.7.2
# via
# -r scripts/user_retirement/requirements/base.txt
# django
attrs==23.2.0
# via
# -r scripts/user_retirement/requirements/base.txt
# zeep
backoff==2.2.1
# via -r scripts/user_retirement/requirements/base.txt
backports-zoneinfo==0.2.1
# via
# -r scripts/user_retirement/requirements/base.txt
# django
# pendulum
boto3==1.34.26
# via
# -r scripts/user_retirement/requirements/base.txt
# moto
botocore==1.34.26
# via
# -r scripts/user_retirement/requirements/base.txt
# boto3
# moto
# s3transfer
cachetools==5.3.2
# via
# -r scripts/user_retirement/requirements/base.txt
# google-auth
certifi==2023.11.17
# via
# -r scripts/user_retirement/requirements/base.txt
# requests
cffi==1.16.0
# via
# -r scripts/user_retirement/requirements/base.txt
# cryptography
# pynacl
charset-normalizer==3.3.2
# via
# -r scripts/user_retirement/requirements/base.txt
# requests
click==8.1.7
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-django-utils
cryptography==42.0.0
# via
# -r scripts/user_retirement/requirements/base.txt
# moto
# simple-salesforce
ddt==1.7.1
# via -r scripts/user_retirement/requirements/testing.in
django==4.2.9
# via
# -r scripts/user_retirement/requirements/base.txt
# django-crum
# django-waffle
# edx-django-utils
django-crum==0.7.9
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-django-utils
django-waffle==4.1.0
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-django-utils
edx-django-utils==5.10.1
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-rest-api-client
edx-rest-api-client==5.6.1
# via -r scripts/user_retirement/requirements/base.txt
exceptiongroup==1.2.0
# via pytest
google-api-core==2.15.0
# via
# -r scripts/user_retirement/requirements/base.txt
# google-api-python-client
google-api-python-client==2.115.0
# via -r scripts/user_retirement/requirements/base.txt
google-auth==2.26.2
# via
# -r scripts/user_retirement/requirements/base.txt
# google-api-core
# google-api-python-client
# google-auth-httplib2
google-auth-httplib2==0.2.0
# via
# -r scripts/user_retirement/requirements/base.txt
# google-api-python-client
googleapis-common-protos==1.62.0
# via
# -r scripts/user_retirement/requirements/base.txt
# google-api-core
httplib2==0.22.0
# via
# -r scripts/user_retirement/requirements/base.txt
# google-api-python-client
# google-auth-httplib2
idna==3.6
# via
# -r scripts/user_retirement/requirements/base.txt
# requests
importlib-resources==6.1.1
# via
# -r scripts/user_retirement/requirements/base.txt
# pendulum
iniconfig==2.0.0
# via pytest
isodate==0.6.1
# via
# -r scripts/user_retirement/requirements/base.txt
# zeep
jenkinsapi==0.3.13
# via -r scripts/user_retirement/requirements/base.txt
jinja2==3.1.3
# via moto
jmespath==1.0.1
# via
# -r scripts/user_retirement/requirements/base.txt
# boto3
# botocore
lxml==4.9.3
# via
# -r scripts/user_retirement/requirements/base.txt
# zeep
markupsafe==2.1.4
# via
# jinja2
# werkzeug
mock==5.1.0
# via -r scripts/user_retirement/requirements/testing.in
more-itertools==10.2.0
# via
# -r scripts/user_retirement/requirements/base.txt
# simple-salesforce
moto==4.2.13
# via -r scripts/user_retirement/requirements/testing.in
newrelic==9.5.0
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-django-utils
packaging==23.2
# via pytest
pbr==6.0.0
# via
# -r scripts/user_retirement/requirements/base.txt
# stevedore
pendulum==3.0.0
# via
# -r scripts/user_retirement/requirements/base.txt
# simple-salesforce
platformdirs==4.1.0
# via
# -r scripts/user_retirement/requirements/base.txt
# zeep
pluggy==1.3.0
# via pytest
protobuf==4.25.2
# via
# -r scripts/user_retirement/requirements/base.txt
# google-api-core
# googleapis-common-protos
psutil==5.9.8
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-django-utils
pyasn1==0.5.1
# via
# -r scripts/user_retirement/requirements/base.txt
# pyasn1-modules
# rsa
pyasn1-modules==0.3.0
# via
# -r scripts/user_retirement/requirements/base.txt
# google-auth
pycparser==2.21
# via
# -r scripts/user_retirement/requirements/base.txt
# cffi
pyjwt==2.8.0
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-rest-api-client
# simple-salesforce
pynacl==1.5.0
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-django-utils
pyparsing==3.1.1
# via
# -r scripts/user_retirement/requirements/base.txt
# httplib2
pytest==7.4.4
# via -r scripts/user_retirement/requirements/testing.in
python-dateutil==2.8.2
# via
# -r scripts/user_retirement/requirements/base.txt
# botocore
# moto
# pendulum
# time-machine
pytz==2023.3.post1
# via
# -r scripts/user_retirement/requirements/base.txt
# jenkinsapi
# zeep
pyyaml==6.0.1
# via
# -r scripts/user_retirement/requirements/base.txt
# responses
requests==2.31.0
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-rest-api-client
# google-api-core
# jenkinsapi
# moto
# requests-file
# requests-mock
# requests-toolbelt
# responses
# simple-salesforce
# slumber
# zeep
requests-file==1.5.1
# via
# -r scripts/user_retirement/requirements/base.txt
# zeep
requests-mock==1.11.0
# via -r scripts/user_retirement/requirements/testing.in
requests-toolbelt==1.0.0
# via
# -r scripts/user_retirement/requirements/base.txt
# zeep
responses==0.24.1
# via
# -r scripts/user_retirement/requirements/testing.in
# moto
rsa==4.9
# via
# -r scripts/user_retirement/requirements/base.txt
# google-auth
s3transfer==0.10.0
# via
# -r scripts/user_retirement/requirements/base.txt
# boto3
simple-salesforce==1.12.5
# via -r scripts/user_retirement/requirements/base.txt
simplejson==3.19.2
# via -r scripts/user_retirement/requirements/base.txt
six==1.16.0
# via
# -r scripts/user_retirement/requirements/base.txt
# isodate
# jenkinsapi
# python-dateutil
# requests-file
# requests-mock
slumber==0.7.1
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-rest-api-client
sqlparse==0.4.4
# via
# -r scripts/user_retirement/requirements/base.txt
# django
stevedore==5.1.0
# via
# -r scripts/user_retirement/requirements/base.txt
# edx-django-utils
time-machine==2.13.0
# via
# -r scripts/user_retirement/requirements/base.txt
# pendulum
tomli==2.0.1
# via pytest
typing-extensions==4.9.0
# via
# -r scripts/user_retirement/requirements/base.txt
# asgiref
tzdata==2023.4
# via
# -r scripts/user_retirement/requirements/base.txt
# pendulum
unicodecsv==0.14.1
# via -r scripts/user_retirement/requirements/base.txt
uritemplate==4.1.1
# via
# -r scripts/user_retirement/requirements/base.txt
# google-api-python-client
urllib3==1.26.18
# via
# -r scripts/user_retirement/requirements/base.txt
# botocore
# requests
# responses
werkzeug==3.0.1
# via moto
xmltodict==0.13.0
# via moto
zeep==4.2.1
# via
# -r scripts/user_retirement/requirements/base.txt
# simple-salesforce
zipp==3.17.0
# via
# -r scripts/user_retirement/requirements/base.txt
# importlib-resources

View File

@@ -0,0 +1,224 @@
#! /usr/bin/env python3
"""
Command-line script to drive the user retirement workflow for a single user
To run this script you will need a username to run against and a YAML config file in the format:
client_id: <client id from LMS DOT>
client_secret: <client secret from LMS DOT>
base_urls:
lms: http://localhost:18000/
ecommerce: http://localhost:18130/
credentials: http://localhost:18150/
demographics: http://localhost:18360/
retirement_pipeline:
- ['RETIRING_CREDENTIALS', 'CREDENTIALS_COMPLETE', 'CREDENTIALS', 'retire_learner']
- ['RETIRING_ECOM', 'ECOM_COMPLETE', 'ECOMMERCE', 'retire_learner']
- ['RETIRING_DEMOGRAPHICS', 'DEMOGRAPHICS_COMPLETE', 'DEMOGRAPHICS', 'retire_learner']
- ['RETIRING_LICENSE_MANAGER', 'LICENSE_MANAGER_COMPLETE', 'LICENSE_MANAGER', 'retire_learner']
- ['RETIRING_FORUMS', 'FORUMS_COMPLETE', 'LMS', 'retirement_retire_forum']
- ['RETIRING_EMAIL_LISTS', 'EMAIL_LISTS_COMPLETE', 'LMS', 'retirement_retire_mailings']
- ['RETIRING_ENROLLMENTS', 'ENROLLMENTS_COMPLETE', 'LMS', 'retirement_unenroll']
- ['RETIRING_LMS', 'LMS_COMPLETE', 'LMS', 'retirement_lms_retire']
"""
import logging
import sys
from functools import partial
from os import path
from time import time
import click
# Add top-level project path to sys.path before importing scripts code
sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..')))
from scripts.user_retirement.utils.exception import HttpDoesNotExistException
# pylint: disable=wrong-import-position
from scripts.user_retirement.utils.helpers import (
_config_or_exit,
_fail,
_fail_exception,
_get_error_str_from_exception,
_log,
_setup_all_apis_or_exit
)
# Return codes for various fail cases
ERR_SETUP_FAILED = -1
ERR_USER_AT_END_STATE = -2
ERR_USER_IN_WORKING_STATE = -3
ERR_WHILE_RETIRING = -4
ERR_BAD_LEARNER = -5
ERR_UNKNOWN_STATE = -6
ERR_BAD_CONFIG = -7
SCRIPT_SHORTNAME = 'Learner Retirement'
LOG = partial(_log, SCRIPT_SHORTNAME)
FAIL = partial(_fail, SCRIPT_SHORTNAME)
FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME)
CONFIG_OR_EXIT = partial(_config_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG)
SETUP_ALL_APIS_OR_EXIT = partial(_setup_all_apis_or_exit, FAIL_EXCEPTION, ERR_SETUP_FAILED)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# "Magic" states with special meaning, these are required to be in LMS
START_STATE = 'PENDING'
ERROR_STATE = 'ERRORED'
COMPLETE_STATE = 'COMPLETE'
ABORTED_STATE = 'ABORTED'
END_STATES = (ERROR_STATE, ABORTED_STATE, COMPLETE_STATE)
# We'll store the access token here once retrieved
AUTH_HEADER = {}
def _get_learner_state_index_or_exit(learner, config):
"""
Returns the index in the ALL_STATES retirement state list, validating that it is in
an appropriate state to work on.
"""
try:
learner_state = learner['current_state']['state_name']
learner_state_index = config['all_states'].index(learner_state)
if learner_state in END_STATES:
FAIL(ERR_USER_AT_END_STATE, 'User already in end state: {}'.format(learner_state))
if learner_state in config['working_states']:
FAIL(ERR_USER_IN_WORKING_STATE, 'User is already in a working state! {}'.format(learner_state))
return learner_state_index
except KeyError:
FAIL(ERR_BAD_LEARNER, 'Bad learner response missing current_state or state_name: {}'.format(learner))
except ValueError:
FAIL(ERR_UNKNOWN_STATE, 'Unknown learner retirement state for learner: {}'.format(learner))
def _config_retirement_pipeline(config):
"""
Organizes the pipeline and populate the various state types
"""
# List of states where an API call is currently in progress
retirement_pipeline = config['retirement_pipeline']
config['working_states'] = [state[0] for state in retirement_pipeline]
# Create the full list of all of our states
config['all_states'] = [START_STATE]
for working in config['retirement_pipeline']:
config['all_states'].append(working[0])
config['all_states'].append(working[1])
for end in END_STATES:
config['all_states'].append(end)
def _get_learner_and_state_index_or_exit(config, username):
"""
Double-checks the current learner state, contacting LMS, and maps that state to its
index in the pipeline. Exits out if the learner is in an invalid state or not found
in LMS.
"""
try:
learner = config['LMS'].get_learner_retirement_state(username)
learner_state_index = _get_learner_state_index_or_exit(learner, config)
return learner, learner_state_index
except HttpDoesNotExistException:
FAIL(ERR_BAD_LEARNER, 'Learner {} not found. Please check that the learner is present in '
'UserRetirementStatus, is not already retired, '
'and is in an appropriate state to be acted upon.'.format(username))
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_SETUP_FAILED, 'Unexpected error fetching user state!', str(exc))
def _get_ecom_segment_id(config, learner):
"""
Calls Ecommerce to get the ecom-specific Segment tracking id that we need to retire.
This is only available from Ecommerce, unfortunately, and makes more sense to handle
here than to pass all of the config down to SegmentApi.
"""
try:
return config['ECOMMERCE'].get_tracking_key(learner)
except HttpDoesNotExistException:
LOG('Learner {} not found in Ecommerce. Setting Ecommerce Segment ID to None'.format(learner))
return None
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_SETUP_FAILED, 'Unexpected error fetching Ecommerce tracking id!', str(exc))
@click.command("retire_learner")
@click.option(
'--username',
help='The original username of the user to retire'
)
@click.option(
'--config_file',
help='File in which YAML config exists that overrides all other params.'
)
def retire_learner(
username,
config_file
):
"""
Retrieves a JWT token as the retirement service learner, then performs the retirement process as
defined in WORKING_STATE_ORDER
"""
LOG('Starting learner retirement for {} using config file {}'.format(username, config_file))
if not config_file:
FAIL(ERR_BAD_CONFIG, 'No config file passed in.')
config = CONFIG_OR_EXIT(config_file)
_config_retirement_pipeline(config)
SETUP_ALL_APIS_OR_EXIT(config)
learner, learner_state_index = _get_learner_and_state_index_or_exit(config, username)
if config.get('fetch_ecommerce_segment_id', False):
learner['ecommerce_segment_id'] = _get_ecom_segment_id(config, learner)
start_state = None
try:
for start_state, end_state, service, method in config['retirement_pipeline']:
# Skip anything that has already been done
if config['all_states'].index(start_state) < learner_state_index:
LOG('State {} completed in previous run, skipping'.format(start_state))
continue
LOG('Starting state {}'.format(start_state))
config['LMS'].update_learner_retirement_state(username, start_state, 'Starting: {}'.format(start_state))
# This does the actual API call
start_time = time()
response = getattr(config[service], method)(learner)
end_time = time()
LOG('State {} completed in {} seconds'.format(start_state, end_time - start_time))
config['LMS'].update_learner_retirement_state(
username,
end_state,
'Ending: {} with response:\n{}'.format(end_state, response)
)
learner_state_index += 1
LOG('Progressing to state {}'.format(end_state))
config['LMS'].update_learner_retirement_state(username, COMPLETE_STATE, 'Learner retirement complete.')
LOG('Retirement complete for learner {}'.format(username))
except Exception as exc: # pylint: disable=broad-except
exc_msg = _get_error_str_from_exception(exc)
try:
LOG('Error in retirement state {}: {}'.format(start_state, exc_msg))
config['LMS'].update_learner_retirement_state(username, ERROR_STATE, exc_msg)
except Exception as update_exc: # pylint: disable=broad-except
LOG('Critical error attempting to change learner state to ERRORED: {}'.format(update_exc))
FAIL_EXCEPTION(ERR_WHILE_RETIRING, 'Error encountered in state "{}"'.format(start_state), exc)
if __name__ == '__main__':
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
retire_learner(auto_envvar_prefix='RETIREMENT')

View File

@@ -0,0 +1,329 @@
#! /usr/bin/env python3
"""
Command-line script to bulk archive and cleanup retired learners from LMS
"""
import datetime
import gzip
import json
import logging
import sys
import time
from functools import partial
from os import path
import backoff
import boto3
import click
from botocore.exceptions import BotoCoreError, ClientError
from six import text_type
# Add top-level project path to sys.path before importing scripts code
sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..')))
# pylint: disable=wrong-import-position
from scripts.user_retirement.utils.helpers import _config_or_exit, _fail, _fail_exception, _log, _setup_lms_api_or_exit
SCRIPT_SHORTNAME = 'Archive and Cleanup'
# Return codes for various fail cases
ERR_NO_CONFIG = -1
ERR_BAD_CONFIG = -2
ERR_FETCHING = -3
ERR_ARCHIVING = -4
ERR_DELETING = -5
ERR_SETUP_FAILED = -5
ERR_BAD_CLI_PARAM = -6
LOG = partial(_log, SCRIPT_SHORTNAME)
FAIL = partial(_fail, SCRIPT_SHORTNAME)
FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME)
CONFIG_OR_EXIT = partial(_config_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG)
SETUP_LMS_OR_EXIT = partial(_setup_lms_api_or_exit, FAIL, ERR_SETUP_FAILED)
DELAY = 10
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger('boto').setLevel(logging.INFO)
def _fetch_learners_to_archive_or_exit(config, start_date, end_date, initial_state):
"""
Makes the call to fetch learners to be cleaned up, returns the list of learners or exits.
"""
LOG('Fetching users in state {} created from {} to {}'.format(initial_state, start_date, end_date))
try:
learners = config['LMS'].get_learners_by_date_and_status(initial_state, start_date, end_date)
LOG('Successfully fetched {} learners'.format(str(len(learners))))
return learners
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_FETCHING, 'Unexpected error occurred fetching users to update!', exc)
def _batch_learners(learners=None, batch_size=None):
"""
To avoid potentially overwheling the LMS with a large number of user retirements to
delete, create a list of smaller batches of users to iterate over. This has the
added benefit of reducing the amount of user retirement archive requests that can
get into a bad state should this script experience an error.
Args:
learners (list): List of learners to portion into smaller batches (lists)
batch_size (int): The number of learners to portion into each batch. If this
parameter is not supplied, this function will return one batch containing
all of the learners supplied to it.
"""
if batch_size:
return [
learners[i:i + batch_size] for i, _ in list(enumerate(learners))[::batch_size]
]
else:
return [learners]
def _on_s3_backoff(details):
"""
Callback that is called when backoff... backs off
"""
LOG("Backing off {wait:0.1f} seconds after {tries} tries calling function {target}".format(**details))
@backoff.on_exception(
backoff.expo,
(
ClientError,
BotoCoreError
),
on_backoff=lambda details: _on_s3_backoff(details), # pylint: disable=unnecessary-lambda,
max_time=120, # 2 minutes
)
def _upload_to_s3(config, filename, dry_run=False):
"""
Upload the archive file to S3
"""
try:
datestr = datetime.datetime.now().strftime('%Y/%m/')
s3 = boto3.resource('s3')
bucket_name = config['s3_archive']['bucket_name']
# Dry runs of this script should only generate the retirement archive file, not push it to s3.
bucket = s3.Bucket(bucket_name)
key = 'raw/' + datestr + filename
if dry_run:
LOG('Dry run. Skipping the step to upload data to {}'.format(key))
return
else:
bucket.upload_file(filename, key)
LOG('Successfully uploaded retirement data to {}'.format(key))
except Exception as exc:
LOG(text_type(exc))
raise
def _format_datetime_for_athena(timestamp):
"""
Takes a JSON serialized timestamp string and returns a format of it that is queryable as a datetime in Athena
"""
return timestamp.replace('T', ' ').rstrip('Z')
def _archive_retirements_or_exit(config, learners, dry_run=False):
"""
Creates an archive file with all of the retirements and uploads it to S3
The format of learners from LMS should be a list of these:
{
'id': 46, # This is the UserRetirementStatus ID!
'user':
{
'id': 5213599, # THIS is the LMS User ID
'username': 'retired__user_88ad587896920805c26041a2e75c767c75471ee9',
'email': 'retired__user_d08919da55a0e03c032425567e4a33e860488a96@retired.invalid',
'profile':
{
'id': 2842382,
'name': ''
}
},
'current_state':
{
'id': 41,
'state_name': 'COMPLETE',
'state_execution_order': 13
},
'last_state': {
'id': 1,
'state_name': 'PENDING',
'state_execution_order': 1
},
'created': '2018-10-18T20:35:52.349757Z', # This is the UserRetirementStatus creation date
'modified': '2018-10-18T20:35:52.350050Z', # This is the UserRetirementStatus last touched date
'original_username': 'retirement_test',
'original_email': 'orig@foo.invalid',
'original_name': 'Retirement Test',
'retired_username': 'retired__user_88ad587896920805c26041a2e75c767c75471ee9',
'retired_email': 'retired__user_d08919da55a0e03c032425567e4a33e860488a96@retired.invalid'
}
"""
LOG('Archiving retirements for {} learners to {}'.format(len(learners), config['s3_archive']['bucket_name']))
try:
now = _get_utc_now()
filename = 'retirement_archive_{}.json.gz'.format(now.strftime('%Y_%d_%m_%H_%M_%S'))
LOG('Creating retirement archive file {}'.format(filename))
# The file format is one JSON object per line with the newline as a separator. This allows for
# easy queries via AWS Athena if we need to confirm learner deletion.
with gzip.open(filename, 'wt') as out:
for learner in learners:
user = {
'user_id': learner['user']['id'],
'original_username': learner['original_username'],
'original_email': learner['original_email'],
'original_name': learner['original_name'],
'retired_username': learner['retired_username'],
'retired_email': learner['retired_email'],
'retirement_request_date': _format_datetime_for_athena(learner['created']),
'last_modified_date': _format_datetime_for_athena(learner['modified']),
}
json.dump(user, out)
out.write("\n")
if dry_run:
LOG('Dry run. Logging the contents of {} for debugging'.format(filename))
with gzip.open(filename, 'r') as archive_file:
for line in archive_file.readlines():
LOG(line)
_upload_to_s3(config, filename, dry_run)
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_ARCHIVING, 'Unexpected error occurred archiving retirements!', exc)
def _cleanup_retirements_or_exit(config, learners):
"""
Bulk deletes the retirements for this run
"""
LOG('Cleaning up retirements for {} learners'.format(len(learners)))
try:
usernames = [l['original_username'] for l in learners]
config['LMS'].bulk_cleanup_retirements(usernames)
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_DELETING, 'Unexpected error occurred deleting retirements!', exc)
def _get_utc_now():
"""
Helper function only used to make unit test mocking/patching easier.
"""
return datetime.datetime.utcnow()
@click.command("archive_and_cleanup")
@click.option(
'--config_file',
help='YAML file that contains retirement-related configuration for this environment.'
)
@click.option(
'--cool_off_days',
help='Number of days a retirement should exist before being archived and deleted.',
type=int,
default=37 # 7 days before retirement, 30 after
)
@click.option(
'--dry_run',
help='''
Should this script be run in a dry-run mode, in which generated retirement
archive files are not pushed to s3 and retirements are not cleaned up in the LMS
''',
type=bool,
default=False
)
@click.option(
'--start_date',
help='''
Start of window used to select user retirements for archival. Only user retirements
added to the retirement queue after this date will be processed.
''',
type=click.DateTime(formats=['%Y-%m-%d'])
)
@click.option(
'--end_date',
help='''
End of window used to select user retirments for archival. Only user retirments
added to the retirement queue before this date will be processed. In the case that
this date is more recent than the value specified in the `cool_off_days` parameter,
an error will be thrown. If this parameter is not used, the script will default to
using an end_date based upon the `cool_off_days` parameter.
''',
type=click.DateTime(formats=['%Y-%m-%d'])
)
@click.option(
'--batch_size',
help='Number of user retirements to process',
type=int
)
def archive_and_cleanup(config_file, cool_off_days, dry_run, start_date, end_date, batch_size):
"""
Cleans up UserRetirementStatus rows in LMS by:
1- Getting all rows currently in COMPLETE that were created --cool_off_days ago or more,
unless a specific timeframe is specified
2- Archiving them to S3 in an Athena-queryable format
3- Deleting them from LMS (by username)
"""
try:
LOG('Starting bulk update script: Config: {}'.format(config_file))
if not config_file:
FAIL(ERR_NO_CONFIG, 'No config file passed in.')
config = CONFIG_OR_EXIT(config_file)
SETUP_LMS_OR_EXIT(config)
if not start_date:
# This date is just a bogus "earliest possible value" since the call requires one
start_date = datetime.datetime.strptime('2018-01-01', '%Y-%m-%d')
if end_date:
if end_date > _get_utc_now() - datetime.timedelta(days=cool_off_days):
FAIL(ERR_BAD_CLI_PARAM, 'End date cannot occur within the cool_off_days period')
else:
# Set an end_date of `cool_off_days` days before the time that this script is run
end_date = _get_utc_now() - datetime.timedelta(days=cool_off_days)
if start_date >= end_date:
FAIL(ERR_BAD_CLI_PARAM, 'Conflicting start and end dates passed on CLI')
LOG(
'Fetching retirements for learners that have a COMPLETE status and were created '
'between {} and {}.'.format(
start_date, end_date
)
)
learners = _fetch_learners_to_archive_or_exit(
config, start_date, end_date, 'COMPLETE'
)
learners_to_process = _batch_learners(learners, batch_size)
num_batches = len(learners_to_process)
if learners_to_process:
for index, batch in enumerate(learners_to_process):
LOG(
'Processing batch {} out of {} of user retirement requests'.format(
str(index + 1), str(num_batches)
)
)
_archive_retirements_or_exit(config, batch, dry_run)
if dry_run:
LOG('This is a dry-run. Exiting before any retirements are cleaned up')
else:
_cleanup_retirements_or_exit(config, batch)
LOG('Archive and cleanup complete for batch #{}'.format(str(index + 1)))
time.sleep(DELAY)
else:
LOG('No learners found!')
except Exception as exc:
LOG(text_type(exc))
raise
if __name__ == '__main__':
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
archive_and_cleanup(auto_envvar_prefix='RETIREMENT')

View File

@@ -0,0 +1,146 @@
#! /usr/bin/env python3
"""
Command-line script to bulk update retirement states in LMS
"""
import logging
import sys
from datetime import datetime
from functools import partial
from os import path
import click
from six import text_type
# Add top-level project path to sys.path before importing scripts code
sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..')))
# pylint: disable=wrong-import-position
from scripts.user_retirement.utils.helpers import _config_or_exit, _fail, _fail_exception, _log, _setup_lms_api_or_exit
SCRIPT_SHORTNAME = 'Bulk Status'
# Return codes for various fail cases
ERR_NO_CONFIG = -1
ERR_BAD_CONFIG = -2
ERR_FETCHING = -3
ERR_UPDATING = -4
ERR_SETUP_FAILED = -5
LOG = partial(_log, SCRIPT_SHORTNAME)
FAIL = partial(_fail, SCRIPT_SHORTNAME)
FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME)
CONFIG_OR_EXIT = partial(_config_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG)
SETUP_LMS_OR_EXIT = partial(_setup_lms_api_or_exit, FAIL, ERR_SETUP_FAILED)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
def validate_dates(_, __, value):
"""
Click input validator for date options.
- Validates string format
- Transforms the string into a datetime.Date object
- Validates the date is less than or equal to today
- Returns the Date, or raises a click.BadParameter
"""
try:
date = datetime.strptime(value, '%Y-%m-%d').date()
if date > datetime.now().date():
raise ValueError()
return date
except ValueError:
raise click.BadParameter('Dates need to be in the format of YYYY-MM-DD and today or earlier.')
def _fetch_learners_to_update_or_exit(config, start_date, end_date, initial_state):
"""
Makes the call to fetch learners to be bulk updated, returns the list of learners
or exits.
"""
LOG('Fetching users in state {} created from {} to {}'.format(initial_state, start_date, end_date))
try:
return config['LMS'].get_learners_by_date_and_status(initial_state, start_date, end_date)
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_FETCHING, 'Unexpected error occurred fetching users to update!', exc)
def _update_learners_or_exit(config, learners, new_state=None, rewind_state=False):
"""
Iterates the list of learners, setting each to the new state. On any error
it will exit the script. If rewind_state is set to True then the learner
will be reset to their previous state.
"""
if (not new_state and not rewind_state) or (rewind_state and new_state):
FAIL(ERR_BAD_CONFIG, "You must specify either the boolean rewind_state or a new state to set learners to.")
LOG('Updating {} learners to {}'.format(len(learners), new_state))
try:
for learner in learners:
if rewind_state:
new_state = learner['last_state']['state_name']
config['LMS'].update_learner_retirement_state(
learner['original_username'],
new_state,
'Force updated via retirement_bulk_status_update script',
force=True
)
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_UPDATING, 'Unexpected error occurred updating users!', exc)
@click.command("update_statuses")
@click.option(
'--config_file',
help='YAML file that contains retirement-related configuration for this environment.'
)
@click.option(
'--initial_state',
help='Find learners in this retirement state. Use the state name ex: PENDING, COMPLETE'
)
@click.option(
'--new_state',
help='Set any found learners to this new state. Use the state name ex: PENDING, COMPLETE',
default=None
)
@click.option(
'--start_date',
callback=validate_dates,
help='(YYYY-MM-DD) Earliest creation date for retirements to act on.'
)
@click.option(
'--end_date',
callback=validate_dates,
help='(YYYY-MM-DD) Latest creation date for retirements to act on.'
)
@click.option(
'--rewind-state',
help='Rewinds to the last_state for learners. Useful for resetting ERRORED users',
default=False,
is_flag=True
)
def update_statuses(config_file, initial_state, new_state, start_date, end_date, rewind_state):
"""
Bulk-updates user retirement statuses which are in the specified state -and- retirement was
requested between a start date and end date.
"""
try:
LOG('Starting bulk update script: Config: {}'.format(config_file))
if not config_file:
FAIL(ERR_NO_CONFIG, 'No config file passed in.')
config = CONFIG_OR_EXIT(config_file)
SETUP_LMS_OR_EXIT(config)
learners = _fetch_learners_to_update_or_exit(config, start_date, end_date, initial_state)
_update_learners_or_exit(config, learners, new_state, rewind_state)
LOG('Bulk update complete')
except Exception as exc:
print(text_type(exc))
raise
if __name__ == '__main__':
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
update_statuses(auto_envvar_prefix='RETIREMENT')

View File

@@ -0,0 +1,404 @@
#! /usr/bin/env python3
# coding=utf-8
"""
Command-line script to drive the partner reporting part of the retirement process
"""
import logging
import os
import sys
import unicodedata
from collections import OrderedDict, defaultdict
from datetime import date
from functools import partial
import click
import unicodecsv as csv
from six import text_type
# Add top-level project path to sys.path before importing scripts code
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from scripts.user_retirement.utils.thirdparty_apis.google_api import DriveApi # pylint: disable=wrong-import-position
# pylint: disable=wrong-import-position
from scripts.user_retirement.utils.helpers import (
_config_with_drive_or_exit,
_fail,
_fail_exception,
_log,
_setup_lms_api_or_exit
)
# Return codes for various fail cases
ERR_SETUP_FAILED = -1
ERR_FETCHING_LEARNERS = -2
ERR_NO_CONFIG = -3
ERR_NO_SECRETS = -4
ERR_NO_OUTPUT_DIR = -5
ERR_BAD_CONFIG = -6
ERR_BAD_SECRETS = -7
ERR_UNKNOWN_ORG = -8
ERR_REPORTING = -9
ERR_DRIVE_UPLOAD = -10
ERR_CLEANUP = -11
ERR_DRIVE_LISTING = -12
SCRIPT_SHORTNAME = 'Partner report'
LOG = partial(_log, SCRIPT_SHORTNAME)
FAIL = partial(_fail, SCRIPT_SHORTNAME)
FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME)
CONFIG_WITH_DRIVE_OR_EXIT = partial(_config_with_drive_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG, ERR_BAD_SECRETS)
SETUP_LMS_OR_EXIT = partial(_setup_lms_api_or_exit, FAIL, ERR_SETUP_FAILED)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# Prefix which starts all generated report filenames.
REPORTING_FILENAME_PREFIX = 'user_retirement'
# We'll store the access token here once retrieved
AUTH_HEADER = {}
# This text template will be the comment body for all new CSV uploads. The
# following format variables need to be provided:
# tags: space delimited list of google user tags, e.g. "+user1@gmail.com +user2@gmail.com"
NOTIFICATION_MESSAGE_TEMPLATE = """
Hello from edX. Dear {tags}, a new report listing the learners enrolled in your institutions courses on edx.org that have requested deletion of their edX account and associated personal data within the last week has been published to Google Drive. Please access your folder to see the latest report.
""".strip()
LEARNER_CREATED_KEY = 'created' # This key is currently required to exist in the learner
LEARNER_ORIGINAL_USERNAME_KEY = 'original_username' # This key is currently required to exist in the learner
ORGS_KEY = 'orgs'
ORGS_CONFIG_KEY = 'orgs_config'
ORGS_CONFIG_ORG_KEY = 'org'
ORGS_CONFIG_FIELD_HEADINGS_KEY = 'field_headings'
ORGS_CONFIG_LEARNERS_KEY = 'learners'
# Default field headings for the CSV file
DEFAULT_FIELD_HEADINGS = ['user_id', 'original_username', 'original_email', 'original_name', 'deletion_completed']
def _check_all_learner_orgs_or_exit(config, learners):
"""
Checks all learners and their orgs, ensuring that each org has a mapping to a partner Drive folder.
If any orgs are missing a mapping, fails after printing the mismatched orgs.
"""
# Loop through all learner orgs, checking for their mappings.
mismatched_orgs = set()
for learner in learners:
# Check the orgs with standard fields
if ORGS_KEY in learner:
for org in learner[ORGS_KEY]:
if org not in config['org_partner_mapping']:
mismatched_orgs.add(org)
# Check the orgs with custom configurations (orgs with custom fields)
if ORGS_CONFIG_KEY in learner:
for org_config in learner[ORGS_CONFIG_KEY]:
org_name = org_config[ORGS_CONFIG_ORG_KEY]
if org_name not in config['org_partner_mapping']:
mismatched_orgs.add(org_name)
if mismatched_orgs:
FAIL(
ERR_UNKNOWN_ORG,
'Partners for organizations {} do not exist in configuration.'.format(text_type(mismatched_orgs))
)
def _get_orgs_and_learners_or_exit(config):
"""
Contacts LMS to get the list of learners to report on and the orgs they belong to.
Reformats them into dicts with keys of the orgs and lists of learners as the value
and returns a tuple of that dict plus a list of all of the learner usernames.
"""
try:
LOG('Retrieving all learners on which to report from the LMS.')
learners = config['LMS'].retirement_partner_report()
LOG('Retrieved {} learners from the LMS.'.format(len(learners)))
_check_all_learner_orgs_or_exit(config, learners)
orgs = defaultdict()
usernames = []
# Organize the learners, create separate dicts per partner, making sure each partner is in the mapping.
# Learners can appear in more than one dict. It is assumed that each org has 1 and only 1 set of field headings.
for learner in learners:
usernames.append({'original_username': learner[LEARNER_ORIGINAL_USERNAME_KEY]})
# Use the datetime upon which the record was 'created' in the partner reporting queue
# as the approximate time upon which user retirement was completed ('deletion_completed')
# for the record's user.
learner['deletion_completed'] = learner[LEARNER_CREATED_KEY]
# Create a list of orgs who should be notified about this user
if ORGS_KEY in learner:
for org_name in learner[ORGS_KEY]:
reporting_org_names = config['org_partner_mapping'][org_name]
_add_reporting_org(orgs, reporting_org_names, DEFAULT_FIELD_HEADINGS, learner)
# Check for orgs with custom fields
if ORGS_CONFIG_KEY in learner:
for org_config in learner[ORGS_CONFIG_KEY]:
org_name = org_config[ORGS_CONFIG_ORG_KEY]
org_headings = org_config[ORGS_CONFIG_FIELD_HEADINGS_KEY]
reporting_org_names = config['org_partner_mapping'][org_name]
_add_reporting_org(orgs, reporting_org_names, org_headings, learner)
return orgs, usernames
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_FETCHING_LEARNERS, 'Unexpected exception occurred!', exc)
def _add_reporting_org(orgs, org_names, org_headings, learner):
"""
Add the learner to the org
"""
for org_name in org_names:
# Create the org, if necessary
orgs[org_name] = orgs.get(
org_name,
{
ORGS_CONFIG_FIELD_HEADINGS_KEY: org_headings,
ORGS_CONFIG_LEARNERS_KEY: []
}
)
# Add the learner to the list of learners in the org
orgs[org_name][ORGS_CONFIG_LEARNERS_KEY].append(learner)
def _generate_report_files_or_exit(config, report_data, output_dir):
"""
Spins through the partners, creating a single CSV file for each
"""
# We'll store all of the partner to file links here so we can be sure all files generated successfully
# before trying to push to Google, minimizing the cases where we might have to overwrite files
# already up there.
partner_filenames = {}
for partner_name in report_data:
try:
partner = report_data[partner_name]
partner_headings = partner[ORGS_CONFIG_FIELD_HEADINGS_KEY]
partner_learners = partner[ORGS_CONFIG_LEARNERS_KEY]
outfile = _generate_report_file_or_exit(config, output_dir, partner_name, partner_headings,
partner_learners)
partner_filenames[partner_name] = outfile
LOG('Report complete for partner {}'.format(partner_name))
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_REPORTING, 'Error reporting retirement for partner {}'.format(partner_name), exc)
return partner_filenames
def _generate_report_file_or_exit(config, output_dir, partner, field_headings, field_values):
"""
Create a CSV file for the partner
"""
LOG('Starting report for partner {}: {} learners to add. Field headings are {}'.format(
partner,
len(field_values),
field_headings
))
outfile = os.path.join(output_dir, '{}_{}_{}_{}.csv'.format(
REPORTING_FILENAME_PREFIX, config['partner_report_platform_name'], partner, date.today().isoformat()
))
# If there is already a file for this date, assume it is bad and replace it
try:
os.remove(outfile)
except OSError:
pass
with open(outfile, 'wb') as f:
writer = csv.DictWriter(f, field_headings, dialect=csv.excel, extrasaction='ignore')
writer.writeheader()
writer.writerows(field_values)
return outfile
def _config_drive_folder_map_or_exit(config):
"""
Lists folders under our top level parent for this environment and returns
a dict of {partner name: folder id}. Partner names should match the values
in config['org_partner_mapping']
"""
drive = DriveApi(config['google_secrets_file'])
try:
LOG('Attempting to find all partner sub-directories on Drive.')
folders = drive.walk_files(
config['drive_partners_folder'],
mimetype='application/vnd.google-apps.folder',
recurse=False
)
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_DRIVE_LISTING, 'Finding partner directories on Drive failed.', exc)
if not folders:
FAIL(ERR_DRIVE_LISTING, 'Finding partner directories on Drive failed. Check your permissions.')
# As in _config_or_exit we force normalize the unicode here to make sure the keys
# match. Otherwise the name we get back from Google won't match what's in the YAML config.
config['partner_folder_mapping'] = OrderedDict()
for folder in folders:
folder['name'] = unicodedata.normalize('NFKC', text_type(folder['name']))
config['partner_folder_mapping'][folder['name']] = folder['id']
def _push_files_to_google(config, partner_filenames):
"""
Copy the file to Google drive for this partner
Returns:
List of file IDs for the uploaded csv files.
"""
# First make sure we have Drive folders for all partners
failed_partners = []
for partner in partner_filenames:
if partner not in config['partner_folder_mapping']:
failed_partners.append(partner)
if failed_partners:
FAIL(ERR_BAD_CONFIG, 'These partners have retiring learners, but no Drive folder: {}'.format(failed_partners))
file_ids = {}
drive = DriveApi(config['google_secrets_file'])
for partner in partner_filenames:
# This is populated on the fly in _config_drive_folder_map_or_exit
folder_id = config['partner_folder_mapping'][partner]
file_id = None
with open(partner_filenames[partner], 'rb') as f:
try:
drive_filename = os.path.basename(partner_filenames[partner])
LOG('Attempting to upload {} to {} Drive folder.'.format(drive_filename, partner))
file_id = drive.create_file_in_folder(folder_id, drive_filename, f, "text/csv")
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_DRIVE_UPLOAD, 'Drive upload failed for: {}'.format(drive_filename), exc)
file_ids[partner] = file_id
return file_ids
def _add_comments_to_files(config, file_ids):
"""
Add comments to the uploaded csv files, triggering email notification.
Args:
file_ids (dict): Mapping of partner names to Drive file IDs corresponding to the newly uploaded csv files.
"""
drive = DriveApi(config['google_secrets_file'])
partner_folders_to_permissions = drive.list_permissions_for_files(
config['partner_folder_mapping'].values(),
fields='emailAddress',
)
# create a mapping of partners to a list of permissions dicts:
permissions = {
partner: partner_folders_to_permissions[config['partner_folder_mapping'][partner]]
for partner in file_ids
}
# throw out all denied addresses, and flatten the permissions dicts to just the email:
external_emails = {
partner: [
perm['emailAddress']
for perm in permissions[partner]
if not any(
perm['emailAddress'].lower().endswith(denied_domain.lower())
for denied_domain in config['denied_notification_domains']
)
]
for partner in permissions
}
file_ids_and_comments = []
for partner in file_ids:
if not external_emails[partner]:
LOG(
'WARNING: could not find a POC for the following partner: "{}". '
'Double check the partner folder permissions in Google Drive.'
.format(partner)
)
else:
tag_string = ' '.join('+' + email for email in external_emails[partner])
comment_content = NOTIFICATION_MESSAGE_TEMPLATE.format(tags=tag_string)
file_ids_and_comments.append((file_ids[partner], comment_content))
try:
LOG('Adding notification comments to uploaded csv files.')
drive.create_comments_for_files(file_ids_and_comments)
except Exception as exc: # pylint: disable=broad-except
# do not fail the script here, since comment errors are non-critical
LOG('WARNING: there was an error adding Google Drive comments to the csv files: {}'.format(exc))
@click.command("generate_report")
@click.option(
'--config_file',
help='YAML file that contains retirement related configuration for this environment.'
)
@click.option(
'--google_secrets_file',
help='JSON file with Google service account credentials for uploading.'
)
@click.option(
'--output_dir',
help='The local directory that the script will write the reports to.'
)
@click.option(
'--comments/--no_comments',
default=True,
help='Do or skip adding notification comments to the reports.'
)
def generate_report(config_file, google_secrets_file, output_dir, comments):
"""
Retrieves a JWT token as the retirement service learner, then performs the reporting process as that user.
- Accepts the configuration file with all necessary credentials and URLs for a single environment
- Gets the users in the LMS reporting queue and the partners they need to be reported to
- Generates a single report per partner
- Pushes the reports to Google Drive
- On success tells LMS to remove the users who succeeded from the reporting queue
"""
LOG('Starting partner report using config file {} and Google config {}'.format(config_file, google_secrets_file))
try:
if not config_file:
FAIL(ERR_NO_CONFIG, 'No config file passed in.')
if not google_secrets_file:
FAIL(ERR_NO_SECRETS, 'No secrets file passed in.')
# The Jenkins DSL is supposed to create this path for us
if not output_dir or not os.path.exists(output_dir):
FAIL(ERR_NO_OUTPUT_DIR, 'No output_dir passed in or path does not exist.')
config = CONFIG_WITH_DRIVE_OR_EXIT(config_file, google_secrets_file)
SETUP_LMS_OR_EXIT(config)
_config_drive_folder_map_or_exit(config)
report_data, all_usernames = _get_orgs_and_learners_or_exit(config)
# If no usernames were returned, then no reports need to be generated.
if all_usernames:
partner_filenames = _generate_report_files_or_exit(config, report_data, output_dir)
# All files generated successfully, now push them to Google
report_file_ids = _push_files_to_google(config, partner_filenames)
if comments:
# All files uploaded successfully, now add comments to them to trigger notifications
_add_comments_to_files(config, report_file_ids)
# Success, tell LMS to remove these users from the queue
config['LMS'].retirement_partner_cleanup(all_usernames)
LOG('All reports completed and uploaded to Google.')
except Exception as exc: # pylint: disable=broad-except
FAIL_EXCEPTION(ERR_CLEANUP, 'Unexpected error occurred! Users may be stuck in the processing state!', exc)
if __name__ == '__main__':
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
generate_report(auto_envvar_prefix='RETIREMENT')

View File

@@ -0,0 +1,23 @@
from urllib.parse import urljoin
import responses
from scripts.user_retirement.utils import edx_api
FAKE_ACCESS_TOKEN = 'THIS_IS_A_JWT'
CONTENT_TYPE = 'application/json'
class OAuth2Mixin:
@staticmethod
def mock_access_token_response(status=200):
"""
Mock POST requests to retrieve an access token for this site's service user.
"""
responses.add(
responses.POST,
urljoin('http://localhost:18000/', edx_api.OAUTH_ACCESS_TOKEN_URL),
status=status,
json={'access_token': FAKE_ACCESS_TOKEN, 'expires_in': 60},
content_type=CONTENT_TYPE
)

View File

@@ -0,0 +1,166 @@
# coding=utf-8
"""
Common functionality for retirement related tests
"""
import json
import unicodedata
from datetime import datetime
import yaml
TEST_RETIREMENT_PIPELINE = [
['RETIRING_FORUMS', 'FORUMS_COMPLETE', 'LMS', 'retirement_retire_forum'],
['RETIRING_EMAIL_LISTS', 'EMAIL_LISTS_COMPLETE', 'LMS', 'retirement_retire_mailings'],
['RETIRING_ENROLLMENTS', 'ENROLLMENTS_COMPLETE', 'LMS', 'retirement_unenroll'],
['RETIRING_LMS', 'LMS_COMPLETE', 'LMS', 'retirement_lms_retire']
]
TEST_RETIREMENT_END_STATES = [state[1] for state in TEST_RETIREMENT_PIPELINE]
TEST_RETIREMENT_QUEUE_STATES = ['PENDING'] + TEST_RETIREMENT_END_STATES
TEST_RETIREMENT_STATE = 'PENDING'
FAKE_DATETIME_OBJECT = datetime(2022, 1, 1)
FAKE_DATETIME_STR = '2022-01-01'
FAKE_ORIGINAL_USERNAME = 'foo_username'
FAKE_USERNAMES = [FAKE_ORIGINAL_USERNAME, FAKE_ORIGINAL_USERNAME]
FAKE_RESPONSE_MESSAGE = 'fake response message'
FAKE_USERNAME_MAPPING = [
{"fake_current_username_1": "fake_desired_username_1"},
{"fake_current_username_2": "fake_desired_username_2"}
]
FAKE_ORGS = {
# Make sure unicode names, as they should come in from the yaml config, work
'org1': [unicodedata.normalize('NFKC', u'TéstX')],
'org2': ['Org2X'],
'org3': ['Org3X', 'Org4X'],
}
TEST_PLATFORM_NAME = 'fakename'
TEST_DENIED_NOTIFICATION_DOMAINS = {
'@edx.org',
'@partner-reporting-automation.iam.gserviceaccount.com',
}
def flatten_partner_list(partner_list):
"""
Flattens a list of lists into a list.
[["Org1X"], ["Org2X"], ["Org3X", "Org4X"]] => ["Org1X", "Org2X", "Org3X", "Org4X"]
"""
return [partner for sublist in partner_list for partner in sublist]
def fake_config_file(f, orgs=None, fetch_ecom_segment_id=False):
"""
Create a config file for a single test. Combined with CliRunner.isolated_filesystem() to
ensure the file lifetime is limited to the test. See _call_script for usage.
"""
if orgs is None:
orgs = FAKE_ORGS
config = {
'client_id': 'bogus id',
'client_secret': 'supersecret',
'base_urls': {
'credentials': 'https://credentials.stage.edx.invalid/',
'lms': 'https://stage-edx-edxapp.edx.invalid/',
'ecommerce': 'https://ecommerce.stage.edx.invalid/',
'segment': 'https://segment.invalid/graphql',
},
'retirement_pipeline': TEST_RETIREMENT_PIPELINE,
'partner_report_platform_name': TEST_PLATFORM_NAME,
'org_partner_mapping': orgs,
'drive_partners_folder': 'FakeDriveID',
'denied_notification_domains': TEST_DENIED_NOTIFICATION_DOMAINS,
'sailthru_key': 'fake_sailthru_key',
'sailthru_secret': 'fake_sailthru_secret',
's3_archive': {
'bucket_name': 'fake_test_bucket',
'region': 'fake_region',
},
'segment_workspace_slug': 'test_slug',
'segment_auth_token': 'fakeauthtoken',
}
if fetch_ecom_segment_id:
config['fetch_ecommerce_segment_id'] = True
yaml.safe_dump(config, f)
def get_fake_user_retirement(
retirement_id=1,
original_username="foo_username",
original_email="foo@edx.invalid",
original_name="Foo User",
retired_username="retired_user__asdf123",
retired_email="retired_user__asdf123",
ecommerce_segment_id="ecommerce-90",
user_id=9009,
current_username="foo_username",
current_email="foo@edx.invalid",
current_name="Foo User",
current_state_name="PENDING",
last_state_name="PENDING",
):
"""
Return a "learner" used in retirment in the serialized format we get from LMS.
"""
return {
"id": retirement_id,
"current_state": {
"id": 1,
"state_name": current_state_name,
"state_execution_order": 10,
},
"last_state": {
"id": 1,
"state_name": last_state_name,
"state_execution_order": 10,
},
"original_username": original_username,
"original_email": original_email,
"original_name": original_name,
"retired_username": retired_username,
"retired_email": retired_email,
"ecommerce_segment_id": ecommerce_segment_id,
"user": {
"id": user_id,
"username": current_username,
"email": current_email,
"profile": {
"id": 10009,
"name": current_name
}
},
"created": "2018-10-18T20:08:03.724805",
"modified": "2018-10-18T20:08:03.724805",
}
def fake_google_secrets_file(f):
"""
Create a fake google secrets file for a single test.
"""
fake_private_key = """
-----BEGIN PRIVATE KEY-----
-----END PRIVATE KEY-----
r"""
secrets = {
"type": "service_account",
"project_id": "partner-reporting-automation",
"private_key_id": "foo",
"private_key": fake_private_key,
"client_email": "bogus@serviceacct.invalid",
"client_id": "411",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://accounts.google.com/o/oauth2/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/foo"
}
json.dump(secrets, f)

View File

@@ -0,0 +1 @@
Upload this file on s3 in tests.

View File

@@ -0,0 +1,159 @@
"""
Test the get_learners_to_retire.py script
"""
import os
from click.testing import CliRunner
from mock import DEFAULT, patch
from requests.exceptions import HTTPError
from scripts.user_retirement.get_learners_to_retire import get_learners_to_retire
from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement
def _call_script(expected_user_files, cool_off_days=1, output_dir='test', user_count_error_threshold=200,
max_user_batch_size=201):
"""
Call the retired learner script with the given username and a generic, temporary config file.
Returns the CliRunner.invoke results
"""
runner = CliRunner()
with runner.isolated_filesystem():
with open('test_config.yml', 'w') as f:
fake_config_file(f)
result = runner.invoke(
get_learners_to_retire,
args=[
'--config_file', 'test_config.yml',
'--cool_off_days', cool_off_days,
'--output_dir', output_dir,
'--user_count_error_threshold', user_count_error_threshold,
'--max_user_batch_size', max_user_batch_size
]
)
print(result)
print(result.output)
# This is the number of users in the mocked call, each should have a file if the number is
# greater than 0, otherwise a failure is expected and the output dir should not exist
if expected_user_files:
assert len(os.listdir(output_dir)) == expected_user_files
else:
assert not os.path.exists(output_dir)
return result
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
learners_to_retire=DEFAULT
)
def test_success(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners_to_retire = kwargs['learners_to_retire']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_learners_to_retire.return_value = [
get_fake_user_retirement(original_username='test_user1'),
get_fake_user_retirement(original_username='test_user2'),
]
result = _call_script(2)
# Called once per API we instantiate (LMS, ECommerce, Credentials)
assert mock_get_access_token.call_count == 1
mock_get_learners_to_retire.assert_called_once()
assert result.exit_code == 0
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
learners_to_retire=DEFAULT
)
def test_lms_down(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners_to_retire = kwargs['learners_to_retire']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_learners_to_retire.side_effect = HTTPError
result = _call_script(0)
# Called once per API we instantiate (LMS, ECommerce, Credentials)
assert mock_get_access_token.call_count == 1
mock_get_learners_to_retire.assert_called_once()
assert result.exit_code == 1
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
learners_to_retire=DEFAULT
)
def test_misconfigured(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners_to_retire = kwargs['learners_to_retire']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_learners_to_retire.side_effect = HTTPError
result = _call_script(0)
# Called once per API we instantiate (LMS, ECommerce, Credentials)
assert mock_get_access_token.call_count == 1
mock_get_learners_to_retire.assert_called_once()
assert result.exit_code == 1
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
learners_to_retire=DEFAULT
)
def test_too_many_users(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners_to_retire = kwargs['learners_to_retire']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_learners_to_retire.return_value = [
get_fake_user_retirement(original_username='test_user1'),
get_fake_user_retirement(original_username='test_user2'),
]
result = _call_script(0, user_count_error_threshold=1)
# Called once per API we instantiate (LMS, ECommerce, Credentials)
assert mock_get_access_token.call_count == 1
mock_get_learners_to_retire.assert_called_once()
assert result.exit_code == -1
assert 'Too many learners' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
learners_to_retire=DEFAULT
)
def test_users_limit(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners_to_retire = kwargs['learners_to_retire']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_learners_to_retire.return_value = [
get_fake_user_retirement(original_username='test_user1'),
get_fake_user_retirement(original_username='test_user2'),
]
result = _call_script(1, user_count_error_threshold=200, max_user_batch_size=1)
# Called once per API we instantiate (LMS, ECommerce, Credentials)
assert mock_get_access_token.call_count == 1
mock_get_learners_to_retire.assert_called_once()
assert result.exit_code == 0

View File

@@ -0,0 +1,412 @@
"""
Test the retire_one_learner.py script
"""
from click.testing import CliRunner
from mock import DEFAULT, patch
from scripts.user_retirement.retire_one_learner import (
END_STATES,
ERR_BAD_CONFIG,
ERR_BAD_LEARNER,
ERR_SETUP_FAILED,
ERR_UNKNOWN_STATE,
ERR_USER_AT_END_STATE,
ERR_USER_IN_WORKING_STATE,
retire_learner
)
from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement
from scripts.user_retirement.utils.exception import HttpDoesNotExistException
def _call_script(username, fetch_ecom_segment_id=False):
"""
Call the retired learner script with the given username and a generic, temporary config file.
Returns the CliRunner.invoke results
"""
runner = CliRunner()
with runner.isolated_filesystem():
with open('test_config.yml', 'w') as f:
fake_config_file(f, fetch_ecom_segment_id=fetch_ecom_segment_id)
result = runner.invoke(retire_learner, args=['--username', username, '--config_file', 'test_config.yml'])
print(result)
print(result.output)
return result
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT,
retirement_retire_forum=DEFAULT,
retirement_retire_mailings=DEFAULT,
retirement_unenroll=DEFAULT,
retirement_lms_retire=DEFAULT
)
def test_successful_retirement(*args, **kwargs):
username = 'test_username'
mock_get_access_token = args[1]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_retire_forum = kwargs['retirement_retire_forum']
mock_retire_mailings = kwargs['retirement_retire_mailings']
mock_unenroll = kwargs['retirement_unenroll']
mock_lms_retire = kwargs['retirement_lms_retire']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_retirement_state.return_value = get_fake_user_retirement(original_username=username)
result = _call_script(username, fetch_ecom_segment_id=True)
# Called once per API we instantiate (LMS, ECommerce, Credentials)
assert mock_get_access_token.call_count == 3
mock_get_retirement_state.assert_called_once_with(username)
assert mock_update_learner_state.call_count == 9
# Called once per retirement
for mock_call in (
mock_retire_forum,
mock_retire_mailings,
mock_unenroll,
mock_lms_retire
):
mock_call.assert_called_once_with(mock_get_retirement_state.return_value)
assert result.exit_code == 0
assert 'Retirement complete' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_user_does_not_exist(*args, **kwargs):
username = 'test_username'
mock_get_access_token = args[0]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_retirement_state.side_effect = Exception
result = _call_script(username)
assert mock_get_access_token.call_count == 3
mock_get_retirement_state.assert_called_once_with(username)
mock_update_learner_state.assert_not_called()
assert result.exit_code == ERR_SETUP_FAILED
assert 'Exception' in result.output
def test_bad_config():
username = 'test_username'
runner = CliRunner()
result = runner.invoke(retire_learner, args=['--username', username, '--config_file', 'does_not_exist.yml'])
assert result.exit_code == ERR_BAD_CONFIG
assert 'does_not_exist.yml' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_bad_learner(*args, **kwargs):
username = 'test_username'
mock_get_access_token = args[0]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
# Broken API call, no state returned
mock_get_retirement_state.side_effect = HttpDoesNotExistException
result = _call_script(username)
assert mock_get_access_token.call_count == 3
mock_get_retirement_state.assert_called_once_with(username)
mock_update_learner_state.assert_not_called()
assert result.exit_code == ERR_BAD_LEARNER
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_user_in_working_state(*args, **kwargs):
username = 'test_username'
mock_get_access_token = args[0]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_retirement_state.return_value = get_fake_user_retirement(
original_username=username,
current_state_name='RETIRING_FORUMS'
)
result = _call_script(username)
assert mock_get_access_token.call_count == 3
mock_get_retirement_state.assert_called_once_with(username)
mock_update_learner_state.assert_not_called()
assert result.exit_code == ERR_USER_IN_WORKING_STATE
assert 'in a working state' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_user_in_bad_state(*args, **kwargs):
username = 'test_username'
bad_state = 'BOGUS_STATE'
mock_get_access_token = args[0]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_retirement_state.return_value = get_fake_user_retirement(
original_username=username,
current_state_name=bad_state
)
result = _call_script(username)
assert mock_get_access_token.call_count == 3
mock_get_retirement_state.assert_called_once_with(username)
mock_update_learner_state.assert_not_called()
assert result.exit_code == ERR_UNKNOWN_STATE
assert bad_state in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_user_in_end_state(*args, **kwargs):
username = 'test_username'
mock_get_access_token = args[0]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
# pytest.parameterize doesn't play nicely with patch.multiple, this seemed more
# readable than the alternatives.
for end_state in END_STATES:
mock_get_retirement_state.return_value = {
'original_username': username,
'current_state': {
'state_name': end_state
}
}
result = _call_script(username)
assert mock_get_access_token.call_count == 3
mock_get_retirement_state.assert_called_once_with(username)
mock_update_learner_state.assert_not_called()
assert result.exit_code == ERR_USER_AT_END_STATE
assert end_state in result.output
# Reset our call counts for the next test
mock_get_access_token.reset_mock()
mock_get_retirement_state.reset_mock()
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT,
retirement_retire_forum=DEFAULT,
retirement_retire_mailings=DEFAULT,
retirement_unenroll=DEFAULT,
retirement_lms_retire=DEFAULT
)
def test_skipping_states(*args, **kwargs):
username = 'test_username'
mock_get_access_token = args[0]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_retire_forum = kwargs['retirement_retire_forum']
mock_retire_mailings = kwargs['retirement_retire_mailings']
mock_unenroll = kwargs['retirement_unenroll']
mock_lms_retire = kwargs['retirement_lms_retire']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_retirement_state.return_value = get_fake_user_retirement(
original_username=username,
current_state_name='EMAIL_LISTS_COMPLETE'
)
result = _call_script(username)
# Called once per API we instantiate (LMS, ECommerce, Credentials)
assert mock_get_access_token.call_count == 3
mock_get_retirement_state.assert_called_once_with(username)
assert mock_update_learner_state.call_count == 5
# Skipped
for mock_call in (
mock_retire_forum,
mock_retire_mailings
):
mock_call.assert_not_called()
# Called once per retirement
for mock_call in (
mock_unenroll,
mock_lms_retire
):
mock_call.assert_called_once_with(mock_get_retirement_state.return_value)
assert result.exit_code == 0
for required_output in (
'RETIRING_FORUMS completed in previous run',
'RETIRING_EMAIL_LISTS completed in previous run',
'Starting state RETIRING_ENROLLMENTS',
'State RETIRING_ENROLLMENTS completed',
'Starting state RETIRING_LMS',
'State RETIRING_LMS completed',
'Retirement complete'
):
assert required_output in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT,
retirement_retire_forum=DEFAULT,
retirement_retire_mailings=DEFAULT,
retirement_unenroll=DEFAULT,
retirement_lms_retire=DEFAULT
)
def test_get_segment_id_success(*args, **kwargs):
username = 'test_username'
mock_get_tracking_key = args[0]
mock_get_access_token = args[1]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_retirement_retire_forum = kwargs['retirement_retire_forum']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_tracking_key.return_value = {'id': 1, 'ecommerce_tracking_id': 'ecommerce-1'}
# The learner starts off with these values, 'ecommerce_segment_id' is added during script
# startup
mock_get_retirement_state.return_value = get_fake_user_retirement(
original_username=username,
)
_call_script(username, fetch_ecom_segment_id=True)
mock_get_tracking_key.assert_called_once_with(mock_get_retirement_state.return_value)
config_after_get_segment_id = mock_get_retirement_state.return_value
config_after_get_segment_id['ecommerce_segment_id'] = 'ecommerce-1'
mock_retirement_retire_forum.assert_called_once_with(config_after_get_segment_id)
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT,
retirement_retire_forum=DEFAULT,
retirement_retire_mailings=DEFAULT,
retirement_unenroll=DEFAULT,
retirement_lms_retire=DEFAULT
)
def test_get_segment_id_not_found(*args, **kwargs):
username = 'test_username'
mock_get_tracking_key = args[0]
mock_get_access_token = args[1]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_get_tracking_key.side_effect = HttpDoesNotExistException('{} not found'.format(username))
mock_get_retirement_state.return_value = get_fake_user_retirement(
original_username=username,
)
result = _call_script(username, fetch_ecom_segment_id=True)
mock_get_tracking_key.assert_called_once_with(mock_get_retirement_state.return_value)
assert 'Setting Ecommerce Segment ID to None' in result.output
# Reset our call counts for the next test
mock_get_access_token.reset_mock()
mock_get_retirement_state.reset_mock()
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learner_retirement_state=DEFAULT,
update_learner_retirement_state=DEFAULT,
retirement_retire_forum=DEFAULT,
retirement_retire_mailings=DEFAULT,
retirement_unenroll=DEFAULT,
retirement_lms_retire=DEFAULT
)
def test_get_segment_id_error(*args, **kwargs):
username = 'test_username'
mock_get_tracking_key = args[0]
mock_get_access_token = args[1]
mock_get_retirement_state = kwargs['get_learner_retirement_state']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
test_exception_message = 'Test Exception!'
mock_get_tracking_key.side_effect = Exception(test_exception_message)
mock_get_retirement_state.return_value = get_fake_user_retirement(
original_username=username,
)
mock_get_retirement_state.return_value = {
'original_username': username,
'current_state': {
'state_name': 'PENDING'
}
}
result = _call_script(username, fetch_ecom_segment_id=True)
mock_get_tracking_key.assert_called_once_with(mock_get_retirement_state.return_value)
mock_update_learner_state.assert_not_called()
assert result.exit_code == ERR_SETUP_FAILED
assert 'Unexpected error fetching Ecommerce tracking id!' in result.output
assert test_exception_message in result.output

View File

@@ -0,0 +1,277 @@
"""
Test the retirement_archive_and_cleanup.py script
"""
import datetime
import os
import boto3
import pytest
from botocore.exceptions import ClientError
from click.testing import CliRunner
from mock import DEFAULT, call, patch
from moto import mock_ec2, mock_s3
from scripts.user_retirement.retirement_archive_and_cleanup import (
ERR_ARCHIVING,
ERR_BAD_CLI_PARAM,
ERR_BAD_CONFIG,
ERR_DELETING,
ERR_FETCHING,
ERR_NO_CONFIG,
ERR_SETUP_FAILED,
_upload_to_s3,
archive_and_cleanup
)
from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement
FAKE_BUCKET_NAME = "fake_test_bucket"
def _call_script(cool_off_days=37, batch_size=None, dry_run=None, start_date=None, end_date=None):
"""
Call the archive script with the given params and a generic config file.
Returns the CliRunner.invoke results
"""
runner = CliRunner()
with runner.isolated_filesystem():
with open('test_config.yml', 'w') as f:
fake_config_file(f)
base_args = [
'--config_file', 'test_config.yml',
'--cool_off_days', cool_off_days,
]
if batch_size:
base_args += ['--batch_size', batch_size]
if dry_run:
base_args += ['--dry_run', dry_run]
if start_date:
base_args += ['--start_date', start_date]
if end_date:
base_args += ['--end_date', end_date]
result = runner.invoke(archive_and_cleanup, args=base_args)
print(result)
print(result.output)
return result
def _fake_learner(ordinal):
"""
Creates a simple fake learner
"""
return get_fake_user_retirement(
user_id=ordinal,
original_username='test{}'.format(ordinal),
original_email='test{}@edx.invalid'.format(ordinal),
original_name='test {}'.format(ordinal),
retired_username='retired_{}'.format(ordinal),
retired_email='retired_test{}@edx.invalid'.format(ordinal),
last_state_name='COMPLETE'
)
def fake_learners_to_retire():
"""
A simple hard-coded list of fake learners
"""
return [
_fake_learner(1),
_fake_learner(2),
_fake_learner(3)
]
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learners_by_date_and_status=DEFAULT,
bulk_cleanup_retirements=DEFAULT
)
@mock_s3
def test_successful(*args, **kwargs):
conn = boto3.resource('s3')
conn.create_bucket(Bucket=FAKE_BUCKET_NAME)
mock_get_access_token = args[0]
mock_get_learners = kwargs['get_learners_by_date_and_status']
mock_bulk_cleanup_retirements = kwargs['bulk_cleanup_retirements']
mock_get_learners.return_value = fake_learners_to_retire()
result = _call_script()
# Called once to get the LMS token
assert mock_get_access_token.call_count == 1
mock_get_learners.assert_called_once()
mock_bulk_cleanup_retirements.assert_called_once_with(
['test1', 'test2', 'test3'])
assert result.exit_code == 0
assert 'Archive and cleanup complete' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learners_by_date_and_status=DEFAULT,
bulk_cleanup_retirements=DEFAULT
)
@mock_ec2
@mock_s3
def test_successful_with_batching(*args, **kwargs):
conn = boto3.resource('s3')
conn.create_bucket(Bucket=FAKE_BUCKET_NAME)
mock_get_access_token = args[0]
mock_get_learners = kwargs['get_learners_by_date_and_status']
mock_bulk_cleanup_retirements = kwargs['bulk_cleanup_retirements']
mock_get_learners.return_value = fake_learners_to_retire()
result = _call_script(batch_size=2)
# Called once to get the LMS token
assert mock_get_access_token.call_count == 1
mock_get_learners.assert_called_once()
get_learner_calls = [call(['test1', 'test2']), call(['test3'])]
mock_bulk_cleanup_retirements.assert_has_calls(get_learner_calls)
assert result.exit_code == 0
assert 'Archive and cleanup complete for batch #1' in result.output
assert 'Archive and cleanup complete for batch #2' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learners_by_date_and_status=DEFAULT,
bulk_cleanup_retirements=DEFAULT
)
@mock_s3
def test_successful_dry_run(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners = kwargs['get_learners_by_date_and_status']
mock_bulk_cleanup_retirements = kwargs['bulk_cleanup_retirements']
mock_get_learners.return_value = fake_learners_to_retire()
result = _call_script(dry_run=True)
# Called once to get the LMS token
assert mock_get_access_token.call_count == 1
mock_get_learners.assert_called_once()
mock_bulk_cleanup_retirements.assert_not_called()
assert result.exit_code == 0
assert 'Dry run. Skipping the step to upload data to' in result.output
assert 'This is a dry-run. Exiting before any retirements are cleaned up' in result.output
def test_no_config():
runner = CliRunner()
result = runner.invoke(
archive_and_cleanup,
args=[
'--cool_off_days', 37
]
)
assert result.exit_code == ERR_NO_CONFIG
assert 'No config file passed in.' in result.output
def test_bad_config():
runner = CliRunner()
result = runner.invoke(
archive_and_cleanup,
args=[
'--config_file', 'does_not_exist.yml',
'--cool_off_days', 37
]
)
assert result.exit_code == ERR_BAD_CONFIG
assert 'does_not_exist.yml' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch('scripts.user_retirement.utils.edx_api.LmsApi.__init__', side_effect=Exception)
def test_setup_failed(*_):
result = _call_script()
assert result.exit_code == ERR_SETUP_FAILED
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status', side_effect=Exception)
def test_bad_fetch(*_):
result = _call_script()
assert result.exit_code == ERR_FETCHING
assert 'Unexpected error occurred fetching users to update!' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status',
return_value=fake_learners_to_retire())
@patch('scripts.user_retirement.utils.edx_api.LmsApi.bulk_cleanup_retirements', side_effect=Exception)
@patch('scripts.user_retirement.retirement_archive_and_cleanup._upload_to_s3')
def test_bad_lms_deletion(*_):
result = _call_script()
assert result.exit_code == ERR_DELETING
assert 'Unexpected error occurred deleting retirements!' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status',
return_value=fake_learners_to_retire())
@patch('scripts.user_retirement.utils.edx_api.LmsApi.bulk_cleanup_retirements')
@patch('scripts.user_retirement.retirement_archive_and_cleanup._upload_to_s3', side_effect=Exception)
def test_bad_s3_upload(*_):
result = _call_script()
assert result.exit_code == ERR_ARCHIVING
assert 'Unexpected error occurred archiving retirements!' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
def test_conflicting_dates(*_):
result = _call_script(start_date=datetime.datetime(
2021, 10, 10), end_date=datetime.datetime(2018, 10, 10))
assert result.exit_code == ERR_BAD_CLI_PARAM
assert 'Conflicting start and end dates passed on CLI' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch(
'scripts.user_retirement.retirement_archive_and_cleanup._get_utc_now',
return_value=datetime.datetime(2021, 2, 2, 0, 0)
)
def test_conflicting_cool_off_date(*_):
result = _call_script(
cool_off_days=10,
start_date=datetime.datetime(2021, 1, 1), end_date=datetime.datetime(2021, 2, 1)
)
assert result.exit_code == ERR_BAD_CLI_PARAM
assert 'End date cannot occur within the cool_off_days period' in result.output
@mock_s3
def test_s3_upload_data():
"""
Test case to verify s3 upload and download.
"""
s3 = boto3.client("s3")
s3.create_bucket(Bucket=FAKE_BUCKET_NAME)
config = {'s3_archive': {'bucket_name': FAKE_BUCKET_NAME}}
filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data', 'uploading.txt')
key = 'raw/' + datetime.datetime.now().strftime('%Y/%m/') + filename
# first try dry run without uploading. Try to get object should raise error
with pytest.raises(ClientError) as exc_info:
_upload_to_s3(config, filename, True)
s3.get_object(Bucket=FAKE_BUCKET_NAME, Key=key)
assert exc_info.value.response['Error']['Code'] == 'NoSuchKey'
# upload a file, download and compare its content.
_upload_to_s3(config, filename, False)
resp = s3.get_object(Bucket=FAKE_BUCKET_NAME, Key=key)
data = resp["Body"].read()
assert data.decode() == "Upload this file on s3 in tests."

View File

@@ -0,0 +1,182 @@
"""
Test the retirement_bulk_status_update.py script
"""
from click.testing import CliRunner
from mock import DEFAULT, patch
from scripts.user_retirement.retirement_bulk_status_update import (
ERR_BAD_CONFIG,
ERR_FETCHING,
ERR_NO_CONFIG,
ERR_SETUP_FAILED,
ERR_UPDATING,
update_statuses
)
from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement
def _call_script(initial_state='COMPLETE', new_state='PENDING', start_date='2018-01-01', end_date='2018-01-15',
rewind_state=False):
"""
Call the bulk update statuses script with the given params and a generic config file.
Returns the CliRunner.invoke results
"""
runner = CliRunner()
with runner.isolated_filesystem():
with open('test_config.yml', 'w') as f:
fake_config_file(f)
args = [
'--config_file', 'test_config.yml',
'--initial_state', initial_state,
'--start_date', start_date,
'--end_date', end_date
]
args.extend(['--new_state', new_state]) if new_state else None
args.append('--rewind-state') if rewind_state else None
result = runner.invoke(
update_statuses,
args=args
)
print(result)
print(result.output)
return result
def fake_learners_to_retire(**overrides):
"""
A simple hard-coded list of fake learners with the only piece of
information this script cares about.
"""
return [
get_fake_user_retirement(**{"original_username": "user1", **overrides}),
get_fake_user_retirement(**{"original_username": "user2", **overrides}),
get_fake_user_retirement(**{"original_username": "user3", **overrides}),
]
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learners_by_date_and_status=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_successful_update(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners = kwargs['get_learners_by_date_and_status']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_learners.return_value = fake_learners_to_retire()
result = _call_script()
# Called once to get the LMS token
assert mock_get_access_token.call_count == 1
mock_get_learners.assert_called_once()
assert mock_update_learner_state.call_count == 3
assert result.exit_code == 0
assert 'Bulk update complete' in result.output
def test_no_config():
runner = CliRunner()
result = runner.invoke(
update_statuses,
args=[
'--initial_state', 'COMPLETE',
'--new_state', 'PENDING',
'--start_date', '2018-01-01',
'--end_date', '2018-01-15'
]
)
assert result.exit_code == ERR_NO_CONFIG
assert 'No config file passed in.' in result.output
def test_bad_config():
runner = CliRunner()
result = runner.invoke(
update_statuses,
args=[
'--config_file', 'does_not_exist.yml',
'--initial_state', 'COMPLETE',
'--new_state', 'PENDING',
'--start_date', '2018-01-01',
'--end_date', '2018-01-15'
]
)
assert result.exit_code == ERR_BAD_CONFIG
assert 'does_not_exist.yml' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learners_by_date_and_status=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_successful_rewind(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners = kwargs['get_learners_by_date_and_status']
mock_update_learner_state = kwargs['update_learner_retirement_state']
mock_get_learners.return_value = fake_learners_to_retire(current_state_name='ERRORED')
result = _call_script(new_state=None, rewind_state=True)
# Called once to get the LMS token
assert mock_get_access_token.call_count == 1
mock_get_learners.assert_called_once()
assert mock_update_learner_state.call_count == 3
assert result.exit_code == 0
assert 'Bulk update complete' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
get_learners_by_date_and_status=DEFAULT,
update_learner_retirement_state=DEFAULT
)
def test_rewind_bad_args(*args, **kwargs):
mock_get_access_token = args[0]
mock_get_learners = kwargs['get_learners_by_date_and_status']
mock_get_learners.return_value = fake_learners_to_retire(current_state_name='ERRORED')
result = _call_script(rewind_state=True)
# Called once to get the LMS token
assert mock_get_access_token.call_count == 1
mock_get_learners.assert_called_once()
assert result.exit_code == ERR_BAD_CONFIG
assert 'boolean rewind_state or a new state to set learners to' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch('scripts.user_retirement.utils.edx_api.LmsApi.__init__', side_effect=Exception)
def test_setup_failed(*_):
result = _call_script()
assert result.exit_code == ERR_SETUP_FAILED
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status', side_effect=Exception)
def test_bad_fetch(*_):
result = _call_script()
assert result.exit_code == ERR_FETCHING
assert 'Unexpected error occurred fetching users to update!' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None))
@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status',
return_value=fake_learners_to_retire())
@patch('scripts.user_retirement.utils.edx_api.LmsApi.update_learner_retirement_state', side_effect=Exception)
def test_bad_update(*_):
result = _call_script()
assert result.exit_code == ERR_UPDATING
assert 'Unexpected error occurred updating users!' in result.output

View File

@@ -0,0 +1,818 @@
# coding=utf-8
"""
Test the retire_one_learner.py script
"""
import csv
import os
import time
import unicodedata
from datetime import date
from click.testing import CliRunner
from mock import DEFAULT, patch
from six import PY2, itervalues
from scripts.user_retirement.retirement_partner_report import \
_generate_report_files_or_exit # pylint: disable=protected-access
from scripts.user_retirement.retirement_partner_report import \
_get_orgs_and_learners_or_exit # pylint: disable=protected-access
from scripts.user_retirement.retirement_partner_report import (
DEFAULT_FIELD_HEADINGS,
ERR_BAD_CONFIG,
ERR_BAD_SECRETS,
ERR_CLEANUP,
ERR_DRIVE_LISTING,
ERR_FETCHING_LEARNERS,
ERR_NO_CONFIG,
ERR_NO_OUTPUT_DIR,
ERR_NO_SECRETS,
ERR_REPORTING,
ERR_SETUP_FAILED,
ERR_UNKNOWN_ORG,
LEARNER_CREATED_KEY,
LEARNER_ORIGINAL_USERNAME_KEY,
ORGS_CONFIG_FIELD_HEADINGS_KEY,
ORGS_CONFIG_KEY,
ORGS_CONFIG_LEARNERS_KEY,
ORGS_CONFIG_ORG_KEY,
ORGS_KEY,
REPORTING_FILENAME_PREFIX,
SETUP_LMS_OR_EXIT,
generate_report
)
from scripts.user_retirement.tests.retirement_helpers import (
FAKE_ORGS,
TEST_PLATFORM_NAME,
fake_config_file,
fake_google_secrets_file,
flatten_partner_list
)
TEST_CONFIG_YML_NAME = 'test_config.yml'
TEST_GOOGLE_SECRETS_FILENAME = 'test_google_secrets.json'
DELETION_TIME = time.strftime("%Y-%m-%dT%H:%M:%S")
UNICODE_NAME_CONSTANT = '阿碧'
USER_ID = '12345'
TEST_ORGS_CONFIG = [
{
ORGS_CONFIG_ORG_KEY: 'orgCustom',
ORGS_CONFIG_FIELD_HEADINGS_KEY: ['heading_1', 'heading_2', 'heading_3']
},
{
ORGS_CONFIG_ORG_KEY: 'otherCustomOrg',
ORGS_CONFIG_FIELD_HEADINGS_KEY: ['unique_id']
}
]
DEFAULT_FIELD_VALUES = {
'user_id': USER_ID,
LEARNER_ORIGINAL_USERNAME_KEY: 'username',
'original_email': 'invalid',
'original_name': UNICODE_NAME_CONSTANT,
'deletion_completed': DELETION_TIME
}
def _call_script(expect_success=True, expected_num_rows=10, config_orgs=None, expected_fields=None):
"""
Call the retired learner script with the given username and a generic, temporary config file.
Returns the CliRunner.invoke results
"""
if expected_fields is None:
expected_fields = DEFAULT_FIELD_VALUES
if config_orgs is None:
config_orgs = FAKE_ORGS
runner = CliRunner()
with runner.isolated_filesystem():
with open(TEST_CONFIG_YML_NAME, 'w') as config_f:
fake_config_file(config_f, config_orgs)
with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as secrets_f:
fake_google_secrets_file(secrets_f)
tmp_output_dir = 'test_output_dir'
os.mkdir(tmp_output_dir)
result = runner.invoke(
generate_report,
args=[
'--config_file',
TEST_CONFIG_YML_NAME,
'--google_secrets_file',
TEST_GOOGLE_SECRETS_FILENAME,
'--output_dir',
tmp_output_dir
]
)
print(result)
print(result.output)
if expect_success:
assert result.exit_code == 0
if config_orgs is None:
# These are the orgs
config_org_vals = flatten_partner_list(FAKE_ORGS.values())
else:
config_org_vals = flatten_partner_list(config_orgs.values())
# Normalize the unicode as the script does
if PY2:
config_org_vals = [org.decode('utf-8') for org in config_org_vals]
config_org_vals = [unicodedata.normalize('NFKC', org) for org in config_org_vals]
for org in config_org_vals:
outfile = os.path.join(tmp_output_dir, '{}_{}_{}_{}.csv'.format(
REPORTING_FILENAME_PREFIX, TEST_PLATFORM_NAME, org, date.today().isoformat()
))
with open(outfile, 'r') as csvfile:
reader = csv.DictReader(csvfile)
rows = []
for row in reader:
for field_key in expected_fields:
field_value = expected_fields[field_key]
assert field_value in row[field_key]
rows.append(row)
# Confirm the number of rows
assert len(rows) == expected_num_rows
return result
def _fake_retirement_report_user(seed_val, user_orgs=None, user_orgs_config=None):
"""
Creates unique user to populate a fake report with.
- seed_val is a number or other unique value for this user, will be formatted into
user values to make sure they're distinct.
- user_orgs, if given, should be a list of orgs that will be associated with the user.
- user_orgs_config, if given, should be a list of dicts mapping orgs to their customized
field headings. These orgs will also be associated with the user.
"""
user_info = {
'user_id': USER_ID,
LEARNER_ORIGINAL_USERNAME_KEY: 'username_{}'.format(seed_val),
'original_email': 'user_{}@foo.invalid'.format(seed_val),
'original_name': '{} {}'.format(UNICODE_NAME_CONSTANT, seed_val),
LEARNER_CREATED_KEY: DELETION_TIME,
}
if user_orgs is not None:
user_info[ORGS_KEY] = user_orgs
if user_orgs_config is not None:
user_info[ORGS_CONFIG_KEY] = user_orgs_config
return user_info
def _fake_retirement_report(num_users=10, user_orgs=None, user_orgs_config=None):
"""
Fake the output of a retirement report with unique users
"""
return [_fake_retirement_report_user(i, user_orgs, user_orgs_config) for i in range(num_users)]
@patch('scripts.user_retirement.utils.edx_api.LmsApi.retirement_partner_report')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
def test_report_generation_multiple_partners(*args, **kwargs):
mock_get_access_token = args[0]
mock_retirement_report = args[1]
org_1_users = [_fake_retirement_report_user(i, user_orgs=['org1']) for i in range(1, 3)]
org_2_users = [_fake_retirement_report_user(i, user_orgs=['org2']) for i in range(3, 5)]
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_retirement_report.return_value = org_1_users + org_2_users
config = {
'client_id': 'bogus id',
'client_secret': 'supersecret',
'base_urls': {
'lms': 'https://stage-edx-edxapp.edx.invalid/',
},
'org_partner_mapping': {
'org1': ['Org1X'],
'org2': ['Org2X', 'Org2Xb'],
}
}
SETUP_LMS_OR_EXIT(config)
orgs, usernames = _get_orgs_and_learners_or_exit(config)
assert usernames == [{'original_username': 'username_{}'.format(username)} for username in range(1, 5)]
def _get_learner_usernames(org_data):
return [learner['original_username'] for learner in org_data['learners']]
assert _get_learner_usernames(orgs['Org1X']) == ['username_1', 'username_2']
# Org2X and Org2Xb should have the same learners in their report data
assert _get_learner_usernames(orgs['Org2X']) == _get_learner_usernames(orgs['Org2Xb']) == ['username_3',
'username_4']
# Org2X and Org2Xb report data should match
assert orgs['Org2X'] == orgs['Org2Xb']
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_partner_report=DEFAULT,
retirement_partner_cleanup=DEFAULT
)
def test_successful_report(*args, **kwargs):
mock_get_access_token = args[0]
mock_create_comments = args[1]
mock_list_permissions = args[2]
mock_walk_files = args[3]
mock_create_files = args[4]
mock_driveapi = args[5]
mock_retirement_report = kwargs['retirement_partner_report']
mock_retirement_cleanup = kwargs['retirement_partner_cleanup']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_create_comments.return_value = None
fake_partners = list(itervalues(FAKE_ORGS))
# Generate the list_permissions return value.
# The first few have POCs.
mock_list_permissions.return_value = {
'folder' + partner: [
{'emailAddress': 'some.contact@example.com'}, # The POC.
{'emailAddress': 'another.contact@edx.org'},
]
for partner in flatten_partner_list(fake_partners[:2])
}
# The last one does not have any POCs.
mock_list_permissions.return_value.update({
'folder' + partner: [
{'emailAddress': 'another.contact@edx.org'},
]
for partner in fake_partners[2]
})
mock_walk_files.return_value = [{'name': partner, 'id': 'folder' + partner} for partner in
flatten_partner_list(FAKE_ORGS.values())]
mock_create_files.side_effect = ['foo', 'bar', 'baz', 'qux']
mock_driveapi.return_value = None
mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys()))
result = _call_script()
# Make sure we're getting the LMS token
mock_get_access_token.assert_called_once()
# Make sure that we get the report
mock_retirement_report.assert_called_once()
# Make sure we tried to upload the files
assert mock_create_files.call_count == 4
# Make sure we tried to add comments to the files
assert mock_create_comments.call_count == 1
# First [0] returns all positional args, second [0] gets the first positional arg.
create_comments_file_ids, create_comments_messages = zip(*mock_create_comments.call_args[0][0])
assert set(create_comments_file_ids).issubset(set(['foo', 'bar', 'baz', 'qux']))
assert len(create_comments_file_ids) == 2 # only two comments created, the third didn't have a POC.
assert all('+some.contact@example.com' in msg for msg in create_comments_messages)
assert all('+another.contact@edx.org' not in msg for msg in create_comments_messages)
assert 'WARNING: could not find a POC' in result.output
# Make sure we tried to remove the users from the queue
mock_retirement_cleanup.assert_called_with(
[{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value]
)
assert 'All reports completed and uploaded to Google.' in result.output
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_partner_report=DEFAULT,
retirement_partner_cleanup=DEFAULT
)
def test_successful_report_org_config(*args, **kwargs):
mock_get_access_token = args[0]
mock_create_comments = args[1]
mock_list_permissions = args[2]
mock_walk_files = args[3]
mock_create_files = args[4]
mock_driveapi = args[5]
mock_retirement_report = kwargs['retirement_partner_report']
mock_retirement_cleanup = kwargs['retirement_partner_cleanup']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_create_comments.return_value = None
fake_custom_orgs = {
'orgCustom': ['firstBlah']
}
fake_partners = list(itervalues(fake_custom_orgs))
mock_list_permissions.return_value = {
'folder' + partner: [
{'emailAddress': 'some.contact@example.com'}, # The POC.
{'emailAddress': 'another.contact@edx.org'},
]
for partner in flatten_partner_list(fake_partners[:2])
}
mock_walk_files.return_value = [{'name': partner, 'id': 'folder' + partner} for partner in
flatten_partner_list(fake_custom_orgs.values())]
mock_create_files.side_effect = ['foo', 'bar', 'baz']
mock_driveapi.return_value = None
expected_num_users = 1
orgs_config = [
{
ORGS_CONFIG_ORG_KEY: 'orgCustom',
ORGS_CONFIG_FIELD_HEADINGS_KEY: ['heading_1', 'heading_2', 'heading_3']
}
]
# Input from the LMS
report_data = [
{
'heading_1': 'h1val',
'heading_2': 'h2val',
'heading_3': 'h3val',
LEARNER_ORIGINAL_USERNAME_KEY: 'blah',
LEARNER_CREATED_KEY: DELETION_TIME,
ORGS_CONFIG_KEY: orgs_config
}
]
# Resulting csv file content
expected_fields = {
'heading_1': 'h1val',
'heading_2': 'h2val',
'heading_3': 'h3val',
}
mock_retirement_report.return_value = report_data
result = _call_script(expected_num_rows=expected_num_users, config_orgs=fake_custom_orgs,
expected_fields=expected_fields)
# Make sure we're getting the LMS token
mock_get_access_token.assert_called_once()
# Make sure that we get the report
mock_retirement_report.assert_called_once()
# Make sure we tried to remove the users from the queue
mock_retirement_cleanup.assert_called_with(
[{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value]
)
assert 'All reports completed and uploaded to Google.' in result.output
def test_no_config():
runner = CliRunner()
result = runner.invoke(generate_report)
print(result.output)
assert result.exit_code == ERR_NO_CONFIG
assert 'No config file' in result.output
def test_no_secrets():
runner = CliRunner()
result = runner.invoke(generate_report, args=['--config_file', 'does_not_exist.yml'])
print(result.output)
assert result.exit_code == ERR_NO_SECRETS
assert 'No secrets file' in result.output
def test_no_output_dir():
runner = CliRunner()
with runner.isolated_filesystem():
with open(TEST_CONFIG_YML_NAME, 'w') as config_f:
config_f.write('irrelevant')
with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f:
config_f.write('irrelevant')
result = runner.invoke(
generate_report,
args=[
'--config_file',
TEST_CONFIG_YML_NAME,
'--google_secrets_file',
TEST_GOOGLE_SECRETS_FILENAME
]
)
print(result.output)
assert result.exit_code == ERR_NO_OUTPUT_DIR
assert 'No output_dir' in result.output
def test_bad_config():
runner = CliRunner()
with runner.isolated_filesystem():
with open(TEST_CONFIG_YML_NAME, 'w') as config_f:
config_f.write(']this is bad yaml')
with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f:
config_f.write('{this is bad json but we should not get to parsing it')
tmp_output_dir = 'test_output_dir'
os.mkdir(tmp_output_dir)
result = runner.invoke(
generate_report,
args=[
'--config_file',
TEST_CONFIG_YML_NAME,
'--google_secrets_file',
TEST_GOOGLE_SECRETS_FILENAME,
'--output_dir',
tmp_output_dir
]
)
print(result.output)
assert result.exit_code == ERR_BAD_CONFIG
assert 'Failed to read' in result.output
def test_bad_secrets():
runner = CliRunner()
with runner.isolated_filesystem():
with open(TEST_CONFIG_YML_NAME, 'w') as config_f:
fake_config_file(config_f)
with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f:
config_f.write('{this is bad json')
tmp_output_dir = 'test_output_dir'
os.mkdir(tmp_output_dir)
result = runner.invoke(
generate_report,
args=[
'--config_file',
TEST_CONFIG_YML_NAME,
'--google_secrets_file',
TEST_GOOGLE_SECRETS_FILENAME,
'--output_dir',
tmp_output_dir
]
)
print(result.output)
assert result.exit_code == ERR_BAD_SECRETS
assert 'Failed to read' in result.output
def test_bad_output_dir():
runner = CliRunner()
with runner.isolated_filesystem():
with open(TEST_CONFIG_YML_NAME, 'w') as config_f:
fake_config_file(config_f)
with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f:
fake_google_secrets_file(config_f)
result = runner.invoke(
generate_report,
args=[
'--config_file',
TEST_CONFIG_YML_NAME,
'--google_secrets_file',
TEST_GOOGLE_SECRETS_FILENAME,
'--output_dir',
'does_not_exist/at_all'
]
)
print(result.output)
assert result.exit_code == ERR_NO_OUTPUT_DIR
assert 'or path does not exist' in result.output
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
def test_setup_failed(*args):
mock_get_access_token = args[0]
mock_get_access_token.side_effect = Exception('boom')
result = _call_script(expect_success=False)
mock_get_access_token.assert_called_once()
assert result.exit_code == ERR_SETUP_FAILED
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_partner_report=DEFAULT)
def test_fetching_learners_failed(*args, **kwargs):
mock_get_access_token = args[0]
mock_walk_files = args[1]
mock_drive_init = args[2]
mock_retirement_report = kwargs['retirement_partner_report']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_walk_files.return_value = [{'name': 'dummy_file_name', 'id': 'dummy_file_id'}]
mock_drive_init.return_value = None
mock_retirement_report.side_effect = Exception('failed to get learners')
result = _call_script(expect_success=False)
assert result.exit_code == ERR_FETCHING_LEARNERS
assert 'failed to get learners' in result.output
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
def test_listing_folders_failed(*args):
mock_get_access_token = args[0]
mock_walk_files = args[1]
mock_drive_init = args[2]
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_walk_files.side_effect = [[], Exception()]
mock_drive_init.return_value = None
# call it once; this time walk_files will return an empty list.
result = _call_script(expect_success=False)
assert result.exit_code == ERR_DRIVE_LISTING
assert 'Finding partner directories on Drive failed' in result.output
# call it a second time; this time walk_files will throw an exception.
result = _call_script(expect_success=False)
assert result.exit_code == ERR_DRIVE_LISTING
assert 'Finding partner directories on Drive failed' in result.output
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_partner_report=DEFAULT)
def test_unknown_org(*args, **kwargs):
mock_get_access_token = args[0]
mock_drive_init = args[2]
mock_retirement_report = kwargs['retirement_partner_report']
mock_drive_init.return_value = None
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
orgs = ['orgA', 'orgB']
mock_retirement_report.return_value = [_fake_retirement_report_user(i, orgs, TEST_ORGS_CONFIG) for i in range(10)]
result = _call_script(expect_success=False)
assert result.exit_code == ERR_UNKNOWN_ORG
assert 'orgA' in result.output
assert 'orgB' in result.output
assert 'orgCustom' in result.output
assert 'otherCustomOrg' in result.output
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_partner_report=DEFAULT)
def test_unknown_org_custom(*args, **kwargs):
mock_get_access_token = args[0]
mock_drive_init = args[2]
mock_retirement_report = kwargs['retirement_partner_report']
mock_drive_init.return_value = None
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
custom_orgs_config = [
{
ORGS_CONFIG_ORG_KEY: 'singleCustomOrg',
ORGS_CONFIG_FIELD_HEADINGS_KEY: ['first_heading', 'second_heading']
}
]
mock_retirement_report.return_value = [_fake_retirement_report_user(i, None, custom_orgs_config) for i in range(2)]
result = _call_script(expect_success=False)
assert result.exit_code == ERR_UNKNOWN_ORG
assert 'organizations {\'singleCustomOrg\'} do not exist' in result.output
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch('unicodecsv.DictWriter')
@patch('scripts.user_retirement.utils.edx_api.LmsApi.retirement_partner_report')
def test_reporting_error(*args):
mock_retirement_report = args[0]
mock_dictwriter = args[1]
mock_get_access_token = args[2]
mock_drive_init = args[4]
error_msg = 'Fake unable to write csv'
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_dictwriter.side_effect = Exception(error_msg)
mock_drive_init.return_value = None
mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys()))
result = _call_script(expect_success=False)
assert result.exit_code == ERR_REPORTING
assert error_msg in result.output
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_partner_report=DEFAULT,
retirement_partner_cleanup=DEFAULT
)
def test_cleanup_error(*args, **kwargs):
mock_get_access_token = args[0]
mock_create_files = args[1]
mock_driveapi = args[2]
mock_walk_files = args[3]
mock_create_comments = args[4]
mock_list_permissions = args[5]
mock_retirement_report = kwargs['retirement_partner_report']
mock_retirement_cleanup = kwargs['retirement_partner_cleanup']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_create_files.return_value = True
mock_driveapi.return_value = None
mock_walk_files.return_value = [{'name': partner, 'id': 'folder' + partner} for partner in
flatten_partner_list(FAKE_ORGS.values())]
fake_partners = list(itervalues(FAKE_ORGS))
# Generate the list_permissions return value.
mock_list_permissions.return_value = {
'folder' + partner: [
{'emailAddress': 'some.contact@example.com'}, # The POC.
{'emailAddress': 'another.contact@edx.org'},
{'emailAddress': 'third@edx.org'}
]
for partner in flatten_partner_list(fake_partners)
}
mock_create_comments.return_value = None
mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys()))
mock_retirement_cleanup.side_effect = Exception('Mock cleanup exception')
result = _call_script(expect_success=False)
mock_retirement_cleanup.assert_called_with(
[{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value]
)
assert result.exit_code == ERR_CLEANUP
assert 'Users may be stuck in the processing state!' in result.output
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files')
@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files')
@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token')
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_partner_report=DEFAULT,
retirement_partner_cleanup=DEFAULT
)
def test_google_unicode_folder_names(*args, **kwargs):
mock_get_access_token = args[0]
mock_create_comments = args[1]
mock_list_permissions = args[2]
mock_walk_files = args[3]
mock_create_files = args[4]
mock_driveapi = args[5]
mock_retirement_report = kwargs['retirement_partner_report']
mock_retirement_cleanup = kwargs['retirement_partner_cleanup']
mock_get_access_token.return_value = ('THIS_IS_A_JWT', None)
mock_list_permissions.return_value = {
'folder' + partner: [
{'emailAddress': 'some.contact@example.com'},
{'emailAddress': 'another.contact@edx.org'},
]
for partner in [
unicodedata.normalize('NFKC', u'TéstX'),
unicodedata.normalize('NFKC', u'TéstX2'),
unicodedata.normalize('NFKC', u'TéstX3'),
]
}
mock_walk_files.return_value = [
{'name': partner, 'id': 'folder' + partner}
for partner in [
unicodedata.normalize('NFKC', u'TéstX'),
unicodedata.normalize('NFKC', u'TéstX2'),
unicodedata.normalize('NFKC', u'TéstX3'),
]
]
mock_create_files.side_effect = ['foo', 'bar', 'baz']
mock_driveapi.return_value = None
mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys()))
config_orgs = {
'org1': [unicodedata.normalize('NFKC', u'TéstX')],
'org2': [unicodedata.normalize('NFD', u'TéstX2')],
'org3': [unicodedata.normalize('NFKD', u'TéstX3')],
}
result = _call_script(config_orgs=config_orgs)
# Make sure we're getting the LMS token
mock_get_access_token.assert_called_once()
# Make sure that we get the report
mock_retirement_report.assert_called_once()
# Make sure we tried to upload the files
assert mock_create_files.call_count == 3
# Make sure we tried to add comments to the files
assert mock_create_comments.call_count == 1
# First [0] returns all positional args, second [0] gets the first positional arg.
create_comments_file_ids, create_comments_messages = zip(*mock_create_comments.call_args[0][0])
assert set(create_comments_file_ids) == set(['foo', 'bar', 'baz'])
assert all('+some.contact@example.com' in msg for msg in create_comments_messages)
assert all('+another.contact@edx.org' not in msg for msg in create_comments_messages)
# Make sure we tried to remove the users from the queue
mock_retirement_cleanup.assert_called_with(
[{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value]
)
assert 'All reports completed and uploaded to Google.' in result.output
def test_file_content_custom_headings():
runner = CliRunner()
with runner.isolated_filesystem():
config = {'partner_report_platform_name': 'fake_platform_name'}
tmp_output_dir = 'test_output_dir'
os.mkdir(tmp_output_dir)
# Custom headings and values
ch1 = 'special_id'
ch1v = '134456765432'
ch2 = 'alternate_heading_for_email'
ch2v = 'zxcvbvcxz@blah.com'
custom_field_headings = [ch1, ch2]
org_name = 'my_delightful_org'
username = 'unique_user'
learner_data = [
{
ch1: ch1v,
ch2: ch2v,
LEARNER_ORIGINAL_USERNAME_KEY: username,
LEARNER_CREATED_KEY: DELETION_TIME,
}
]
report_data = {
org_name: {
ORGS_CONFIG_FIELD_HEADINGS_KEY: custom_field_headings,
ORGS_CONFIG_LEARNERS_KEY: learner_data
}
}
partner_filenames = _generate_report_files_or_exit(config, report_data, tmp_output_dir)
assert len(partner_filenames) == 1
filename = partner_filenames[org_name]
with open(filename) as f:
file_content = f.read()
# Custom field headings
for ch in custom_field_headings:
# Verify custom field headings are present
assert ch in file_content
# Verify custom field values are present
assert ch1v in file_content
assert ch2v in file_content
# Default field headings
for h in DEFAULT_FIELD_HEADINGS:
# Verify default field headings are not present
assert h not in file_content
# Verify default field values are not present
assert username not in file_content
assert DELETION_TIME not in file_content

View File

@@ -0,0 +1,584 @@
"""
Tests for edX API calls.
"""
import unittest
from urllib.parse import urljoin
import requests
import responses
from ddt import data, ddt, unpack
from mock import DEFAULT, patch
from requests.exceptions import ConnectionError, HTTPError
from responses import GET, PATCH, POST, matchers
from responses.registries import OrderedRegistry
from scripts.user_retirement.tests.mixins import OAuth2Mixin
from scripts.user_retirement.tests.retirement_helpers import (
FAKE_DATETIME_OBJECT,
FAKE_ORIGINAL_USERNAME,
FAKE_RESPONSE_MESSAGE,
FAKE_USERNAME_MAPPING,
FAKE_USERNAMES,
TEST_RETIREMENT_QUEUE_STATES,
TEST_RETIREMENT_STATE,
get_fake_user_retirement
)
from scripts.user_retirement.utils import edx_api
class BackoffTriedException(Exception):
"""
Raise this from a backoff handler to indicate that backoff was tried.
"""
@ddt
class TestLmsApi(OAuth2Mixin, unittest.TestCase):
"""
Test the edX LMS API client.
"""
@responses.activate(registry=OrderedRegistry)
def setUp(self):
super().setUp()
self.mock_access_token_response()
self.lms_base_url = 'http://localhost:18000/'
self.lms_api = edx_api.LmsApi(
self.lms_base_url,
self.lms_base_url,
'the_client_id',
'the_client_secret'
)
def tearDown(self):
super().tearDown()
responses.reset()
@patch.object(edx_api.LmsApi, 'learners_to_retire')
def test_learners_to_retire(self, mock_method):
params = {
'states': TEST_RETIREMENT_QUEUE_STATES,
'cool_off_days': 365,
}
responses.add(
GET,
urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_queue/'),
match=[matchers.query_param_matcher(params)],
)
self.lms_api.learners_to_retire(
TEST_RETIREMENT_QUEUE_STATES, cool_off_days=365)
mock_method.assert_called_once_with(
TEST_RETIREMENT_QUEUE_STATES, cool_off_days=365)
@patch.object(edx_api.LmsApi, 'get_learners_by_date_and_status')
def test_get_learners_by_date_and_status(self, mock_method):
query_params = {
'start_date': FAKE_DATETIME_OBJECT.strftime('%Y-%m-%d'),
'end_date': FAKE_DATETIME_OBJECT.strftime('%Y-%m-%d'),
'state': TEST_RETIREMENT_STATE,
}
responses.add(
GET,
urljoin(self.lms_base_url, 'api/user/v1/accounts/retirements_by_status_and_date/'),
match=[matchers.query_param_matcher(query_params)]
)
self.lms_api.get_learners_by_date_and_status(
state_to_request=TEST_RETIREMENT_STATE,
start_date=FAKE_DATETIME_OBJECT,
end_date=FAKE_DATETIME_OBJECT
)
mock_method.assert_called_once_with(
state_to_request=TEST_RETIREMENT_STATE,
start_date=FAKE_DATETIME_OBJECT,
end_date=FAKE_DATETIME_OBJECT
)
@patch.object(edx_api.LmsApi, 'get_learner_retirement_state')
def test_get_learner_retirement_state(self, mock_method):
responses.add(
GET,
urljoin(self.lms_base_url, f'api/user/v1/accounts/{FAKE_ORIGINAL_USERNAME}/retirement_status/'),
)
self.lms_api.get_learner_retirement_state(
username=FAKE_ORIGINAL_USERNAME
)
mock_method.assert_called_once_with(
username=FAKE_ORIGINAL_USERNAME
)
@patch.object(edx_api.LmsApi, 'update_learner_retirement_state')
def test_update_leaner_retirement_state(self, mock_method):
json_data = {
'username': FAKE_ORIGINAL_USERNAME,
'new_state': TEST_RETIREMENT_STATE,
'response': FAKE_RESPONSE_MESSAGE,
}
responses.add(
PATCH,
urljoin(self.lms_base_url, 'api/user/v1/accounts/update_retirement_status/'),
match=[matchers.json_params_matcher(json_data)]
)
self.lms_api.update_learner_retirement_state(
username=FAKE_ORIGINAL_USERNAME,
new_state_name=TEST_RETIREMENT_STATE,
message=FAKE_RESPONSE_MESSAGE
)
mock_method.assert_called_once_with(
username=FAKE_ORIGINAL_USERNAME,
new_state_name=TEST_RETIREMENT_STATE,
message=FAKE_RESPONSE_MESSAGE
)
@data(
{
'api_url': 'api/user/v1/accounts/deactivate_logout/',
'mock_method': 'retirement_deactivate_logout',
'method': 'POST',
},
{
'api_url': 'api/discussion/v1/accounts/retire_forum/',
'mock_method': 'retirement_retire_forum',
'method': 'POST',
},
{
'api_url': 'api/user/v1/accounts/retire_mailings/',
'mock_method': 'retirement_retire_mailings',
'method': 'POST',
},
{
'api_url': 'api/enrollment/v1/unenroll/',
'mock_method': 'retirement_unenroll',
'method': 'POST',
},
{
'api_url': 'api/edxnotes/v1/retire_user/',
'mock_method': 'retirement_retire_notes',
'method': 'POST',
},
{
'api_url': 'api/user/v1/accounts/retire_misc/',
'mock_method': 'retirement_lms_retire_misc',
'method': 'POST',
},
{
'api_url': 'api/user/v1/accounts/retire/',
'mock_method': 'retirement_lms_retire',
'method': 'POST',
},
{
'api_url': 'api/user/v1/accounts/retirement_partner_report/',
'mock_method': 'retirement_partner_queue',
'method': 'PUT',
},
)
@unpack
@patch.multiple(
'scripts.user_retirement.utils.edx_api.LmsApi',
retirement_deactivate_logout=DEFAULT,
retirement_retire_forum=DEFAULT,
retirement_retire_mailings=DEFAULT,
retirement_unenroll=DEFAULT,
retirement_retire_notes=DEFAULT,
retirement_lms_retire_misc=DEFAULT,
retirement_lms_retire=DEFAULT,
retirement_partner_queue=DEFAULT,
)
def test_learner_retirement(self, api_url, mock_method, method, **kwargs):
json_data = {
'username': FAKE_ORIGINAL_USERNAME,
}
responses.add(
method,
urljoin(self.lms_base_url, api_url),
match=[matchers.json_params_matcher(json_data)]
)
getattr(self.lms_api, mock_method)(get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME))
kwargs[mock_method].assert_called_once_with(get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME))
@patch.object(edx_api.LmsApi, 'retirement_partner_report')
def test_retirement_partner_report(self, mock_method):
responses.add(
POST,
urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_partner_report/')
)
self.lms_api.retirement_partner_report(
learner=get_fake_user_retirement(
original_username=FAKE_ORIGINAL_USERNAME
)
)
mock_method.assert_called_once_with(
learner=get_fake_user_retirement(
original_username=FAKE_ORIGINAL_USERNAME
)
)
@patch.object(edx_api.LmsApi, 'retirement_partner_cleanup')
def test_retirement_partner_cleanup(self, mock_method):
json_data = FAKE_USERNAMES
responses.add(
POST,
urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_partner_report_cleanup/'),
match=[matchers.json_params_matcher(json_data)]
)
self.lms_api.retirement_partner_cleanup(
usernames=FAKE_USERNAMES
)
mock_method.assert_called_once_with(
usernames=FAKE_USERNAMES
)
@patch.object(edx_api.LmsApi, 'retirement_retire_proctoring_data')
def test_retirement_retire_proctoring_data(self, mock_method):
learner = get_fake_user_retirement()
responses.add(
POST,
urljoin(self.lms_base_url, f"api/edx_proctoring/v1/retire_user/{learner['user']['id']}/"),
)
self.lms_api.retirement_retire_proctoring_data()
mock_method.assert_called_once()
@patch.object(edx_api.LmsApi, 'retirement_retire_proctoring_backend_data')
def test_retirement_retire_proctoring_backend_data(self, mock_method):
learner = get_fake_user_retirement()
responses.add(
POST,
urljoin(self.lms_base_url, f"api/edx_proctoring/v1/retire_backend_user/{learner['user']['id']}/"),
)
self.lms_api.retirement_retire_proctoring_backend_data()
mock_method.assert_called_once()
@patch.object(edx_api.LmsApi, 'replace_lms_usernames')
def test_replace_lms_usernames(self, mock_method):
json_data = {
'username_mappings': FAKE_USERNAME_MAPPING
}
responses.add(
POST,
urljoin(self.lms_base_url, 'api/user/v1/accounts/replace_usernames/'),
match=[matchers.json_params_matcher(json_data)]
)
self.lms_api.replace_lms_usernames(
username_mappings=FAKE_USERNAME_MAPPING
)
mock_method.assert_called_once_with(
username_mappings=FAKE_USERNAME_MAPPING
)
@patch.object(edx_api.LmsApi, 'replace_forums_usernames')
def test_replace_forums_usernames(self, mock_method):
json_data = {
'username_mappings': FAKE_USERNAME_MAPPING
}
responses.add(
POST,
urljoin(self.lms_base_url, 'api/discussion/v1/accounts/replace_usernames/'),
match=[matchers.json_params_matcher(json_data)]
)
self.lms_api.replace_forums_usernames(
username_mappings=FAKE_USERNAME_MAPPING
)
mock_method.assert_called_once_with(
username_mappings=FAKE_USERNAME_MAPPING
)
@data(504, 500)
@patch('scripts.user_retirement.utils.edx_api._backoff_handler')
@patch.object(edx_api.LmsApi, 'learners_to_retire')
def test_retrieve_learner_queue_backoff(
self,
svr_status_code,
mock_backoff_handler,
mock_learners_to_retire
):
mock_backoff_handler.side_effect = BackoffTriedException
params = {
'states': TEST_RETIREMENT_QUEUE_STATES,
'cool_off_days': 365,
}
response = requests.Response()
response.status_code = svr_status_code
responses.add(
GET,
urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_queue/'),
status=200,
match=[matchers.query_param_matcher(params)],
)
mock_learners_to_retire.side_effect = HTTPError(response=response)
with self.assertRaises(BackoffTriedException):
self.lms_api.learners_to_retire(
TEST_RETIREMENT_QUEUE_STATES, cool_off_days=365)
@data(104)
@responses.activate
@patch('scripts.user_retirement.utils.edx_api._backoff_handler')
@patch.object(edx_api.LmsApi, 'retirement_partner_cleanup')
def test_retirement_partner_cleanup_backoff_on_connection_error(
self,
svr_status_code,
mock_backoff_handler,
mock_retirement_partner_cleanup
):
mock_backoff_handler.side_effect = BackoffTriedException
response = requests.Response()
response.status_code = svr_status_code
mock_retirement_partner_cleanup.retirement_partner_cleanup.side_effect = ConnectionError(
response=response
)
with self.assertRaises(BackoffTriedException):
self.lms_api.retirement_partner_cleanup([{'original_username': 'test'}])
class TestEcommerceApi(OAuth2Mixin, unittest.TestCase):
"""
Test the edX Ecommerce API client.
"""
@responses.activate(registry=OrderedRegistry)
def setUp(self):
super().setUp()
self.mock_access_token_response()
self.lms_base_url = 'http://localhost:18000/'
self.ecommerce_base_url = 'http://localhost:18130/'
self.ecommerce_api = edx_api.EcommerceApi(
self.lms_base_url,
self.ecommerce_base_url,
'the_client_id',
'the_client_secret'
)
def tearDown(self):
super().tearDown()
responses.reset()
@patch.object(edx_api.EcommerceApi, 'retire_learner')
def test_retirement_partner_report(self, mock_method):
json_data = {
'username': FAKE_ORIGINAL_USERNAME,
}
responses.add(
POST,
urljoin(self.lms_base_url, 'api/v2/user/retire/'),
match=[matchers.json_params_matcher(json_data)]
)
self.ecommerce_api.retire_learner(
learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)
)
mock_method.assert_called_once_with(
learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)
)
@patch.object(edx_api.EcommerceApi, 'retire_learner')
def get_tracking_key(self, mock_method):
original_username = {
'original_username': get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)
}
responses.add(
GET,
urljoin(self.lms_base_url, f"api/v2/retirement/tracking_id/{original_username}/"),
)
self.ecommerce_api.get_tracking_key(
learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)
)
mock_method.assert_called_once_with(
learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)
)
@patch.object(edx_api.EcommerceApi, 'replace_usernames')
def test_replace_usernames(self, mock_method):
json_data = {
"username_mappings": FAKE_USERNAME_MAPPING
}
responses.add(
POST,
urljoin(self.lms_base_url, 'api/v2/user_management/replace_usernames/'),
match=[matchers.json_params_matcher(json_data)]
)
self.ecommerce_api.replace_usernames(
username_mappings=FAKE_USERNAME_MAPPING
)
mock_method.assert_called_once_with(
username_mappings=FAKE_USERNAME_MAPPING
)
class TestCredentialApi(OAuth2Mixin, unittest.TestCase):
"""
Test the edX Credential API client.
"""
@responses.activate(registry=OrderedRegistry)
def setUp(self):
super().setUp()
self.mock_access_token_response()
self.lms_base_url = 'http://localhost:18000/'
self.credentials_base_url = 'http://localhost:18150/'
self.credentials_api = edx_api.CredentialsApi(
self.lms_base_url,
self.credentials_base_url,
'the_client_id',
'the_client_secret'
)
def tearDown(self):
super().tearDown()
responses.reset()
@patch.object(edx_api.CredentialsApi, 'retire_learner')
def test_retire_learner(self, mock_method):
json_data = {
'username': FAKE_ORIGINAL_USERNAME
}
responses.add(
POST,
urljoin(self.credentials_base_url, 'user/retire/'),
match=[matchers.json_params_matcher(json_data)]
)
self.credentials_api.retire_learner(
learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)
)
mock_method.assert_called_once_with(
learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)
)
@patch.object(edx_api.CredentialsApi, 'replace_usernames')
def test_replace_usernames(self, mock_method):
json_data = {
"username_mappings": FAKE_USERNAME_MAPPING
}
responses.add(
POST,
urljoin(self.credentials_base_url, 'api/v2/replace_usernames/'),
match=[matchers.json_params_matcher(json_data)]
)
self.credentials_api.replace_usernames(
username_mappings=FAKE_USERNAME_MAPPING
)
mock_method.assert_called_once_with(
username_mappings=FAKE_USERNAME_MAPPING
)
class TestDiscoveryApi(OAuth2Mixin, unittest.TestCase):
"""
Test the edX Discovery API client.
"""
@responses.activate(registry=OrderedRegistry)
def setUp(self):
super().setUp()
self.mock_access_token_response()
self.lms_base_url = 'http://localhost:18000/'
self.discovery_base_url = 'http://localhost:18150/'
self.discovery_api = edx_api.DiscoveryApi(
self.lms_base_url,
self.discovery_base_url,
'the_client_id',
'the_client_secret'
)
def tearDown(self):
super().tearDown()
responses.reset()
@patch.object(edx_api.DiscoveryApi, 'replace_usernames')
def test_replace_usernames(self, mock_method):
json_data = {
"username_mappings": FAKE_USERNAME_MAPPING
}
responses.add(
POST,
urljoin(self.discovery_base_url, 'api/v1/replace_usernames/'),
match=[matchers.json_params_matcher(json_data)]
)
self.discovery_api.replace_usernames(
username_mappings=FAKE_USERNAME_MAPPING
)
mock_method.assert_called_once_with(
username_mappings=FAKE_USERNAME_MAPPING
)
class TestDemographicsApi(OAuth2Mixin, unittest.TestCase):
"""
Test the edX Demographics API client.
"""
@responses.activate(registry=OrderedRegistry)
def setUp(self):
super().setUp()
self.mock_access_token_response()
self.lms_base_url = 'http://localhost:18000/'
self.demographics_base_url = 'http://localhost:18360/'
self.demographics_api = edx_api.DemographicsApi(
self.lms_base_url,
self.demographics_base_url,
'the_client_id',
'the_client_secret'
)
def tearDown(self):
super().tearDown()
responses.reset()
@patch.object(edx_api.DemographicsApi, 'retire_learner')
def test_retire_learner(self, mock_method):
json_data = {
'lms_user_id': get_fake_user_retirement()['user']['id']
}
responses.add(
POST,
urljoin(self.demographics_base_url, 'demographics/api/v1/retire_demographics/'),
match=[matchers.json_params_matcher(json_data)]
)
self.demographics_api.retire_learner(
learner=get_fake_user_retirement()
)
mock_method.assert_called_once_with(
learner=get_fake_user_retirement()
)
class TestLicenseManagerApi(OAuth2Mixin, unittest.TestCase):
"""
Test the edX License Manager API client.
"""
@responses.activate(registry=OrderedRegistry)
def setUp(self):
super().setUp()
self.mock_access_token_response()
self.lms_base_url = 'http://localhost:18000/'
self.license_manager_base_url = 'http://localhost:18170/'
self.license_manager_api = edx_api.LicenseManagerApi(
self.lms_base_url,
self.license_manager_base_url,
'the_client_id',
'the_client_secret'
)
def tearDown(self):
super().tearDown()
responses.reset()
@patch.object(edx_api.LicenseManagerApi, 'retire_learner')
def test_retire_learner(self, mock_method):
json_data = {
'lms_user_id': get_fake_user_retirement()['user']['id'],
'original_username': FAKE_ORIGINAL_USERNAME,
}
responses.add(
POST,
urljoin(self.license_manager_base_url, 'api/v1/retire_user/'),
match=[matchers.json_params_matcher(json_data)]
)
self.license_manager_api.retire_learner(
learner=get_fake_user_retirement(
original_username=FAKE_ORIGINAL_USERNAME
)
)
mock_method.assert_called_once_with(
learner=get_fake_user_retirement(
original_username=FAKE_ORIGINAL_USERNAME
)
)

View File

@@ -0,0 +1,193 @@
"""
Tests for triggering a Jenkins job.
"""
import json
import re
import unittest
from itertools import islice
import backoff
import ddt
import requests_mock
from mock import Mock, call, mock_open, patch
import scripts.user_retirement.utils.jenkins as jenkins
from scripts.user_retirement.utils.exception import BackendError
BASE_URL = u'https://test-jenkins'
USER_ID = u'foo'
USER_TOKEN = u'12345678901234567890123456789012'
JOB = u'test-job'
TOKEN = u'asdf'
BUILD_NUM = 456
JOBS_URL = u'{}/job/{}/'.format(BASE_URL, JOB)
JOB_URL = u'{}{}'.format(JOBS_URL, BUILD_NUM)
MOCK_BUILD = {u'number': BUILD_NUM, u'url': JOB_URL}
MOCK_JENKINS_DATA = {'jobs': [{'name': JOB, 'url': JOBS_URL, 'color': 'blue'}]}
MOCK_BUILDS_DATA = {
'actions': [
{'parameterDefinitions': [
{'defaultParameterValue': {'value': '0'}, 'name': 'EXIT_CODE', 'type': 'StringParameterDefinition'}
]}
],
'builds': [MOCK_BUILD],
'lastBuild': MOCK_BUILD
}
MOCK_QUEUE_DATA = {
'id': 123,
'task': {'name': JOB, 'url': JOBS_URL},
'executable': {'number': BUILD_NUM, 'url': JOB_URL}
}
MOCK_BUILD_DATA = {
'actions': [{}],
'fullDisplayName': 'foo',
'number': BUILD_NUM,
'result': 'SUCCESS',
'url': JOB_URL,
}
MOCK_CRUMB_DATA = {
'crumbRequestField': 'Jenkins-Crumb',
'crumb': '1234567890'
}
class TestProperties(unittest.TestCase):
"""
Test the Jenkins property-creating methods.
"""
def test_properties_files(self):
learners = [
{
'original_username': 'learnerA'
},
{
'original_username': 'learnerB'
},
]
open_mocker = mock_open()
with patch('scripts.user_retirement.utils.jenkins.open', open_mocker, create=True):
jenkins._recreate_directory = Mock() # pylint: disable=protected-access
jenkins.export_learner_job_properties(learners, "tmpdir")
jenkins._recreate_directory.assert_called_once() # pylint: disable=protected-access
self.assertIn(call('tmpdir/learner_retire_learnera', 'w'), open_mocker.call_args_list)
self.assertIn(call('tmpdir/learner_retire_learnerb', 'w'), open_mocker.call_args_list)
handle = open_mocker()
self.assertIn(call('RETIREMENT_USERNAME=learnerA\n'), handle.write.call_args_list)
self.assertIn(call('RETIREMENT_USERNAME=learnerB\n'), handle.write.call_args_list)
@ddt.ddt
class TestBackoff(unittest.TestCase):
u"""
Test of custom backoff code (wait time generator and max_tries)
"""
@ddt.data(
(2, 1, 1, 2, [1]),
(2, 1, 2, 3, [1, 1]),
(2, 1, 3, 3, [1, 2]),
(2, 100, 90, 2, [90]),
(2, 1, 90, 8, [1, 2, 4, 8, 16, 32, 27]),
(3, 5, 1000, 7, [5, 15, 45, 135, 405, 395]),
(2, 1, 3600, 13, [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1553]),
)
@ddt.unpack
def test_max_timeout(self, base, factor, timeout, expected_max_tries, expected_waits):
# pylint: disable=protected-access
wait_gen, max_tries = jenkins._backoff_timeout(timeout, base, factor)
self.assertEqual(expected_max_tries, max_tries)
# Use max_tries-1, because we only wait that many times
waits = list(islice(wait_gen(), max_tries - 1))
self.assertEqual(expected_waits, waits)
self.assertEqual(timeout, sum(waits))
def test_backoff_call(self):
# pylint: disable=protected-access
wait_gen, max_tries = jenkins._backoff_timeout(timeout=.36, base=2, factor=.0001)
always_false = Mock(return_value=False)
count_retries = backoff.on_predicate(
wait_gen,
max_tries=max_tries,
on_backoff=print,
jitter=None,
)(always_false.__call__)
count_retries()
self.assertEqual(always_false.call_count, 13)
@ddt.ddt
class TestJenkinsAPI(unittest.TestCase):
"""
Tests for interacting with the Jenkins API
"""
@requests_mock.Mocker()
def test_failure(self, mock):
"""
Test the failure condition when triggering a jenkins job
"""
# Mock all network interactions
mock.get(
re.compile(".*"),
status_code=404,
)
with self.assertRaises(BackendError):
jenkins.trigger_build(BASE_URL, USER_ID, USER_TOKEN, JOB, TOKEN, None, ())
@ddt.data(
(None, ()),
('my cause', ()),
(None, ((u'FOO', u'bar'),)),
(None, ((u'FOO', u'bar'), (u'BAZ', u'biz'))),
('my cause', ((u'FOO', u'bar'),)),
)
@ddt.unpack
@requests_mock.Mocker()
def test_success(self, cause, param, mock):
u"""
Test triggering a jenkins job
"""
def text_callback(request, context):
u""" What to return from the mock. """
# This is the initial call that jenkinsapi uses to
# establish connectivity to Jenkins
# https://test-jenkins/api/python?tree=jobs[name,color,url]
context.status_code = 200
if request.url.startswith(u'https://test-jenkins/api/python'):
return json.dumps(MOCK_JENKINS_DATA)
elif request.url.startswith(u'https://test-jenkins/job/test-job/456'):
return json.dumps(MOCK_BUILD_DATA)
elif request.url.startswith(u'https://test-jenkins/job/test-job'):
return json.dumps(MOCK_BUILDS_DATA)
elif request.url.startswith(u'https://test-jenkins/queue/item/123/api/python'):
return json.dumps(MOCK_QUEUE_DATA)
elif request.url.startswith(u'https://test-jenkins/crumbIssuer/api/python'):
return json.dumps(MOCK_CRUMB_DATA)
else:
# We should never get here, unless the jenkinsapi implementation changes.
# This response will catch that condition.
context.status_code = 500
return None
# Mock all network interactions
mock.get(
re.compile('.*'),
text=text_callback
)
mock.post(
'{}/job/test-job/buildWithParameters'.format(BASE_URL),
status_code=201, # Jenkins responds with a 201 Created on success
headers={'location': '{}/queue/item/123'.format(BASE_URL)}
)
# Make the call to the Jenkins API
result = jenkins.trigger_build(BASE_URL, USER_ID, USER_TOKEN, JOB, TOKEN, cause, param)
self.assertEqual(result, 'SUCCESS')

View File

@@ -0,0 +1,86 @@
"""
Tests for the Amplitude API functionality
"""
import logging
import os
import unittest
from unittest import mock
import ddt
import requests_mock
MAX_ATTEMPTS = int(os.environ.get("RETRY_MAX_ATTEMPTS", 5))
from scripts.user_retirement.utils.thirdparty_apis.amplitude_api import (
AmplitudeApi,
AmplitudeException,
AmplitudeRecoverableException
)
@ddt.ddt
@requests_mock.Mocker()
class TestAmplitude(unittest.TestCase):
"""
Class containing tests of all code interacting with Amplitude.
"""
def setUp(self):
super().setUp()
self.user = {"user": {"id": "1234"}}
self.amplitude = AmplitudeApi("test-api-key", "test-secret-key")
def _mock_delete(self, req_mock, status_code, message=None):
"""
Send a mock request with dummy headers and status code.
"""
req_mock.post(
"https://amplitude.com/api/2/deletions/users",
headers={"Content-Type": "application/json"},
json={},
status_code=status_code
)
def test_delete_happy_path(self, req_mock):
"""
This test pass status_code 200 to mock_delete see how AmplitudeApi respond in happy path.
"""
self._mock_delete(req_mock, 200)
logger = logging.getLogger("scripts.user_retirement.utils.thirdparty_apis.amplitude_api")
with mock.patch.object(logger, "info") as mock_info:
self.amplitude.delete_user(self.user)
self.assertEqual(mock_info.call_args, [("Amplitude user deletion succeeded",)])
self.assertEqual(len(req_mock.request_history), 1)
request = req_mock.request_history[0]
self.assertEqual(request.json(),
{"user_ids": ["1234"], 'ignore_invalid_id': 'true', "requester": "user-retirement-pipeline"})
def test_delete_fatal_error(self, req_mock):
"""
This test pass status_code 404 to see how AmplitudeApi respond in fatal error case.
"""
self._mock_delete(req_mock, 404)
message = None
logger = logging.getLogger("scripts.user_retirement.utils.thirdparty_apis.amplitude_api")
with mock.patch.object(logger, "error") as mock_error:
with self.assertRaises(AmplitudeException) as exc:
self.amplitude.delete_user(self.user)
error = "Amplitude user deletion failed due to {message}".format(message=message)
self.assertEqual(mock_error.call_args, [(error,)])
self.assertEqual(str(exc.exception), error)
@ddt.data(429, 500)
def test_delete_recoverable_error(self, status_code, req_mock):
"""
This test pass status_code 429 and 500 to see how AmplitudeApi respond to recoverable cases.
"""
self._mock_delete(req_mock, status_code)
with self.assertRaises(AmplitudeRecoverableException):
self.amplitude.delete_user(self.user)
self.assertEqual(len(req_mock.request_history), MAX_ATTEMPTS)

View File

@@ -0,0 +1,66 @@
"""
Tests for the Braze API functionality
"""
import logging
import unittest
from unittest import mock
import ddt
import requests_mock
from scripts.user_retirement.utils.thirdparty_apis.braze_api import BrazeApi, BrazeException, BrazeRecoverableException
@ddt.ddt
@requests_mock.Mocker()
class TestBraze(unittest.TestCase):
"""
Class containing tests of all code interacting with Braze.
"""
def setUp(self):
super().setUp()
self.learner = {'user': {'id': 1234}}
self.braze = BrazeApi('test-key', 'test-instance')
def _mock_delete(self, req_mock, status_code, message=None):
req_mock.post(
'https://rest.test-instance.braze.com/users/delete',
request_headers={'Authorization': 'Bearer test-key'},
json={'message': message} if message else {},
status_code=status_code
)
def test_delete_happy_path(self, req_mock):
self._mock_delete(req_mock, 200)
logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.braze_api')
with mock.patch.object(logger, 'info') as mock_info:
self.braze.delete_user(self.learner)
self.assertEqual(mock_info.call_args, [('Braze user deletion succeeded',)])
self.assertEqual(len(req_mock.request_history), 1)
request = req_mock.request_history[0]
self.assertEqual(request.json(), {'external_ids': [1234]})
def test_delete_fatal_error(self, req_mock):
self._mock_delete(req_mock, 404, message='Test Error Message')
logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.braze_api')
with mock.patch.object(logger, 'error') as mock_error:
with self.assertRaises(BrazeException) as exc:
self.braze.delete_user(self.learner)
error = 'Braze user deletion failed due to Test Error Message'
self.assertEqual(mock_error.call_args, [(error,)])
self.assertEqual(str(exc.exception), error)
@ddt.data(429, 500)
def test_delete_recoverable_error(self, status_code, req_mock):
self._mock_delete(req_mock, status_code)
with self.assertRaises(BrazeRecoverableException):
self.braze.delete_user(self.learner)
self.assertEqual(len(req_mock.request_history), 5)

View File

@@ -0,0 +1,159 @@
"""
Tests for the Sailthru API functionality
"""
import logging
import os
import unittest
from unittest import mock
import requests_mock
from six.moves import reload_module
# This module is imported separately solely so it can be re-loaded below.
from scripts.user_retirement.utils.thirdparty_apis import hubspot_api
# This HubspotAPI class will be used without being re-loaded.
from scripts.user_retirement.utils.thirdparty_apis.hubspot_api import HubspotAPI
# Change the number of retries for Hubspot API's delete_user call to 1.
# Then reload hubspot_api so only a single retry is performed.
os.environ['RETRY_HUBSPOT_MAX_ATTEMPTS'] = "1"
reload_module(hubspot_api) # pylint: disable=too-many-function-args
@requests_mock.Mocker()
@mock.patch.object(HubspotAPI, 'send_marketing_alert')
class TestHubspot(unittest.TestCase):
"""
Class containing tests of all code interacting with Hubspot.
"""
def setUp(self):
super(TestHubspot, self).setUp()
self.test_learner = {'original_email': 'foo@bar.com'}
self.api_key = 'example_key'
self.test_vid = 12345
self.test_region = 'test-east-1'
self.from_address = 'no-reply@example.com'
self.alert_email = 'marketing@example.com'
def _mock_get_vid(self, req_mock, status_code):
req_mock.get(
hubspot_api.GET_VID_FROM_EMAIL_URL_TEMPLATE.format(
email=self.test_learner['original_email']
),
json={'vid': self.test_vid},
status_code=status_code
)
def _mock_delete(self, req_mock, status_code):
req_mock.delete(
hubspot_api.DELETE_USER_FROM_VID_TEMPLATE.format(
vid=self.test_vid
),
json={},
status_code=status_code
)
def test_delete_no_email(self, req_mock, mock_alert): # pylint: disable=unused-argument
with self.assertRaises(TypeError) as exc:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user({})
self.assertIn('Expected an email address for user to delete, but received None.', str(exc))
mock_alert.assert_not_called()
def test_delete_success(self, req_mock, mock_alert):
self._mock_get_vid(req_mock, 200)
self._mock_delete(req_mock, 200)
logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.hubspot_api')
with mock.patch.object(logger, 'info') as mock_info:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user(self.test_learner)
mock_info.assert_called_once_with("User successfully deleted from Hubspot")
mock_alert.assert_called_once_with(12345)
def test_delete_email_does_not_exist(self, req_mock, mock_alert):
self._mock_get_vid(req_mock, 404)
logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.hubspot_api')
with mock.patch.object(logger, 'info') as mock_info:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user(self.test_learner)
mock_info.assert_called_once_with("No action taken because no user was found in Hubspot.")
mock_alert.assert_not_called()
def test_delete_server_failure_on_user_retrieval(self, req_mock, mock_alert):
self._mock_get_vid(req_mock, 500)
with self.assertRaises(hubspot_api.HubspotException) as exc:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user(self.test_learner)
self.assertIn("Error attempted to get user_vid from Hubspot", str(exc))
mock_alert.assert_not_called()
def test_delete_unauthorized_deletion(self, req_mock, mock_alert):
self._mock_get_vid(req_mock, 200)
self._mock_delete(req_mock, 401)
with self.assertRaises(hubspot_api.HubspotException) as exc:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user(self.test_learner)
self.assertIn("Hubspot user deletion failed due to authorized API call", str(exc))
mock_alert.assert_not_called()
def test_delete_vid_not_found(self, req_mock, mock_alert):
self._mock_get_vid(req_mock, 200)
self._mock_delete(req_mock, 404)
with self.assertRaises(hubspot_api.HubspotException) as exc:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user(self.test_learner)
self.assertIn("Hubspot user deletion failed because vid doesn't match user", str(exc))
mock_alert.assert_not_called()
def test_delete_server_failure_on_deletion(self, req_mock, mock_alert):
self._mock_get_vid(req_mock, 200)
self._mock_delete(req_mock, 500)
with self.assertRaises(hubspot_api.HubspotException) as exc:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user(self.test_learner)
self.assertIn("Hubspot user deletion failed due to server-side (Hubspot) issues", str(exc))
mock_alert.assert_not_called()
def test_delete_catch_all_on_deletion(self, req_mock, mock_alert):
self._mock_get_vid(req_mock, 200)
# Testing 403 as it's not a response type per the Hubspot API docs, so it doesn't have it's own error.
self._mock_delete(req_mock, 403)
with self.assertRaises(hubspot_api.HubspotException) as exc:
HubspotAPI(
self.api_key,
self.test_region,
self.from_address,
self.alert_email
).delete_user(self.test_learner)
self.assertIn("Hubspot user deletion failed due to unknown reasons", str(exc))
mock_alert.assert_not_called()

View File

@@ -0,0 +1,155 @@
"""
Tests for the Salesforce API functionality
"""
import logging
from contextlib import contextmanager
import mock
import pytest
from simple_salesforce import SalesforceError
from scripts.user_retirement.utils.thirdparty_apis import salesforce_api
@pytest.fixture
def test_learner():
return {'original_email': 'foo@bar.com'}
def make_api():
"""
Helper function to create salesforce api object
"""
return salesforce_api.SalesforceApi("user", "pass", "key", "domain", "user")
@contextmanager
def mock_get_user():
"""
Context manager method to mock getting the assignee user id when the api object is created
"""
with mock.patch(
'scripts.user_retirement.utils.thirdparty_apis.salesforce_api.SalesforceApi.get_user_id'
) as getuser:
getuser.return_value = "userid"
yield
def test_no_assignee_email():
with mock.patch(
'scripts.user_retirement.utils.thirdparty_apis.salesforce_api.SalesforceApi.get_user_id'
) as getuser:
getuser.return_value = None
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
with pytest.raises(Exception) as exc:
make_api()
print(str(exc))
assert 'Could not find Salesforce user with username user' in str(exc)
def test_retire_no_email():
with mock_get_user():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
with pytest.raises(TypeError) as exc:
make_api().retire_learner({})
assert 'Expected an email address for user to delete, but received None.' in str(exc)
def test_retire_get_id_error(test_learner): # pylint: disable=redefined-outer-name
with mock_get_user():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
api._sf.query.side_effect = SalesforceError("", "", "", "") # pylint: disable=protected-access
with pytest.raises(SalesforceError):
api.retire_learner(test_learner)
# pylint: disable=protected-access
def test_escape_email():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
mock_response = {'totalSize': 0, 'records': []}
api._sf.query.return_value = mock_response
api.get_lead_ids_by_email("Robert'); DROP TABLE students;--")
api._sf.query.assert_called_with(
"SELECT Id FROM Lead WHERE Email = 'Robert\\'); DROP TABLE students;--'"
)
# pylint: disable=protected-access
def test_escape_username():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
mock_response = {'totalSize': 0, 'records': []}
api._sf.query.return_value = mock_response
api.get_user_id("Robert'); DROP TABLE students;--")
api._sf.query.assert_called_with(
"SELECT Id FROM User WHERE Username = 'Robert\\'); DROP TABLE students;--'"
)
def test_retire_learner_not_found(test_learner, caplog): # pylint: disable=redefined-outer-name
caplog.set_level(logging.INFO)
with mock_get_user():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
mock_response = {'totalSize': 0, 'records': []}
api._sf.query.return_value = mock_response # pylint: disable=protected-access
api.retire_learner(test_learner)
assert not api._sf.Task.create.called # pylint: disable=protected-access
assert 'No action taken because no lead was found in Salesforce.' in caplog.text
def test_retire_task_error(test_learner, caplog): # pylint: disable=redefined-outer-name
with mock_get_user():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
mock_query_response = {'totalSize': 1, 'records': [{'Id': 1}]}
api._sf.query.return_value = mock_query_response # pylint: disable=protected-access
mock_task_response = {'success': False, 'errors': ["This is an error!"]}
api._sf.Task.create.return_value = mock_task_response # pylint: disable=protected-access
with pytest.raises(Exception) as exc:
api.retire_learner(test_learner)
assert "Errors while creating task:" in caplog.text
assert "This is an error!" in caplog.text
assert "Unable to create retirement task for email foo@bar.com" in str(exc)
def test_retire_task_exception(test_learner): # pylint: disable=redefined-outer-name
with mock_get_user():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
mock_query_response = {'totalSize': 1, 'records': [{'Id': 1}]}
api._sf.query.return_value = mock_query_response # pylint: disable=protected-access
api._sf.Task.create.side_effect = SalesforceError("", "", "", "") # pylint: disable=protected-access
with pytest.raises(SalesforceError):
api.retire_learner(test_learner)
def test_retire_success(test_learner, caplog): # pylint: disable=redefined-outer-name
caplog.set_level(logging.INFO)
with mock_get_user():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
mock_query_response = {'totalSize': 1, 'records': [{'Id': 1}]}
api._sf.query.return_value = mock_query_response # pylint: disable=protected-access
mock_task_response = {'success': True, 'id': 'task-id'}
api._sf.Task.create.return_value = mock_task_response # pylint: disable=protected-access
api.retire_learner(test_learner)
assert "Successfully salesforce task created task task-id" in caplog.text
def test_retire_multiple_learners(test_learner, caplog): # pylint: disable=redefined-outer-name
caplog.set_level(logging.INFO)
with mock_get_user():
with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'):
api = make_api()
mock_response = {'totalSize': 2, 'records': [{'Id': 1}, {'Id': 2}]}
api._sf.query.return_value = mock_response # pylint: disable=protected-access
mock_task_response = {'success': True, 'id': 'task-id'}
api._sf.Task.create.return_value = mock_task_response # pylint: disable=protected-access
api.retire_learner(test_learner)
assert "Multiple Ids returned for Lead with email foo@bar.com" in caplog.text
assert "Successfully salesforce task created task task-id" in caplog.text
note = "Notice: Multiple leads were identified with the same email. Please retire all following leads:"
assert note in api._sf.Task.create.call_args[0][0]['Description'] # pylint: disable=protected-access

View File

@@ -0,0 +1,170 @@
"""
Tests for the Segment API functionality
"""
import json
import mock
import pytest
import requests
from six import text_type
from scripts.user_retirement.tests.retirement_helpers import get_fake_user_retirement
from scripts.user_retirement.utils.thirdparty_apis.segment_api import BULK_REGULATE_URL, SegmentApi
FAKE_AUTH_TOKEN = 'FakeToken'
TEST_SEGMENT_CONFIG = {
'projects_to_retire': ['project_1', 'project_2'],
'learner': [get_fake_user_retirement(), ],
'fake_base_url': 'https://segment.invalid/',
'fake_auth_token': FAKE_AUTH_TOKEN,
'fake_workspace': 'FakeEdx',
'headers': {"Authorization": "Bearer {}".format(FAKE_AUTH_TOKEN), "Content-Type": "application/json"}
}
class FakeResponse:
"""
Fakes out requests.post response
"""
def json(self):
"""
Returns fake Segment retirement response data in the correct format
"""
return {'regulate_id': 1}
def raise_for_status(self):
pass
class FakeErrorResponse:
"""
Fakes an error response
"""
status_code = 500
text = "{'error': 'Test error message'}"
def json(self):
"""
Returns fake Segment retirement response error in the correct format
"""
return json.loads(self.text)
def raise_for_status(self):
raise requests.exceptions.HTTPError("", response=self)
@pytest.fixture
def setup_regulation_api():
"""
Fixture to setup common bulk delete items.
"""
with mock.patch('requests.post') as mock_post:
segment = SegmentApi(
*[TEST_SEGMENT_CONFIG[key] for key in [
'fake_base_url', 'fake_auth_token', 'fake_workspace'
]]
)
yield mock_post, segment
def test_bulk_delete_success(setup_regulation_api): # pylint: disable=redefined-outer-name
"""
Test simple success case
"""
mock_post, segment = setup_regulation_api
mock_post.return_value = FakeResponse()
learner = TEST_SEGMENT_CONFIG['learner']
segment.delete_and_suppress_learners(learner, 1000)
assert mock_post.call_count == 1
expected_learner = get_fake_user_retirement()
learners_vals = [
text_type(expected_learner['user']['id']),
expected_learner['original_username'],
expected_learner['ecommerce_segment_id'],
]
fake_json = {
"regulation_type": "Suppress_With_Delete",
"attributes": {
"name": "userId",
"values": learners_vals
}
}
url = TEST_SEGMENT_CONFIG['fake_base_url'] + BULK_REGULATE_URL.format(TEST_SEGMENT_CONFIG['fake_workspace'])
mock_post.assert_any_call(
url, json=fake_json, headers=TEST_SEGMENT_CONFIG['headers']
)
def test_bulk_delete_error(setup_regulation_api, caplog): # pylint: disable=redefined-outer-name
"""
Test simple error case
"""
mock_post, segment = setup_regulation_api
mock_post.return_value = FakeErrorResponse()
learner = TEST_SEGMENT_CONFIG['learner']
with pytest.raises(Exception):
segment.delete_and_suppress_learners(learner, 1000)
assert mock_post.call_count == 4
assert "Error was encountered for params:" in caplog.text
assert "9009" in caplog.text
assert "foo_username" in caplog.text
assert "ecommerce-90" in caplog.text
assert "Suppress_With_Delete" in caplog.text
assert "Test error message" in caplog.text
def test_bulk_unsuppress_success(setup_regulation_api): # pylint: disable=redefined-outer-name
"""
Test simple success case
"""
mock_post, segment = setup_regulation_api
mock_post.return_value = FakeResponse()
learner = TEST_SEGMENT_CONFIG['learner']
segment.unsuppress_learners_by_key('original_username', learner, 100)
assert mock_post.call_count == 1
expected_learner = get_fake_user_retirement()
fake_json = {
"regulation_type": "Unsuppress",
"attributes": {
"name": "userId",
"values": [expected_learner['original_username'], ]
}
}
url = TEST_SEGMENT_CONFIG['fake_base_url'] + BULK_REGULATE_URL.format(TEST_SEGMENT_CONFIG['fake_workspace'])
mock_post.assert_any_call(
url, json=fake_json, headers=TEST_SEGMENT_CONFIG['headers']
)
def test_bulk_unsuppress_error(setup_regulation_api, caplog): # pylint: disable=redefined-outer-name
"""
Test simple error case
"""
mock_post, segment = setup_regulation_api
mock_post.return_value = FakeErrorResponse()
learner = TEST_SEGMENT_CONFIG['learner']
with pytest.raises(Exception):
segment.unsuppress_learners_by_key('original_username', learner, 100)
assert mock_post.call_count == 4
assert "Error was encountered for params:" in caplog.text
assert "9009" not in caplog.text
assert "foo_username" in caplog.text
assert "ecommerce-90" not in caplog.text
assert "Unsuppress" in caplog.text
assert "Test error message" in caplog.text

View File

@@ -0,0 +1,522 @@
"""
edX API classes which call edX service REST API endpoints using the edx-rest-api-client module.
"""
import logging
from urllib.parse import urljoin
import backoff
import requests
from edx_rest_api_client.auth import SuppliedJwtAuth
from edx_rest_api_client.client import REQUEST_CONNECT_TIMEOUT, REQUEST_READ_TIMEOUT
from requests.exceptions import ConnectionError, HTTPError, Timeout
from scripts.user_retirement.utils.exception import HttpDoesNotExistException
LOG = logging.getLogger(__name__)
OAUTH_ACCESS_TOKEN_URL = "/oauth2/access_token"
class EdxGatewayTimeoutError(Exception):
"""
Exception used to indicate a 504 server error was returned.
Differentiates from other 5xx errors.
"""
class BaseApiClient:
"""
API client base class used to submit API requests to a particular web service.
"""
append_slash = True
_access_token = None
def __init__(self, lms_base_url, api_base_url, client_id, client_secret):
"""
Retrieves OAuth access token from the LMS and creates REST API client instance.
"""
self.api_base_url = api_base_url
self._access_token = self.get_access_token(lms_base_url, client_id, client_secret)
def get_api_url(self, path):
"""
Construct the full API URL using the api_base_url and path.
Args:
path (str): API endpoint path.
"""
path = path.strip('/')
if self.append_slash:
path += '/'
return urljoin(f'{self.api_base_url}/', path)
def _request(self, method, url, log_404_as_error=True, **kwargs):
if 'headers' not in kwargs:
kwargs['headers'] = {'Content-type': 'application/json'}
try:
response = requests.request(method, url, auth=SuppliedJwtAuth(self._access_token), **kwargs)
response.raise_for_status()
if response.status_code != 204:
return response.json()
except HTTPError as exc:
status_code = exc.response.status_code
if status_code == 404 and not log_404_as_error:
# Immediately raise the error so that a 404 isn't logged as an API error in this case.
raise HttpDoesNotExistException(str(exc))
LOG.error(f'API Error: {str(exc)} with status code: {status_code}')
if status_code == 504:
# Differentiate gateway errors so different backoff can be used.
raise EdxGatewayTimeoutError(str(exc))
if status_code == 404:
raise HttpDoesNotExistException(str(exc))
raise
except Timeout:
LOG.error("The request is timed out.")
raise
return response
@staticmethod
def get_access_token(oauth_base_url, client_id, client_secret):
"""
Returns an access token for this site's service user.
Returns:
str: JWT access token
"""
oauth_access_token_url = urljoin(f'{oauth_base_url}/', OAUTH_ACCESS_TOKEN_URL)
data = {
'grant_type': 'client_credentials',
'client_id': client_id,
'client_secret': client_secret,
'token_type': 'jwt',
}
try:
response = requests.post(
oauth_access_token_url,
data=data,
headers={
'User-Agent': 'scripts.user_retirement',
},
timeout=(REQUEST_CONNECT_TIMEOUT, REQUEST_READ_TIMEOUT)
)
response.raise_for_status()
return response.json()['access_token']
except KeyError as exc:
LOG.error(f'Failed to get token. {str(exc)} does not exist.')
raise
except HTTPError as exc:
LOG.error(
f'API Error: {str(exc)} with status code: {exc.response.status_code} fetching access token: {client_id}'
)
raise
def _backoff_handler(details):
"""
Simple logging handler for when timeout backoff occurs.
"""
LOG.info('Trying again in {wait:0.1f} seconds after {tries} tries calling {target}'.format(**details))
def _wait_one_minute():
"""
Backoff generator that waits for 60 seconds.
"""
return backoff.constant(interval=60)
def _giveup_on_unexpected_exception(exc):
"""
Giveup method that gives up backoff upon any unexpected exception.
"""
keep_retrying = (
# Treat a ConnectionError as retryable.
isinstance(exc, ConnectionError)
# All 5xx status codes are retryable except for 504 Gateway Timeout.
or (
500 <= exc.response.status_code < 600
and exc.response.status_code != 504 # Gateway Timeout
)
# Status code 104 is unreserved, but we must have added this because we observed retryable 104 responses.
or exc.response.status_code == 104
)
return not keep_retrying
def _retry_lms_api():
"""
Decorator which enables retries with sane backoff defaults for LMS APIs.
"""
def inner(func): # pylint: disable=missing-docstring
func_with_backoff = backoff.on_exception(
backoff.expo,
(HTTPError, ConnectionError),
max_time=600, # 10 minutes
giveup=_giveup_on_unexpected_exception,
# Wrap the actual _backoff_handler so that we can patch the real one in unit tests. Otherwise, the func
# will get decorated on import, embedding this handler as a python object reference, precluding our ability
# to patch it in tests.
on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda
)
func_with_timeout_backoff = backoff.on_exception(
_wait_one_minute,
EdxGatewayTimeoutError,
max_tries=2,
# Wrap the actual _backoff_handler so that we can patch the real one in unit tests. Otherwise, the func
# will get decorated on import, embedding this handler as a python object reference, precluding our ability
# to patch it in tests.
on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda
)
return func_with_backoff(func_with_timeout_backoff(func))
return inner
class LmsApi(BaseApiClient):
"""
LMS API client with convenience methods for making API calls.
"""
@_retry_lms_api()
def learners_to_retire(self, states_to_request, cool_off_days=7, limit=None):
"""
Retrieves a list of learners awaiting retirement actions.
"""
params = {
'cool_off_days': cool_off_days,
'states': states_to_request
}
if limit:
params['limit'] = limit
api_url = self.get_api_url('api/user/v1/accounts/retirement_queue')
return self._request('GET', api_url, params=params)
@_retry_lms_api()
def get_learners_by_date_and_status(self, state_to_request, start_date, end_date):
"""
Retrieves a list of learners in the given retirement state that were
created in the retirement queue between the dates given. Date range
is inclusive, so to get one day you would set both dates to that day.
:param state_to_request: String LMS UserRetirementState state name (ex. COMPLETE)
:param start_date: Date or Datetime object
:param end_date: Date or Datetime
"""
params = {
'start_date': start_date.strftime('%Y-%m-%d'),
'end_date': end_date.strftime('%Y-%m-%d'),
'state': state_to_request
}
api_url = self.get_api_url('api/user/v1/accounts/retirements_by_status_and_date')
return self._request('GET', api_url, params=params)
@_retry_lms_api()
def get_learner_retirement_state(self, username):
"""
Retrieves the given learner's retirement state.
"""
api_url = self.get_api_url(f'api/user/v1/accounts/{username}/retirement_status')
return self._request('GET', api_url)
@_retry_lms_api()
def update_learner_retirement_state(self, username, new_state_name, message, force=False):
"""
Updates the given learner's retirement state to the retirement state name new_string
with the additional string information in message (for logging purposes).
"""
data = {
'username': username,
'new_state': new_state_name,
'response': message
}
if force:
data['force'] = True
api_url = self.get_api_url('api/user/v1/accounts/update_retirement_status')
return self._request('PATCH', api_url, json=data)
@_retry_lms_api()
def retirement_deactivate_logout(self, learner):
"""
Performs the user deactivation and forced logout step of learner retirement
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/user/v1/accounts/deactivate_logout')
return self._request('POST', api_url, json=data)
@_retry_lms_api()
def retirement_retire_forum(self, learner):
"""
Performs the forum retirement step of learner retirement
"""
# api/discussion/
data = {'username': learner['original_username']}
try:
api_url = self.get_api_url('api/discussion/v1/accounts/retire_forum')
return self._request('POST', api_url, json=data)
except HttpDoesNotExistException:
LOG.info("No information about learner retirement")
return True
@_retry_lms_api()
def retirement_retire_mailings(self, learner):
"""
Performs the email list retirement step of learner retirement
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/user/v1/accounts/retire_mailings')
return self._request('POST', api_url, json=data)
@_retry_lms_api()
def retirement_unenroll(self, learner):
"""
Unenrolls the user from all courses
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/enrollment/v1/unenroll')
return self._request('POST', api_url, json=data)
# This endpoint additionally returns 500 when the EdxNotes backend service is unavailable.
@_retry_lms_api()
def retirement_retire_notes(self, learner):
"""
Deletes all the user's notes (aka. annotations)
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/edxnotes/v1/retire_user')
return self._request('POST', api_url, json=data)
@_retry_lms_api()
def retirement_lms_retire_misc(self, learner):
"""
Deletes, blanks, or one-way hashes personal information in LMS as
defined in EDUCATOR-2802 and sub-tasks.
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/user/v1/accounts/retire_misc')
return self._request('POST', api_url, json=data)
@_retry_lms_api()
def retirement_lms_retire(self, learner):
"""
Deletes, blanks, or one-way hashes all remaining personal information in LMS
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/user/v1/accounts/retire')
return self._request('POST', api_url, json=data)
@_retry_lms_api()
def retirement_partner_queue(self, learner):
"""
Calls LMS to add the given user to the retirement reporting queue
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/user/v1/accounts/retirement_partner_report')
return self._request('PUT', api_url, json=data)
@_retry_lms_api()
def retirement_partner_report(self):
"""
Retrieves the list of users to create partner reports for and set their status to
processing
"""
api_url = self.get_api_url('api/user/v1/accounts/retirement_partner_report')
return self._request('POST', api_url)
@_retry_lms_api()
def retirement_partner_cleanup(self, usernames):
"""
Removes the given users from the partner reporting queue
"""
api_url = self.get_api_url('api/user/v1/accounts/retirement_partner_report_cleanup')
return self._request('POST', api_url, json=usernames)
@_retry_lms_api()
def retirement_retire_proctoring_data(self, learner):
"""
Deletes or hashes learner data from edx-proctoring
"""
api_url = self.get_api_url(f"api/edx_proctoring/v1/retire_user/{learner['user']['id']}")
return self._request('POST', api_url)
@_retry_lms_api()
def retirement_retire_proctoring_backend_data(self, learner):
"""
Removes the given learner from 3rd party proctoring backends
"""
api_url = self.get_api_url(f"api/edx_proctoring/v1/retire_backend_user/{learner['user']['id']}")
return self._request('POST', api_url)
@_retry_lms_api()
def bulk_cleanup_retirements(self, usernames):
"""
Deletes the retirements for all given usernames
"""
data = {'usernames': usernames}
api_url = self.get_api_url('api/user/v1/accounts/retirement_cleanup')
return self._request('POST', api_url, json=data)
def replace_lms_usernames(self, username_mappings):
"""
Calls LMS API to replace usernames.
Param:
username_mappings: list of dicts where key is current username and value is new desired username
[{current_un_1: desired_un_1}, {current_un_2: desired_un_2}]
"""
data = {"username_mappings": username_mappings}
api_url = self.get_api_url('api/user/v1/accounts/replace_usernames')
return self._request('POST', api_url, json=data)
def replace_forums_usernames(self, username_mappings):
"""
Calls the discussion forums API inside of LMS to replace usernames.
Param:
username_mappings: list of dicts where key is current username and value is new unique username
[{current_un_1: new_un_1}, {current_un_2: new_un_2}]
"""
data = {"username_mappings": username_mappings}
api_url = self.get_api_url('api/discussion/v1/accounts/replace_usernames')
return self._request('POST', api_url, json=data)
class EcommerceApi(BaseApiClient):
"""
Ecommerce API client with convenience methods for making API calls.
"""
@_retry_lms_api()
def retire_learner(self, learner):
"""
Performs the learner retirement step for Ecommerce
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('api/v2/user/retire')
return self._request('POST', api_url, json=data)
@_retry_lms_api()
def get_tracking_key(self, learner):
"""
Fetches the ecommerce tracking id used for Segment tracking when
ecommerce doesn't have access to the LMS user id.
"""
api_url = self.get_api_url(f"api/v2/retirement/tracking_id/{learner['original_username']}")
return self._request('GET', api_url)['ecommerce_tracking_id']
def replace_usernames(self, username_mappings):
"""
Calls the ecommerce API to replace usernames.
Param:
username_mappings: list of dicts where key is current username and value is new unique username
[{current_un_1: new_un_1}, {current_un_2: new_un_2}]
"""
data = {"username_mappings": username_mappings}
api_url = self.get_api_url('api/v2/user_management/replace_usernames')
return self._request('POST', api_url, json=data)
class CredentialsApi(BaseApiClient):
"""
Credentials API client with convenience methods for making API calls.
"""
@_retry_lms_api()
def retire_learner(self, learner):
"""
Performs the learner retirement step for Credentials
"""
data = {'username': learner['original_username']}
api_url = self.get_api_url('user/retire')
return self._request('POST', api_url, json=data)
def replace_usernames(self, username_mappings):
"""
Calls the credentials API to replace usernames.
Param:
username_mappings: list of dicts where key is current username and value is new unique username
[{current_un_1: new_un_1}, {current_un_2: new_un_2}]
"""
data = {"username_mappings": username_mappings}
api_url = self.get_api_url('api/v2/replace_usernames')
return self._request('POST', api_url, json=data)
class DiscoveryApi(BaseApiClient):
"""
Discovery API client with convenience methods for making API calls.
"""
def replace_usernames(self, username_mappings):
"""
Calls the discovery API to replace usernames.
Param:
username_mappings: list of dicts where key is current username and value is new unique username
[{current_un_1: new_un_1}, {current_un_2: new_un_2}]
"""
data = {"username_mappings": username_mappings}
api_url = self.get_api_url('api/v1/replace_usernames')
return self._request('POST', api_url, json=data)
class DemographicsApi(BaseApiClient):
"""
Demographics API client.
"""
@_retry_lms_api()
def retire_learner(self, learner):
"""
Performs the learner retirement step for Demographics. Passes the learner's LMS User Id instead of username.
"""
data = {'lms_user_id': learner['user']['id']}
# If the user we are retiring has no data in the Demographics DB the request will return a 404. We
# catch the HTTPError and return True in order to prevent this error getting raised and
# incorrectly causing the learner to enter an ERROR state during retirement.
try:
api_url = self.get_api_url('demographics/api/v1/retire_demographics')
return self._request('POST', api_url, log_404_as_error=False, json=data)
except HttpDoesNotExistException:
LOG.info("No demographics data found for user")
return True
class LicenseManagerApi(BaseApiClient):
"""
License Manager API client.
"""
@_retry_lms_api()
def retire_learner(self, learner):
"""
Performs the learner retirement step for License manager. Passes the learner's LMS User Id in addition to
username.
"""
data = {
'lms_user_id': learner['user']['id'],
'original_username': learner['original_username'],
}
# If the user we are retiring has no data in the License Manager DB the request will return a 404. We
# catch the HTTPError and return True in order to prevent this error getting raised and
# incorrectly causing the learner to enter an ERROR state during retirement.
try:
api_url = self.get_api_url('api/v1/retire_user')
return self._request('POST', api_url, log_404_as_error=False, json=data)
except HttpDoesNotExistException:
LOG.info("No license manager data found for user")
return True

View File

@@ -0,0 +1,85 @@
"""
Convenience functions using boto and AWS SES to send email.
"""
import logging
import backoff
import boto3
from scripts.user_retirement.utils.exception import BackendError
from scripts.user_retirement.utils.utils import envvar_get_int
LOG = logging.getLogger(__name__)
# Default maximum number of attempts to send email.
MAX_EMAIL_TRIES_DEFAULT = 10
def _poll_giveup(results):
"""
Raise an error when the polling tries are exceeded.
"""
orig_args = results['args']
msg = 'Timed out after {tries} attempts to send email with subject "{subject}".'.format(
tries=results['tries'],
subject=orig_args[3]
)
raise BackendError(msg)
@backoff.on_exception(backoff.expo,
Exception,
max_tries=envvar_get_int("MAX_EMAIL_TRIES", MAX_EMAIL_TRIES_DEFAULT),
on_giveup=_poll_giveup)
def _send_email_with_retry(ses_conn,
from_address,
to_addresses,
subject,
body):
"""
Send email, retrying upon exception.
"""
ses_conn.send_email(
Source=from_address,
Message={
"Body": {
"Text": {
"Charset": "UTF-8",
"Data": body,
},
},
"Subject": {
"Charset": "UTF-8",
"Data": subject,
},
},
Destination={
"ToAddresses": to_addresses,
},
)
def send_email(aws_region,
from_address,
to_addresses,
subject,
body):
"""
Send an email via AWS SES using boto with the specified subject/body/recipients.
Args:
aws_region (str): AWS region whose SES service will be used, e.g. "us-east-1".
from_address (str): Email address to use as the From: address. Must be an SES verified address.
to_addresses (list(str)): List of email addresses to which to send the email.
subject (str): Subject to use in the email.
body (str): Body to use in the email - text format.
"""
ses_conn = boto3.client("ses", region_name=aws_region)
_send_email_with_retry(
ses_conn,
from_address,
to_addresses,
subject,
body
)

View File

@@ -0,0 +1,14 @@
"""
Exceptions used by various utilities.
"""
class BackendError(Exception):
pass
class HttpDoesNotExistException(Exception):
"""
Called when the server sends a 404 error.
"""
pass

View File

@@ -0,0 +1,244 @@
"""
Common helper methods to use in user retirement scripts.
"""
# NOTE: Make sure that all non-ascii text written to standard output (including
# print statements and logging) is manually encoded to bytes using a utf-8 or
# other encoding. We currently make use of this library within a context that
# does NOT tolerate unicode text on sys.stdout, namely python 2 on Build
# Jenkins. PLAT-2287 tracks this Tech Debt.
import io
import json
import sys
import traceback
import unicodedata
import yaml
from six import text_type
from scripts.user_retirement.utils.edx_api import LmsApi # pylint: disable=wrong-import-position
from scripts.user_retirement.utils.edx_api import CredentialsApi, DemographicsApi, EcommerceApi, LicenseManagerApi
from scripts.user_retirement.utils.thirdparty_apis.amplitude_api import \
AmplitudeApi # pylint: disable=wrong-import-position
from scripts.user_retirement.utils.thirdparty_apis.braze_api import BrazeApi # pylint: disable=wrong-import-position
from scripts.user_retirement.utils.thirdparty_apis.hubspot_api import \
HubspotAPI # pylint: disable=wrong-import-position
from scripts.user_retirement.utils.thirdparty_apis.salesforce_api import \
SalesforceApi # pylint: disable=wrong-import-position
from scripts.user_retirement.utils.thirdparty_apis.segment_api import \
SegmentApi # pylint: disable=wrong-import-position
def _log(kind, message):
"""
Convenience method to log text. Prepended "kind" text makes finding log entries easier.
"""
print(u'{}: {}'.format(kind, message).encode('utf-8')) # See note at the top of this file.
def _fail(kind, code, message):
"""
Convenience method to fail out of the command with a message and traceback.
"""
_log(kind, message)
# Try to get a traceback, if there is one. On Python 3.4 this raises an AttributeError
# if there is no current exception, so we eat that here.
try:
_log(kind, traceback.format_exc())
except AttributeError:
pass
sys.exit(code)
def _fail_exception(kind, code, message, exc):
"""
A version of fail that takes an exception to be utf-8 decoded
"""
exc_msg = _get_error_str_from_exception(exc)
message += '\n' + exc_msg
_fail(kind, code, message)
def _get_error_str_from_exception(exc):
"""
Return a string from an exception that may or may not have a .content (Slumber)
"""
exc_msg = text_type(exc)
if hasattr(exc, 'content'):
# Slumber inconveniently discards the decoded .text attribute from the Response object,
# and instead gives us the raw encoded .content attribute, so we need to decode it first.
# Python 2 needs the decode, Py3 does not have it.
try:
exc_msg += '\n' + str(exc.content).decode('utf-8')
except AttributeError:
exc_msg += '\n' + str(exc.content)
return exc_msg
def _config_or_exit(fail_func, fail_code, config_file):
"""
Returns the config values from the given file, allows overriding of passed in values.
"""
try:
with io.open(config_file, 'r') as config:
config = yaml.safe_load(config)
return config
except Exception as exc: # pylint: disable=broad-except
fail_func(fail_code, 'Failed to read config file {}'.format(config_file), exc)
def _config_with_drive_or_exit(fail_func, config_fail_code, google_fail_code, config_file, google_secrets_file):
"""
Returns the config values from the given file, allows overriding of passed in values.
"""
try:
with io.open(config_file, 'r') as config:
config = yaml.safe_load(config)
# Check required values
for var in ('org_partner_mapping', 'drive_partners_folder'):
if var not in config or not config[var]:
fail_func(config_fail_code, 'No {} in config, or it is empty!'.format(var), ValueError())
# Force the partner names into NFKC here and when we get the folders to ensure
# they are using the same characters. Otherwise accented characters will not match.
for org in config['org_partner_mapping']:
partner = config['org_partner_mapping'][org]
config['org_partner_mapping'][org] = [unicodedata.normalize('NFKC', text_type(partner)) for partner in
config['org_partner_mapping'][org]]
except Exception as exc: # pylint: disable=broad-except
fail_func(config_fail_code, 'Failed to read config file {}'.format(config_file), exc)
try:
# Just load and parse the file to make sure it's legit JSON before doing
# all of the work to get the users.
with open(google_secrets_file, 'r') as secrets_f:
json.load(secrets_f)
config['google_secrets_file'] = google_secrets_file
return config
except Exception as exc: # pylint: disable=broad-except
fail_func(google_fail_code, 'Failed to read secrets file {}'.format(google_secrets_file), exc)
def _setup_lms_api_or_exit(fail_func, fail_code, config):
"""
Performs setup of EdxRestClientApi for LMS and returns the validated, sorted list of users to report on.
"""
try:
lms_base_url = config['base_urls']['lms']
client_id = config['client_id']
client_secret = config['client_secret']
config['LMS'] = LmsApi(lms_base_url, lms_base_url, client_id, client_secret)
except Exception as exc: # pylint: disable=broad-except
fail_func(fail_code, text_type(exc))
def _setup_all_apis_or_exit(fail_func, fail_code, config):
"""
Performs setup of EdxRestClientApi instances for LMS, E-Commerce, Credentials, and
Demographics, as well as fetching the learner's record from LMS and validating that
it is in a state to work on. Returns the learner dict and their current stage in the
retirement flow.
"""
try:
lms_base_url = config['base_urls']['lms']
ecommerce_base_url = config['base_urls'].get('ecommerce', None)
credentials_base_url = config['base_urls'].get('credentials', None)
segment_base_url = config['base_urls'].get('segment', None)
demographics_base_url = config['base_urls'].get('demographics', None)
license_manager_base_url = config['base_urls'].get('license_manager', None)
client_id = config['client_id']
client_secret = config['client_secret']
braze_api_key = config.get('braze_api_key', None)
braze_instance = config.get('braze_instance', None)
amplitude_api_key = config.get('amplitude_api_key', None)
amplitude_secret_key = config.get('amplitude_secret_key', None)
salesforce_user = config.get('salesforce_user', None)
salesforce_password = config.get('salesforce_password', None)
salesforce_token = config.get('salesforce_token', None)
salesforce_domain = config.get('salesforce_domain', None)
salesforce_assignee = config.get('salesforce_assignee', None)
segment_auth_token = config.get('segment_auth_token', None)
segment_workspace_slug = config.get('segment_workspace_slug', None)
hubspot_api_key = config.get('hubspot_api_key', None)
hubspot_aws_region = config.get('hubspot_aws_region', None)
hubspot_from_address = config.get('hubspot_from_address', None)
hubspot_alert_email = config.get('hubspot_alert_email', None)
for state in config['retirement_pipeline']:
for service, service_url in (
('BRAZE', braze_api_key),
('AMPLITUDE', amplitude_api_key),
('ECOMMERCE', ecommerce_base_url),
('CREDENTIALS', credentials_base_url),
('SEGMENT', segment_base_url),
('HUBSPOT', hubspot_api_key),
('DEMOGRAPHICS', demographics_base_url)
):
if state[2] == service and service_url is None:
fail_func(fail_code, 'Service URL is not configured, but required for state {}'.format(state))
config['LMS'] = LmsApi(lms_base_url, lms_base_url, client_id, client_secret)
if braze_api_key:
config['BRAZE'] = BrazeApi(
braze_api_key,
braze_instance,
)
if amplitude_api_key and amplitude_secret_key:
config['AMPLITUDE'] = AmplitudeApi(
amplitude_api_key,
amplitude_secret_key,
)
if salesforce_user and salesforce_password and salesforce_token:
config['SALESFORCE'] = SalesforceApi(
salesforce_user,
salesforce_password,
salesforce_token,
salesforce_domain,
salesforce_assignee
)
if hubspot_api_key:
config['HUBSPOT'] = HubspotAPI(
hubspot_api_key,
hubspot_aws_region,
hubspot_from_address,
hubspot_alert_email
)
if ecommerce_base_url:
config['ECOMMERCE'] = EcommerceApi(lms_base_url, ecommerce_base_url, client_id, client_secret)
if credentials_base_url:
config['CREDENTIALS'] = CredentialsApi(lms_base_url, credentials_base_url, client_id, client_secret)
if demographics_base_url:
config['DEMOGRAPHICS'] = DemographicsApi(lms_base_url, demographics_base_url, client_id, client_secret)
if license_manager_base_url:
config['LICENSE_MANAGER'] = LicenseManagerApi(
lms_base_url,
license_manager_base_url,
client_id,
client_secret,
)
if segment_base_url:
config['SEGMENT'] = SegmentApi(
segment_base_url,
segment_auth_token,
segment_workspace_slug
)
except Exception as exc: # pylint: disable=broad-except
fail_func(fail_code, 'Unexpected error occurred!', exc)

View File

@@ -0,0 +1,201 @@
"""
Methods to interact with the Jenkins API to perform various tasks.
"""
import logging
import math
import os.path
import shutil
import sys
import backoff
from jenkinsapi.custom_exceptions import JenkinsAPIException
from jenkinsapi.jenkins import Jenkins
from jenkinsapi.utils.crumb_requester import CrumbRequester
from requests.exceptions import HTTPError
from scripts.user_retirement.utils.exception import BackendError
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
LOG = logging.getLogger(__name__)
def _recreate_directory(directory):
"""
Deletes an existing directory recursively (if exists) and (re-)creates it.
"""
if os.path.exists(directory):
shutil.rmtree(directory)
os.mkdir(directory)
def export_learner_job_properties(learners, directory):
"""
Creates a Jenkins properties file for each learner in order to make
a retirement slave job for each learner.
Args:
learners (list of dicts): List of learners for which to create properties files.
directory (str): Directory in which to create the properties files.
"""
_recreate_directory(directory)
for learner in learners:
learner_name = learner['original_username'].lower()
filename = os.path.join(directory, 'learner_retire_{}'.format(learner_name))
with open(filename, 'w') as learner_prop_file:
learner_prop_file.write('RETIREMENT_USERNAME={}\n'.format(learner['original_username']))
def _poll_giveup(data):
u""" Raise an error when the polling tries are exceeded."""
orig_args = data.get(u'args')
# The Build object was the only parameter to the original method call,
# and so it's the first and only item in the args.
build = orig_args[0]
msg = u'Timed out waiting for build {} to finish.'.format(build.name)
raise BackendError(msg)
def _backoff_timeout(timeout, base=2, factor=1):
u"""
Return a tuple of (wait_gen, max_tries) so that backoff will only try up to `timeout` seconds.
|timeout (s)|max attempts|wait durations |
|----------:|-----------:|---------------------:|
|1 |2 |1 |
|5 |4 |1, 2, 2 |
|10 |5 |1, 2, 4, 3 |
|30 |6 |1, 2, 4, 8, 13 |
|60 |8 |1, 2, 4, 8, 16, 32, 37|
|300 |10 |1, 2, 4, 8, 16, 32, 64|
| | |128, 44 |
|600 |11 |1, 2, 4, 8, 16, 32, 64|
| | |128, 256, 89 |
|3600 |13 |1, 2, 4, 8, 16, 32, 64|
| | |128, 256, 512, 1024, |
| | |1553 |
"""
# Total duration of sum(factor * base ** n for n in range(K)) = factor*(base**K - 1)/(base - 1),
# where K is the number of retries, or max_tries - 1 (since the first try doesn't require a wait)
#
# Solving for K, K = log(timeout * (base - 1) / factor + 1, base)
#
# Using the next smallest integer K will give us a number of elements from
# the exponential sequence to take and still be less than the timeout.
tries = int(math.log(timeout * (base - 1) / factor + 1, base))
remainder = timeout - (factor * (base ** tries - 1)) / (base - 1)
def expo():
u"""Compute an exponential backoff wait period, but capped to an expected max timeout"""
# pylint: disable=invalid-name
n = 0
while True:
a = factor * base ** n
if n >= tries:
yield remainder
else:
yield a
n += 1
# tries tells us the largest standard wait using the standard progression (before being capped)
# tries + 1 because backoff waits one fewer times than max_tries (the first attempt has no wait time).
# If a remainder, then we need to make one last attempt to get the target timeout (so tries + 2)
if remainder == 0:
return expo, tries + 1
else:
return expo, tries + 2
def trigger_build(base_url, user_name, user_token, job_name, job_token,
job_cause=None, job_params=None, timeout=60 * 30):
u"""
Trigger a jenkins job/project (note that jenkins uses these terms interchangeably)
Args:
base_url (str): The base URL for the jenkins server, e.g. https://test-jenkins.testeng.edx.org
user_name (str): The jenkins username
user_token (str): API token for the user. Available at {base_url}/user/{user_name)/configure
job_name (str): The Jenkins job name, e.g. test-project
job_token (str): Jobs must be configured with the option "Trigger builds remotely" selected.
Under this option, you must provide an authorization token (configured in the job)
in the form of a string so that only those who know it would be able to remotely
trigger this project's builds.
job_cause (str): Text that will be included in the recorded build cause
job_params (set of tuples): Parameter names and their values to pass to the job
timeout (int): The maximum number of seconds to wait for the jenkins build to complete (measured
from when the job is triggered.)
Returns:
A the status of the build that was triggered
Raises:
BackendError: if the Jenkins job could not be triggered successfully
"""
@backoff.on_predicate(
backoff.constant,
interval=60,
max_tries=timeout / 60 + 1,
on_giveup=_poll_giveup,
# We aren't worried about concurrent access, so turn off jitter
jitter=None,
)
def poll_build_for_result(build):
u"""
Poll for the build running, with exponential backoff, capped to ``timeout`` seconds.
The on_predicate decorator is used to retry when the return value
of the target function is True.
"""
return not build.is_running()
# Create a dict with key/value pairs from the job_params
# that were passed in like this: --param FOO bar --param BAZ biz
# These will get passed to the job as string parameters like this:
# {u'FOO': u'bar', u'BAX': u'biz'}
request_params = {}
for param in job_params:
request_params[param[0]] = param[1]
# Contact jenkins, log in, and get the base data on the system.
try:
crumb_requester = CrumbRequester(
baseurl=base_url, username=user_name, password=user_token,
ssl_verify=True
)
jenkins = Jenkins(
base_url, username=user_name, password=user_token,
requester=crumb_requester
)
except (JenkinsAPIException, HTTPError) as err:
raise BackendError(str(err))
if not jenkins.has_job(job_name):
msg = u'Job not found: {}.'.format(job_name)
msg += u' Verify that you have permissions for the job and double check the spelling of its name.'
raise BackendError(msg)
# This will start the job and will return a QueueItem object which can be used to get build results
job = jenkins[job_name]
queue_item = job.invoke(securitytoken=job_token, build_params=request_params, cause=job_cause)
LOG.info(u'Added item to jenkins. Server: {} Job: {} '.format(
jenkins.base_server_url(), queue_item
))
# Block this script until we are through the queue and the job has begun to build.
queue_item.block_until_building()
build = queue_item.get_build()
LOG.info(u'Created build {}'.format(build))
LOG.info(u'See {}'.format(build.baseurl))
# Now block until you get a result back from the build.
poll_build_for_result(build)
# Update the build's internal state, so that the final status is available
build.poll()
status = build.get_status()
LOG.info(u'Build status: {status}'.format(status=status))
return status

View File

@@ -0,0 +1,91 @@
"""
Amplitude API class that is used to delete user from Amplitude.
"""
import logging
import os
import backoff
import requests
logger = logging.getLogger(__name__)
MAX_ATTEMPTS = int(os.environ.get("RETRY_MAX_ATTEMPTS", 5))
class AmplitudeException(Exception):
"""
AmplitudeException will be raised there is fatal error and is not recoverable.
"""
pass
class AmplitudeRecoverableException(AmplitudeException):
"""
AmplitudeRecoverableException will be raised when request can be retryable.
"""
pass
class AmplitudeApi:
"""
Amplitude API is used to handle communication with Amplitude Api's.
"""
def __init__(self, amplitude_api_key, amplitude_secret_key):
self.amplitude_api_key = amplitude_api_key
self.amplitude_secret_key = amplitude_secret_key
self.base_url = "https://amplitude.com/"
self.delete_user_path = "api/2/deletions/users"
def auth(self):
"""
Returns auth credentials for Amplitude authorization.
Returns:
Tuple: Returns authorization tuple.
"""
return (self.amplitude_api_key, self.amplitude_secret_key)
@backoff.on_exception(
backoff.expo,
AmplitudeRecoverableException,
max_tries=MAX_ATTEMPTS,
)
def delete_user(self, user):
"""
This function send an API request to delete user from Amplitude. It then parse the response and
try again if it is recoverable.
Returns:
None
Args:
user (dict): raw data of user to delete.
Raises:
AmplitudeException: if the error from amplitude is unrecoverable/unretryable.
AmplitudeRecoverableException: if the error from amplitude is recoverable/retryable.
"""
response = requests.post(
self.base_url + self.delete_user_path,
headers={"Content-Type": "application/json"},
json={
"user_ids": [user["user"]["id"]],
'ignore_invalid_id': 'true', # When true, the job ignores users that don't exist in the project.
"requester": "user-retirement-pipeline",
},
auth=self.auth()
)
if response.status_code == 200:
logger.info("Amplitude user deletion succeeded")
return
# We have some sort of error. Parse it, log it, and retry as needed.
error_msg = "Amplitude user deletion failed due to {reason}".format(reason=response.reason)
logger.error(error_msg)
# Status 429 is returned when there are too many requests and can be resolved in retrying sending
# request.
if response.status_code == 429 or 500 <= response.status_code < 600:
raise AmplitudeRecoverableException(error_msg)
else:
raise AmplitudeException(error_msg)

View File

@@ -0,0 +1,85 @@
"""
Helper API classes for calling Braze APIs.
"""
import logging
import os
import backoff
import requests
LOG = logging.getLogger(__name__)
MAX_ATTEMPTS = int(os.environ.get('RETRY_BRAZE_MAX_ATTEMPTS', 5))
class BrazeException(Exception):
pass
class BrazeRecoverableException(BrazeException):
pass
class BrazeApi:
"""
Braze API client used to make calls to Braze
"""
def __init__(self, braze_api_key, braze_instance):
self.api_key = braze_api_key
# https://www.braze.com/docs/api/basics/#endpoints
self.base_url = 'https://rest.{instance}.braze.com'.format(instance=braze_instance)
def auth_headers(self):
"""Returns authorization headers suitable for passing to the requests library"""
return {
'Authorization': 'Bearer ' + self.api_key,
}
@staticmethod
def get_error_message(response):
"""Returns a string suitable for logging"""
try:
json = response.json()
except ValueError:
json = {}
# https://www.braze.com/docs/api/errors
message = json.get('message')
return message or response.reason
def process_response(self, response, action):
"""Log response status and raise an error as needed"""
if response.ok:
LOG.info('Braze {action} succeeded'.format(action=action))
return
# We have some sort of error. Parse it, log it, and retry as needed.
error_msg = 'Braze {action} failed due to {msg}'.format(action=action, msg=self.get_error_message(response))
LOG.error(error_msg)
if response.status_code == 429 or 500 <= response.status_code < 600:
raise BrazeRecoverableException(error_msg)
else:
raise BrazeException(error_msg)
@backoff.on_exception(
backoff.expo,
BrazeRecoverableException,
max_tries=MAX_ATTEMPTS,
)
def delete_user(self, learner):
"""
Delete a learner from Braze.
"""
# https://www.braze.com/docs/help/gdpr_compliance/#the-right-to-erasure
# https://www.braze.com/docs/api/endpoints/user_data/post_user_delete
response = requests.post(
self.base_url + '/users/delete',
headers=self.auth_headers(),
json={
'external_ids': [learner['user']['id']], # Braze external ids are LMS user ids
},
)
self.process_response(response, 'user deletion')

View File

@@ -0,0 +1,530 @@
"""
Helper API classes for calling google APIs.
DriveApi is for managing files in google drive.
"""
# NOTE: Make sure that all non-ascii text written to standard output (including
# print statements and logging) is manually encoded to bytes using a utf-8 or
# other encoding. We currently make use of this library within a context that
# does NOT tolerate unicode text on sys.stdout, namely python 2 on Build
# Jenkins PLAT-2287 tracks this Tech Debt..
import json
import logging
from itertools import count
import backoff
from dateutil.parser import parse
from google.oauth2 import service_account
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
# I'm not super happy about this since the function is protected with a leading
# underscore, but the next best thing is literally copying this ~40 line
# function verbatim.
from googleapiclient.http import MediaIoBaseUpload, _should_retry_response
from six import iteritems, text_type
from scripts.user_retirement.utils.utils import batch
LOG = logging.getLogger(__name__)
# The maximum number of requests per batch is 100, according to the google API docs.
# However, cap our number lower than that maximum to avoid throttling errors and backoff.
GOOGLE_API_MAX_BATCH_SIZE = 10
# Mimetype used for Google Drive folders.
FOLDER_MIMETYPE = 'application/vnd.google-apps.folder'
# Fields to be extracted from OAuth2 JSON token files
OAUTH2_TOKEN_FIELDS = [
'client_id', 'client_secret', 'refresh_token',
'token_uri', 'id_token', 'scopes', 'access_token'
]
class BatchRequestError(Exception):
"""
Exception which indicates one or more failed requests inside of a batch request.
"""
class TriggerRetryException(Exception):
"""
Exception which indicates one or more throttled requests inside of a batch request.
"""
class BaseApiClient:
"""
Base API client for google services.
To add a new service, extend this class and override these class variables:
_api_name (e.g. "drive")
_api_version (e.g. "v3")
_api_scopes
"""
_api_name = None
_api_version = None
_api_scopes = None
def __init__(self, client_secrets_file_path, **kwargs):
self._build_client(client_secrets_file_path, **kwargs)
def _build_client(self, client_secrets_file_path, **kwargs):
"""
Build the google API client, specific to a single google service.
"""
# as_user_account is an indicator that the authentication
# is using a user account.
# If not true, assume a service account. Otherwise, read in the JSON
# file, set the scope, and use the info to instantiate Credentials.
# For more information about user account authentication, go to
# https://google-auth.readthedocs.io/en/master/user-guide.html#user-credentials
as_user_account = kwargs.pop('as_user_account', False)
if not as_user_account:
credentials = service_account.Credentials.from_service_account_file(
client_secrets_file_path, scopes=self._api_scopes
)
else:
with open(client_secrets_file_path) as fh:
token_info = json.load(fh)
token_info = {k: token_info.get(k) for k in OAUTH2_TOKEN_FIELDS}
# Take the access_token field and change it to token
token = token_info.pop('access_token', None)
token_info['token'] = token
# Set the scopes
token_info['scopes'] = self._api_scopes
credentials = Credentials(**token_info)
self._client = build(self._api_name, self._api_version, credentials=credentials, **kwargs)
LOG.info("Client built.")
def _batch_with_retry(self, requests):
"""
Send the given Google API requests in a single batch requests, and retry only requests that are throttled.
Args:
requests (list of googleapiclient.http.HttpRequest): The requests to send.
Returns:
dict mapping of request object to response
"""
# Mapping of request object to the corresponding response.
responses = {}
# This is our working "request queue". Initially, populate the request queue with all the given requests.
try_requests = []
try_requests.extend(requests)
# This is the queue of requests that are to be retried, populated by the batch callback function.
retry_requests = []
# Generate arbitrary (but unique in this batch request) IDs for each request, so that we can recall the
# corresponding response within a batch response.
request_object_to_request_id = dict(zip(
requests,
(text_type(n) for n in count()),
))
# Create a flipped mapping for convenience.
request_id_to_request_object = {v: k for k, v in iteritems(request_object_to_request_id)}
def batch_callback(request_id, response, exception): # pylint: disable=unused-argument,missing-docstring
"""
Handle individual responses in the batch request.
"""
request_object = request_id_to_request_object[request_id]
if exception:
if _should_retry_google_api(exception):
LOG.error(u'Request throttled, adding to the retry queue: {}'.format(exception).encode('utf-8'))
retry_requests.append(request_object)
else:
# In this case, probably nothing can be done, so we just give up on this particular request and
# do not include it in the responses dict.
LOG.error(u'Error processing request {}'.format(request_object).encode('utf-8'))
LOG.error(text_type(exception).encode('utf-8'))
else:
responses[request_object] = response
LOG.info(u'Successfully processed request {}.'.format(request_object).encode('utf-8'))
# Retry on API throttling at the HTTP request level.
@backoff.on_exception(
backoff.expo,
HttpError,
max_time=600, # 10 minutes
giveup=lambda e: not _should_retry_google_api(e),
on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda
)
# Retry on API throttling at the BATCH ITEM request level.
@backoff.on_exception(
backoff.expo,
TriggerRetryException,
max_time=600, # 10 minutes
on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda
)
def func():
"""
Core function which constitutes the retry loop. It has no inputs or outputs, only side-effects which
populates the `responses` variable within the scope of _batch_with_retry().
"""
# Construct a new batch request object containing the current iteration of requests to "try".
batch_request = self._client.new_batch_http_request(callback=batch_callback) # pylint: disable=no-member
for request_object in try_requests:
batch_request.add(
request_object,
request_id=request_object_to_request_id[request_object]
)
# Empty the retry queue in preparation of filling it back up with requests that need to be retried.
del retry_requests[:]
# Send the batch request. If the API responds with HTTP 403 or some other retryable error, we should
# immediately retry this function func() with the same requests in the try_requests queue. If the response
# is HTTP 200, we *still* may raise TriggerRetryException and retry a subset of requests if some, but not
# all requests need to be retried.
batch_request.execute()
# If the API throttled some requests, batch_callback would have populated the retry queue. Reset the
# try_requests queue and indicate to backoff that there are requests to retry.
if retry_requests:
del try_requests[:]
try_requests.extend(retry_requests)
raise TriggerRetryException()
# func()'s side-effect is that it indirectly calls batch_callback which populates the responses dict.
func()
return responses
def _backoff_handler(details):
"""
Simple logging handler for when timeout backoff occurs.
"""
LOG.info('Trying again in {wait:0.1f} seconds after {tries} tries calling {target}'.format(**details))
def _should_retry_google_api(exc):
"""
General logic for determining if a google API response is retryable.
Args:
exc (googleapiclient.errors.HttpError): The exception thrown by googleapiclient.
Returns:
bool: True if the caller should retry the API call.
"""
retry = False
if hasattr(exc, 'resp') and exc.resp: # bizarre and disappointing that sometimes `resp` doesn't exist.
retry = _should_retry_response(exc.resp.status, exc.content)
return retry
class DriveApi(BaseApiClient):
"""
Google Drive API client.
"""
_api_name = 'drive'
_api_version = 'v3'
_api_scopes = [
# basic file read-write functionality.
# 'https://www.googleapis.com/auth/drive.file',
# Full read write functionality
'https://www.googleapis.com/auth/drive',
# additional scope for being able to see folders not owned by this account.
'https://www.googleapis.com/auth/drive.metadata',
]
@backoff.on_exception(
backoff.expo,
HttpError,
max_time=600, # 10 minutes
giveup=lambda e: not _should_retry_google_api(e),
on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda
)
def create_file_in_folder(self, folder_id, filename, file_stream, mimetype):
"""
Creates a new file in the specified folder.
Args:
folder_id (str): google resource ID for the drive folder to put the file into.
filename (str): name of the uploaded file.
file_stream (file-like/stream): contents of the file to upload.
mimetype (str): mimetype of the given file.
Returns: file ID (str).
Throws:
googleapiclient.errors.HttpError:
For some non-retryable 4xx or 5xx error. See the full list here:
https://developers.google.com/drive/api/v3/handle-errors
"""
file_metadata = {
'name': filename,
'parents': [folder_id],
}
media = MediaIoBaseUpload(file_stream, mimetype=mimetype)
uploaded_file = self._client.files().create( # pylint: disable=no-member
body=file_metadata,
media_body=media,
fields='id'
).execute()
LOG.info(u'File uploaded: ID="{}", name="{}"'.format(uploaded_file.get('id'), filename).encode('utf-8'))
return uploaded_file.get('id')
# NOTE: Do not decorate this function with backoff since it already calls retryable methods.
def delete_files(self, file_ids):
"""
Delete multiple files forever, bypassing the "trash".
This function takes advantage of request batching to reduce request volume.
Args:
file_ids (list of str): list of IDs for files to delete.
Returns: nothing
Throws:
BatchRequestError:
One or more files could not be deleted (could even mean the file does not exist).
"""
if len(set(file_ids)) != len(file_ids):
raise ValueError('duplicates detected in the file_ids list.')
# mapping of request object to the new comment resource returned in the response.
responses = {}
# process the list of file ids in batches of size GOOGLE_API_MAX_BATCH_SIZE.
for file_ids_batch in batch(file_ids, batch_size=GOOGLE_API_MAX_BATCH_SIZE):
request_objects = []
for file_id in file_ids_batch:
request_objects.append(self._client.files().delete(fileId=file_id)) # pylint: disable=no-member
# this generic helper function will handle the retry logic
responses_batch = self._batch_with_retry(request_objects)
responses.update(responses_batch)
if len(responses) != len(file_ids):
raise BatchRequestError('Error deleting one or more files/folders.')
def delete_files_older_than(self, top_level, delete_before_dt, mimetype=None, prefix=None):
"""
Delete all files beneath a given top level folder that are older than a certain datetime.
Optionally, specify a file mimetype and a filename prefix.
Args:
top_level (str): ID of top level folder.
delete_before_dt (datetime.datetime): Datetime to use for file age. All files created before this datetime
will be permanently deleted. Should be timezone offset-aware.
mimetype (str): Mimetype of files to delete. If not specified, all non-folders will be found.
prefix (str): Filename prefix - only files started with this prefix will be deleted.
"""
LOG.info("Walking files...")
all_files = self.walk_files(
top_level, 'id, name, createdTime', mimetype
)
LOG.info("Files walked. {} files found before filtering.".format(len(all_files)))
file_ids_to_delete = []
for file in all_files:
if (not prefix or file['name'].startswith(prefix)) and parse(file['createdTime']) < delete_before_dt:
file_ids_to_delete.append(file['id'])
if file_ids_to_delete:
LOG.info("{} files remaining after filtering.".format(len(file_ids_to_delete)))
self.delete_files(file_ids_to_delete)
@backoff.on_exception(
backoff.expo,
HttpError,
max_time=600, # 10 minutes
giveup=lambda e: not _should_retry_google_api(e),
on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda
)
def walk_files(self, top_folder_id, file_fields='id, name', mimetype=None, recurse=True):
"""
List all files of a particular mimetype within a given top level folder, traversing all folders recursively.
This function may make multiple HTTP requests depending on how many pages the response contains. The default
page size for the python google API client is 100 items.
Args:
top_folder_id (str): ID of top level folder.
file_fields (str): Comma-separated list of metadata fields to return for each folder/file.
For a full list of file metadata fields, see https://developers.google.com/drive/api/v3/reference/files
mimetype (str): Mimetype of files to find. If not specified, all items will be returned, including folders.
recurse (bool): True to recurse into all found folders for items, False to only return top-level items.
Returns: List of dicts, where each dict contains file metadata and each dict key corresponds to fields
specified in the `file_fields` arg.
Throws:
googleapiclient.errors.HttpError:
For some non-retryable 4xx or 5xx error. See the full list here:
https://developers.google.com/drive/api/v3/handle-errors
"""
# Sent to list() call and used only for sending the pageToken.
extra_kwargs = {}
# Cumulative list of file metadata dicts for found files.
results = []
# List of IDs of all visited folders.
visited_folders = []
# List of IDs of all found files.
found_ids = []
# List of folder IDs remaining to be listed.
folders_to_visit = [top_folder_id]
# Mimetype part of file-listing query.
mimetype_clause = ""
if mimetype:
# Return both folders and the specified mimetype.
mimetype_clause = "( mimeType = '{}' or mimeType = '{}') and ".format(FOLDER_MIMETYPE, mimetype)
while folders_to_visit:
current_folder = folders_to_visit.pop()
LOG.info("Current folder: {}".format(current_folder))
visited_folders.append(current_folder)
extra_kwargs = {}
while True:
resp = self._client.files().list( # pylint: disable=no-member
q="{}'{}' in parents".format(mimetype_clause, current_folder),
fields='nextPageToken, files({})'.format(
file_fields + ', mimeType, parents'
),
**extra_kwargs
).execute()
page_results = resp.get('files', [])
LOG.info("walk_files: Returned %s results.", len(page_results))
# Examine returned results to separate folders from non-folders.
for result in page_results:
LOG.info(u"walk_files: Result: {}".format(result).encode('utf-8'))
# Folders contain files - and get special treatment.
if result['mimeType'] == FOLDER_MIMETYPE:
if recurse and result['id'] not in visited_folders:
# Add any undiscovered folders to the list of folders to check.
folders_to_visit.append(result['id'])
# Determine if this result is a file to return.
if result['id'] not in found_ids and (not mimetype or result['mimeType'] == mimetype):
found_ids.append(result['id'])
# Return only the fields specified in file_fields.
results.append({k.strip(): result.get(k.strip(), None) for k in file_fields.split(',')})
LOG.info("walk_files: %s files found and %s folders to check.", len(results), len(folders_to_visit))
if page_results and 'nextPageToken' in resp and resp['nextPageToken']:
# Only call for more result pages if results were actually returned -and
# a nextPageToken is returned.
extra_kwargs['pageToken'] = resp['nextPageToken']
else:
break
return results
# NOTE: Do not decorate this function with backoff since it already calls retryable methods.
def create_comments_for_files(self, file_ids_and_content, fields='id'):
"""
Create comments for files.
This function is NOT idempotent. It will blindly create the comments it was asked to create, regardless of the
existence of other identical comments.
Args:
file_ids_and_content (list of tuple(str, str)): list of (file_id, content) tuples.
fields (str): comma separated list of fields to describe each comment resource in the response.
Returns: dict mapping of file_id to comment resource (dict). The contents of the comment resources are dictated
by the `fields` arg.
Throws:
googleapiclient.errors.HttpError:
For some non-retryable 4xx or 5xx error. See the full list here:
https://developers.google.com/drive/api/v3/handle-errors
BatchRequestError:
One or more files resulted in an error when adding comments.
"""
file_ids, _ = zip(*file_ids_and_content)
if len(set(file_ids)) != len(file_ids):
raise ValueError('Duplicates detected in the file_ids_and_content list.')
# Mapping of file_id to the new comment resource returned in the response.
responses = {}
# Process the list of file IDs in batches of size GOOGLE_API_MAX_BATCH_SIZE.
for file_ids_and_content_batch in batch(file_ids_and_content, batch_size=GOOGLE_API_MAX_BATCH_SIZE):
request_objects_to_file_id = {}
for file_id, content in file_ids_and_content_batch:
request_object = self._client.comments().create( # pylint: disable=no-member
fileId=file_id,
body={u'content': content},
fields=fields
)
request_objects_to_file_id[request_object] = file_id
# This generic helper function will handle the retry logic
responses_batch = self._batch_with_retry(request_objects_to_file_id.keys())
# Transform the mapping FROM request objects -> comment resource TO file IDs -> comment resources.
responses_batch = {
request_objects_to_file_id[request_object]: resp
for request_object, resp in responses_batch.items()
}
responses.update(responses_batch)
if len(responses) != len(file_ids_and_content):
raise BatchRequestError('Error creating comments for one or more files/folders.')
return responses
# NOTE: Do not decorate this function with backoff since it already calls retryable methods.
def list_permissions_for_files(self, file_ids, fields='emailAddress, role'):
"""
List permissions for files.
Args:
file_ids (list of str): list of Drive file IDs for which to list permissions.
fields (str): comma separated list of fields to describe each permissions resource in the response.
Returns: dict mapping of file_id to permission resource list (list of dict). The contents of the permission
resources are dictated by the `fields` arg.
Throws:
googleapiclient.errors.HttpError:
For some non-retryable 4xx or 5xx error. See the full list here:
https://developers.google.com/drive/api/v3/handle-errors
BatchRequestError:
One or more files resulted in an error when having their permissions listed.
"""
if len(set(file_ids)) != len(file_ids):
raise ValueError('duplicates detected in the file_ids list.')
# mapping of file_id to the new comment resource returned in the response.
responses = {}
# process the list of file ids in batches of size GOOGLE_API_MAX_BATCH_SIZE.
for file_ids_batch in batch(file_ids, batch_size=GOOGLE_API_MAX_BATCH_SIZE):
request_objects_to_file_id = {}
for file_id in file_ids_batch:
request_object = self._client.permissions().list( # pylint: disable=no-member
fileId=file_id,
fields='permissions({})'.format(fields)
)
request_objects_to_file_id[request_object] = file_id
# this generic helper function will handle the retry logic
responses_batch = self._batch_with_retry(request_objects_to_file_id.keys())
# transform the mapping from request objects -> response dicts to file ids -> permissions resource lists.
responses_batch_transformed = {}
for request_object, resp in responses_batch.items():
permissions = None
if resp and 'permissions' in resp:
permissions = resp['permissions']
responses_batch_transformed[request_objects_to_file_id[request_object]] = permissions
responses.update(responses_batch_transformed)
if len(responses) != len(file_ids):
raise BatchRequestError('Error listing permissions for one or more files/folders.')
return responses

View File

@@ -0,0 +1,123 @@
"""
Helper API classes for calling Hubspot APIs.
"""
import logging
import os
import backoff
import requests
from scripts.user_retirement.utils.email_utils import send_email
LOG = logging.getLogger(__name__)
MAX_ATTEMPTS = int(os.environ.get('RETRY_HUBSPOT_MAX_ATTEMPTS', 5))
GET_VID_FROM_EMAIL_URL_TEMPLATE = "https://api.hubapi.com/contacts/v1/contact/email/{email}/profile"
DELETE_USER_FROM_VID_TEMPLATE = "https://api.hubapi.com/contacts/v1/contact/vid/{vid}"
class HubspotException(Exception):
pass
class HubspotAPI:
"""
Hubspot API client used to make calls to Hubspot
"""
def __init__(
self,
hubspot_api_key,
aws_region,
from_address,
alert_email
):
self.api_key = hubspot_api_key
self.aws_region = aws_region
self.from_address = from_address
self.alert_email = alert_email
@backoff.on_exception(
backoff.expo,
HubspotException,
max_tries=MAX_ATTEMPTS
)
def delete_user(self, learner):
"""
Delete a learner from hubspot using their email address.
"""
email = learner.get('original_email', None)
if not email:
raise TypeError('Expected an email address for user to delete, but received None.')
user_vid = self.get_user_vid(email)
if user_vid:
self.delete_user_by_vid(user_vid)
def delete_user_by_vid(self, vid):
"""
Delete a learner from hubspot using their Hubspot `vid` (unique identifier)
"""
headers = {
'content-type': 'application/json',
'authorization': f'Bearer {self.api_key}'
}
req = requests.delete(DELETE_USER_FROM_VID_TEMPLATE.format(
vid=vid
), headers=headers)
error_msg = ""
if req.status_code == 200:
LOG.info("User successfully deleted from Hubspot")
self.send_marketing_alert(vid)
elif req.status_code == 401:
error_msg = "Hubspot user deletion failed due to authorized API call"
elif req.status_code == 404:
error_msg = "Hubspot user deletion failed because vid doesn't match user"
elif req.status_code == 500:
error_msg = "Hubspot user deletion failed due to server-side (Hubspot) issues"
else:
error_msg = "Hubspot user deletion failed due to unknown reasons"
if error_msg:
LOG.error(error_msg)
raise HubspotException(error_msg)
def get_user_vid(self, email):
"""
Get a user's `vid` from Hubspot. `vid` is the terminology that hubspot uses for a user ids
"""
headers = {
'content-type': 'application/json',
'authorization': f'Bearer {self.api_key}'
}
req = requests.get(GET_VID_FROM_EMAIL_URL_TEMPLATE.format(
email=email
), headers=headers)
if req.status_code == 200:
req_data = req.json()
return req_data.get('vid')
elif req.status_code == 404:
LOG.info("No action taken because no user was found in Hubspot.")
return
else:
error_msg = "Error attempted to get user_vid from Hubspot. Error: {}".format(
req.text
)
LOG.error(error_msg)
raise HubspotException(error_msg)
def send_marketing_alert(self, vid):
"""
Notify marketing with user's Hubspot `vid` upon successful deletion.
"""
subject = "Alert: Hubspot Deletion"
body = "Learner with the VID \"{}\" has been deleted from Hubspot.".format(vid)
send_email(
self.aws_region,
self.from_address,
[self.alert_email],
subject,
body
)

View File

@@ -0,0 +1,137 @@
"""
Salesforce API class that will call the Salesforce REST API using simple-salesforce.
"""
import logging
import os
import backoff
from requests.exceptions import ConnectionError as RequestsConnectionError
from simple_salesforce import Salesforce, format_soql
LOG = logging.getLogger(__name__)
MAX_ATTEMPTS = int(os.environ.get('RETRY_SALESFORCE_MAX_ATTEMPTS', 5))
RETIREMENT_TASK_DESCRIPTION = (
"A user data retirement request has been made for "
"{email} who has been identified as a lead in Salesforce. "
"Please manually retire the user data for this lead."
)
class SalesforceApi:
"""
Class for making Salesforce API calls
"""
def __init__(self, username, password, security_token, domain, assignee_username):
"""
Create API with credentials
"""
self._sf = self._get_salesforce_client(
username=username,
password=password,
security_token=security_token,
domain=domain
)
self.assignee_id = self.get_user_id(assignee_username)
if not self.assignee_id:
raise Exception("Could not find Salesforce user with username " + assignee_username)
@backoff.on_exception(
backoff.expo,
RequestsConnectionError,
max_tries=MAX_ATTEMPTS
)
def _get_salesforce_client(self, username, password, security_token, domain):
"""
Returns a constructed Salesforce client and retries upon failure.
"""
return Salesforce(
username=username,
password=password,
security_token=security_token,
domain=domain
)
@backoff.on_exception(
backoff.expo,
RequestsConnectionError,
max_tries=MAX_ATTEMPTS
)
def get_lead_ids_by_email(self, email):
"""
Given an id, query for a Lead with that email
Returns a list of ids tht have that email or None if none are found
"""
id_query = self._sf.query(format_soql("SELECT Id FROM Lead WHERE Email = {email}", email=email))
total_size = int(id_query['totalSize'])
if total_size == 0:
return None
else:
ids = [record['Id'] for record in id_query['records']]
if len(ids) > 1:
LOG.warning("Multiple Ids returned for Lead with email {}".format(email))
return ids
@backoff.on_exception(
backoff.expo,
RequestsConnectionError,
max_tries=MAX_ATTEMPTS
)
def get_user_id(self, username):
"""
Given a username, returns the user id for the User with that username
or None if no user is found
Used to get a the user id of the user we will assign the retirement task to
"""
id_query = self._sf.query(format_soql("SELECT Id FROM User WHERE Username = {username}", username=username))
total_size = int(id_query['totalSize'])
if total_size == 0:
return None
else:
return id_query['records'][0]['Id']
@backoff.on_exception(
backoff.expo,
RequestsConnectionError,
max_tries=MAX_ATTEMPTS
)
def _create_retirement_task(self, email, lead_ids):
"""
Creates a Salesforce Task instructing a user to manually retire the
given lead
"""
task_params = {
'Description': RETIREMENT_TASK_DESCRIPTION.format(email=email),
'Subject': "GDPR Request: " + email,
'WhoId': lead_ids[0],
'OwnerId': self.assignee_id
}
if len(lead_ids) > 1:
note = "\nNotice: Multiple leads were identified with the same email. Please retire all following leads:"
for lead_id in lead_ids:
note += "\n{}".format(lead_id)
task_params['Description'] += note
created_task = self._sf.Task.create(task_params)
if created_task['success']:
LOG.info("Successfully salesforce task created task %s", created_task['id'])
else:
LOG.error("Errors while creating task:")
for error in created_task['errors']:
LOG.error(error)
raise Exception("Unable to create retirement task for email " + email)
def retire_learner(self, learner):
"""
Given a learner email, check if that learner exists as a lead in Salesforce
If they do, create a Salesforce Task instructing someone to manually retire
their personal information
"""
email = learner.get('original_email', None)
if not email:
raise TypeError('Expected an email address for user to delete, but received None.')
lead_ids = self.get_lead_ids_by_email(email)
if lead_ids is None:
LOG.info("No action taken because no lead was found in Salesforce.")
return
self._create_retirement_task(email, lead_ids)

View File

@@ -0,0 +1,283 @@
"""
Segment API call wrappers
"""
import logging
import sys
import traceback
import backoff
import requests
from simplejson.errors import JSONDecodeError
from six import text_type
# Maximum number of tries on Segment API calls
MAX_TRIES = 4
# These are the required/optional keys in the learner dict that contain IDs we need to retire from Segment.
REQUIRED_IDENTIFYING_KEYS = [('user', 'id'), 'original_username']
OPTIONAL_IDENTIFYING_KEYS = ['ecommerce_segment_id']
# The Segment Config API for bulk deleting users for a particular workspace
BULK_REGULATE_URL = 'v1beta/workspaces/{}/regulations'
# The Segment Config API for querying the status of a bulk user deletion request for a particular workspace
BULK_REGULATE_STATUS_URL = 'v1beta/workspaces/{}/regulations/{}'
# According to Segment this represents the maximum limits of the bulk regulation call.
# https://reference.segmentapis.com/?version=latest#57a69434-76cc-43cc-a547-98c319182247
MAXIMUM_USERS_IN_REGULATION_REQUEST = 5000
LOG = logging.getLogger(__name__)
def _backoff_handler(details):
"""
Simple logging handler for when timeout backoff occurs.
"""
LOG.error('Trying again in {wait:0.1f} seconds after {tries} tries calling {target}'.format(**details))
# Log the text response from any HTTPErrors, if possible
try:
LOG.error(traceback.format_exc())
exc = sys.exc_info()[1]
LOG.error("HTTPError code {}: {}".format(exc.response.status_code, exc.response.text))
except Exception: # pylint: disable=broad-except
pass
def _wait_30_seconds():
"""
Backoff generator that waits for 30 seconds.
"""
return backoff.constant(interval=30)
def _http_status_giveup(exc):
"""
Giveup method that gives up backoff upon any non-5xx and 504 server errors.
"""
return not 429 == exc.response.status_code and not 500 <= exc.response.status_code < 600
def _retry_segment_api():
"""
Decorator which enables retries with sane backoff defaults
"""
def inner(func): # pylint: disable=missing-docstring
func_with_decode_backoff = backoff.on_exception(
backoff.expo,
JSONDecodeError,
max_tries=MAX_TRIES,
on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda
)
func_with_backoff = backoff.on_exception(
backoff.expo,
requests.exceptions.HTTPError,
max_tries=MAX_TRIES,
giveup=_http_status_giveup,
on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda
)
func_with_timeout_backoff = backoff.on_exception(
_wait_30_seconds,
requests.exceptions.Timeout,
max_tries=MAX_TRIES,
on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda
)
return func_with_decode_backoff(func_with_backoff(func_with_timeout_backoff(func)))
return inner
class SegmentApi:
"""
Segment API client with convenience methods
"""
def __init__(self, base_url, auth_token, workspace_slug):
self.base_url = base_url
self.auth_token = auth_token
self.workspace_slug = workspace_slug
@_retry_segment_api()
def _call_segment_post(self, url, params):
"""
Actually makes the Segment REST POST call.
5xx errors and timeouts will be retried via _retry_segment_api,
all others will bubble up.
"""
headers = {
"Authorization": "Bearer {}".format(self.auth_token),
"Content-Type": "application/json"
}
resp = requests.post(self.base_url + url, json=params, headers=headers)
resp.raise_for_status()
return resp
@_retry_segment_api()
def _call_segment_get(self, url):
"""
Actually makes the Segment REST GET call.
5xx errors and timeouts will be retried via _retry_segment_api,
all others will bubble up.
"""
headers = {
"Authorization": "Bearer {}".format(self.auth_token)
}
resp = requests.get(self.base_url + url, headers=headers)
resp.raise_for_status()
return resp
def _get_value_from_learner(self, learner, key):
"""
Return the value from a learner dict for the given key or 2-tuple of keys.
Allows us to map things like learner['user']['id'] in a single entry in REQUIRED_IDENTIFYING_KEYS.
"""
if isinstance(key, tuple):
val = learner[key[0]][key[1]]
else:
val = learner[key]
return text_type(val)
def _send_regulation_request(self, params):
"""
Make the call to the Segment Regulate API, cleanly report any errors
"""
resp_json = ""
try:
resp = self._call_segment_post(BULK_REGULATE_URL.format(self.workspace_slug), params)
try:
resp_json = resp.json()
bulk_user_delete_id = resp_json['regulate_id']
LOG.info('Bulk user regulation queued. Id: {}'.format(bulk_user_delete_id))
except JSONDecodeError:
resp_json = resp.text
raise
# If we get here we got some kind of JSON response from Segment, we'll try to get
# the data we need. If it doesn't exist we'll bubble up the error from Segment and
# eat the TypeError / KeyError since they won't be relevant.
except (TypeError, KeyError, requests.exceptions.HTTPError, JSONDecodeError) as exc:
LOG.exception(exc)
err = u'Error was encountered for params: {} \n\n Response: {}'.format(
params,
text_type(resp_json)
).encode('utf-8')
LOG.error(err)
raise Exception(err)
def delete_and_suppress_learner(self, learner):
"""
Delete AND suppress a single Segment user using the bulk user deletion REST API.
:param learner: Single user retirement status row with its fields.
"""
# Send a list of one learner to be deleted by the multiple learner deletion call.
return self.delete_and_suppress_learners([learner], 1)
def unsuppress_learners_by_key(self, key, learners, chunk_size, beginning_idx=0):
"""
Sets up the Segment REST API calls to UNSUPPRESS users in chunks.
:param key: Key in the learner dict to pull the ID we care about from.
:param learners: List of learner dicts to be worked on. We only use the key passed in.
:param chunk_size: How many learners should be retired in this batch.
:param beginning_idx: Index into learners where this batch should start.
"""
curr_idx = beginning_idx
while curr_idx < len(learners):
start_idx = curr_idx
end_idx = min(start_idx + chunk_size - 1, len(learners) - 1)
LOG.info(
"Attempting unsuppress for key '%s', start index %s, end index %s for learners '%s' through '%s'",
key,
start_idx, end_idx,
learners[start_idx]['original_username'],
learners[end_idx]['original_username']
)
learner_vals = []
for idx in range(start_idx, end_idx + 1):
learner_vals.append(self._get_value_from_learner(learners[idx], key))
if len(learner_vals) >= MAXIMUM_USERS_IN_REGULATION_REQUEST:
LOG.error(
'Attempting to UNSUPPRESS too many user values (%s) at once in bulk request - decrease chunk_size.',
len(learner_vals)
)
return
params = {
"regulation_type": "Unsuppress",
"attributes": {
"name": "userId",
"values": learner_vals
}
}
self._send_regulation_request(params)
curr_idx += chunk_size
def delete_and_suppress_learners(self, learners, chunk_size, beginning_idx=0):
"""
Sets up the Segment REST API calls to GDPR-delete users in chunks.
:param learners: List of learner dicts returned from LMS, should contain all we need to retire this learner.
:param chunk_size: How many learners should be retired in this batch.
:param beginning_idx: Index into learners where this batch should start.
"""
curr_idx = beginning_idx
while curr_idx < len(learners):
start_idx = curr_idx
end_idx = min(start_idx + chunk_size - 1, len(learners) - 1)
LOG.info(
"Attempting Segment deletion with start index %s, end index %s for learners (%s, %s) through (%s, %s)",
start_idx, end_idx,
learners[start_idx]['user']['id'], learners[start_idx]['original_username'],
learners[end_idx]['user']['id'], learners[end_idx]['original_username']
)
learner_vals = []
for idx in range(start_idx, end_idx + 1):
for id_key in REQUIRED_IDENTIFYING_KEYS:
learner_vals.append(self._get_value_from_learner(learners[idx], id_key))
for id_key in OPTIONAL_IDENTIFYING_KEYS:
if id_key in learners[idx]:
learner_vals.append(self._get_value_from_learner(learners[idx], id_key))
if len(learner_vals) >= MAXIMUM_USERS_IN_REGULATION_REQUEST:
LOG.error(
'Attempting to delete too many user values (%s) at once in bulk request - decrease chunk_size.',
len(learner_vals)
)
return
params = {
"regulation_type": "Suppress_With_Delete",
"attributes": {
"name": "userId",
"values": learner_vals
}
}
self._send_regulation_request(params)
curr_idx += chunk_size
def get_bulk_delete_status(self, bulk_delete_id):
"""
Queries the status of a previously submitted bulk delete request.
:param bulk_delete_id: ID returned from a previously-submitted bulk delete request.
"""
resp = self._call_segment_get(BULK_REGULATE_STATUS_URL.format(self.workspace_slug, bulk_delete_id))
resp_json = resp.json()
LOG.info(text_type(resp_json))

View File

@@ -0,0 +1,25 @@
import os
def envvar_get_int(var_name, default):
"""
Grab an environment variable and return it as an integer.
If the environment variable does not exist, return the default.
"""
return int(os.environ.get(var_name, default))
def batch(batchable, batch_size=1):
"""
Utility to facilitate batched iteration over a list.
Arguments:
batchable (list): The list to break into batches.
Yields:
list
"""
batchable_list = list(batchable)
length = len(batchable_list)
for index in range(0, length, batch_size):
yield batchable_list[index:index + batch_size]