diff --git a/pavelib/tests.py b/pavelib/tests.py index de606d8af8..c19ebad930 100644 --- a/pavelib/tests.py +++ b/pavelib/tests.py @@ -69,6 +69,11 @@ __test__ = False # do not collect dest='disable_migrations', help="Create tables by applying migrations." ), + make_option( + '--xdist_ip_addresses', + dest='xdist_ip_addresses', + help="Space separated string of ip addresses to shard tests to via xdist." + ) ], share_with=['pavelib.utils.test.utils.clean_reports_dir']) @PassthroughTask @timed @@ -152,6 +157,11 @@ def test_system(options, passthrough_options): "--disable-coverage", action="store_false", dest="with_coverage", help="Run the unit tests directly through pytest, NOT coverage" ), + make_option( + '--xdist_ip_addresses', + dest='xdist_ip_addresses', + help="Space separated string of ip addresses to shard tests to via xdist." + ) ], share_with=['pavelib.utils.test.utils.clean_reports_dir']) @PassthroughTask @timed diff --git a/pavelib/utils/test/suites/pytest_suite.py b/pavelib/utils/test/suites/pytest_suite.py index 1b5962a3b4..ab320f43fe 100644 --- a/pavelib/utils/test/suites/pytest_suite.py +++ b/pavelib/utils/test/suites/pytest_suite.py @@ -7,7 +7,6 @@ from pavelib.utils.test import utils as test_utils from pavelib.utils.test.suites.suite import TestSuite from pavelib.utils.envs import Env - __test__ = False # do not collect @@ -112,6 +111,7 @@ class SystemTestSuite(PytestSuite): self.processes = kwargs.get('processes', None) self.randomize = kwargs.get('randomize', None) self.settings = kwargs.get('settings', Env.TEST_SETTINGS) + self.xdist_ip_addresses = kwargs.get('xdist_ip_addresses', None) if self.processes is None: # Don't use multiprocessing by default @@ -142,12 +142,25 @@ class SystemTestSuite(PytestSuite): if self.disable_capture: cmd.append("-s") - if self.processes == -1: - cmd.append('-n auto') - cmd.append('--dist=loadscope') - elif self.processes != 0: - cmd.append('-n {}'.format(self.processes)) + if self.xdist_ip_addresses: cmd.append('--dist=loadscope') + for ip in self.xdist_ip_addresses.split(' '): + xdist_string = '--tx ssh=ubuntu@{}//python="source /edx/app/edxapp/edxapp_env; ' \ + 'python"//chdir="/edx/app/edxapp/edx-platform"'.format(ip) + cmd.append(xdist_string) + already_synced_dirs = set() + for test_path in self.test_id.split(): + test_root_dir = test_path.split('/')[0] + if test_root_dir not in already_synced_dirs: + cmd.append('--rsyncdir {}'.format(test_root_dir)) + already_synced_dirs.add(test_root_dir) + else: + if self.processes == -1: + cmd.append('-n auto') + cmd.append('--dist=loadscope') + elif self.processes != 0: + cmd.append('-n {}'.format(self.processes)) + cmd.append('--dist=loadscope') if not self.randomize: cmd.append('-p no:randomly') @@ -212,6 +225,7 @@ class LibTestSuite(PytestSuite): self.append_coverage = kwargs.get('append_coverage', False) self.test_id = kwargs.get('test_id', self.root) self.eval_attr = kwargs.get('eval_attr', None) + self.xdist_ip_addresses = kwargs.get('xdist_ip_addresses', None) @property def cmd(self): @@ -235,8 +249,23 @@ class LibTestSuite(PytestSuite): cmd.append("--verbose") if self.disable_capture: cmd.append("-s") + + if self.xdist_ip_addresses: + cmd.append('--dist=loadscope') + for ip in self.xdist_ip_addresses.split(' '): + xdist_string = '--tx ssh=ubuntu@{}//python="source /edx/app/edxapp/edxapp_env; ' \ + 'python"//chdir="/edx/app/edxapp/edx-platform"'.format(ip) + cmd.append(xdist_string) + already_synced_dirs = set() + for test_path in self.test_id.split(): + test_root_dir = test_path.split('/')[0] + if test_root_dir not in already_synced_dirs: + cmd.append('--rsyncdir {}'.format(test_root_dir)) + already_synced_dirs.add(test_root_dir) + if self.eval_attr: cmd.append("-a '{}'".format(self.eval_attr)) + cmd.append(self.test_id) return self._under_coverage_cmd(cmd) diff --git a/scripts/xdist/pytest_container_manager.py b/scripts/xdist/pytest_container_manager.py new file mode 100644 index 0000000000..a430791a08 --- /dev/null +++ b/scripts/xdist/pytest_container_manager.py @@ -0,0 +1,190 @@ +import argparse +import logging +import time + +import boto3 +from botocore.exceptions import ClientError + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class PytestContainerManager(): + """ + Responsible for spinning up and terminating containers to be used with pytest-xdist + """ + + def __init__(self, region, cluster): + self.ecs = boto3.client('ecs', region) + self.cluster_name = cluster + + def spin_up_containers(self, number_of_containers, task_name, subnets, security_groups, public_ip_enabled, launch_type): + """ + Spins up containers and generates two .txt files, one containing the IP + addresses of the new containers, the other containing their task_arns. + """ + CONTAINER_RUN_TIME_OUT_MINUTES = 10 + MAX_RUN_TASK_RETRIES = 7 + + revision = self.ecs.describe_task_definition(taskDefinition=task_name)['taskDefinition']['revision'] + task_definition = "{}:{}".format(task_name, revision) + + logging.info("Spinning up {} containers based on task definition: {}".format(number_of_containers, task_definition)) + + remainder = number_of_containers % 10 + quotient = number_of_containers / 10 + + container_num_list = [10 for i in range(0, quotient)] + if remainder: + container_num_list.append(remainder) + + # Boot up containers. boto3's run_task only allows 10 containers to be launched at a time + task_arns = [] + for num in container_num_list: + for retry in range(1, MAX_RUN_TASK_RETRIES + 1): + try: + response = self.ecs.run_task( + count=num, + cluster=self.cluster_name, + launchType=launch_type, + networkConfiguration={ + 'awsvpcConfiguration': { + 'subnets': subnets, + 'securityGroups': security_groups, + 'assignPublicIp': public_ip_enabled + } + }, + taskDefinition=task_definition + ) + except ClientError as err: + # Handle AWS throttling with an exponential backoff + if retry == MAX_RUN_TASK_RETRIES: + raise StandardError( + "MAX_RUN_TASK_RETRIES ({}) reached while spinning up containers due to AWS throttling.".format(MAX_RUN_TASK_RETRIES) + ) + logger.info("Hit error: {}. Retrying".format(err)) + countdown = 2 ** retry + logger.info("Sleeping for {} seconds".format(countdown)) + time.sleep(countdown) + else: + break + + for task_response in response['tasks']: + task_arns.append(task_response['taskArn']) + + # Wait for containers to finish spinning up + not_running = task_arns[:] + ip_addresses = [] + all_running = False + for attempt in range(0, CONTAINER_RUN_TIME_OUT_MINUTES * 2): + time.sleep(30) + list_tasks_response = self.ecs.describe_tasks(cluster=self.cluster_name, tasks=not_running)['tasks'] + del not_running[:] + for task_response in list_tasks_response: + if task_response['lastStatus'] == 'RUNNING': + for container in task_response['containers']: + ip_addresses.append(container["networkInterfaces"][0]["privateIpv4Address"]) + else: + not_running.append(task_response['taskArn']) + + if not_running: + logger.info("Still waiting on {} containers to spin up".format(len(not_running))) + else: + logger.info("Finished spinning up containers") + all_running = True + break + + if not all_running: + raise StandardError( + "Timed out waiting to spin up all containers." + ) + + logger.info("Successfully booted up {} containers.".format(number_of_containers)) + + # Generate .txt files containing IP addresses and task arns + ip_list_string = " ".join(ip_addresses) + logger.info("Container IP list: {}".format(ip_list_string)) + ip_list_file = open("pytest_container_ip_list.txt", "w") + ip_list_file.write(ip_list_string) + ip_list_file.close() + + task_arn_list_string = " ".join(task_arns) + logger.info("Container task arn list: {}".format(task_arn_list_string)) + task_arn_file = open("pytest_container_task_arns.txt", "w") + task_arn_file.write(task_arn_list_string) + task_arn_file.close() + + def terminate_containers(self, task_arns, reason): + """ + Terminates containers based on a list of task_arns. + """ + for task_arn in task_arns: + response = self.ecs.stop_task( + cluster=self.cluster_name, + task=task_arn, + reason=reason + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="PytestContainerManager, manages ECS containers in an AWS cluster." + ) + + parser.add_argument('--region', '-g', default='us-east-1', + help="AWS region where ECS infrastructure lives. Defaults to us-east-1") + + parser.add_argument('--cluster', '-c', default="jenkins-worker-containers", + help="AWS Cluster name where the containers live. Defaults to" + "the testeng cluster: jenkins-worker-containers") + + parser.add_argument('--action', '-a', choices=['up', 'down'], default=None, + help="Action for PytestContainerManager to perform. " + "Either up for spinning up AWS ECS containers or down for stopping them") + + # Spinning up containers + parser.add_argument('--num_containers', '-n', type=int, default=None, + help="Number of containers to spin up") + + parser.add_argument('--task_name', '-t', default=None, + help="Name of the task definition for spinning up workers") + + parser.add_argument('--subnets', '-s', nargs='+', default=None, + help="List of subnets for the containers to exist in") + + parser.add_argument('--security_groups', '-sg', nargs='+', default=None, + help="List of security groups to apply to the containers") + + parser.add_argument('--public_ip_enabled', choices=['ENABLED', 'DISABLED'], + default='DISABLED', help="Whether the containers should have a public IP") + + parser.add_argument('--launch_type', default='FARGATE', choices=['EC2', 'FARGATE'], + help="ECS launch type of container. Defaults to FARGATE") + + # Terminating containers + parser.add_argument('--task_arns', '-arns', nargs='+', default=None, + help="Task arns to terminate") + + parser.add_argument('--reason', '-r', default="Finished executing tests", + help="Reason for terminating containers") + + args = parser.parse_args() + containerManager = PytestContainerManager(args.region, args.cluster) + + if args.action == 'up': + containerManager.spin_up_containers( + args.num_containers, + args.task_name, + args.subnets, + args.security_groups, + args.public_ip_enabled, + args.launch_type + ) + elif args.action == 'down': + containerManager.terminate_containers( + args.task_arns, + args.reason + ) + else: + logger.info("No action specified for PytestContainerManager")