#!/usr/bin/env python3 import argparse import logging import boto3 import time import pprint logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("Runner") logger.setLevel(logging.INFO) class EcsService: def __init__(self, client): self.ecs = client def get_ecs_service(self, service_name): """Get ECS services and return Worker service name Returns: ECS worker node servicename """ try: response = self.ecs.list_services( cluster='SatCompCluster', ) for service in response['serviceArns']: if service_name in service: return service except Exception as e: logger.error(f"Failed to get ECS service names: {e}") raise e def update_ecs_service(self, leader_node_count, worker_node_count): worker_service = self.get_ecs_service("SolverWorkerService") leader_service = self.get_ecs_service("SolverLeaderService") try: ecs.update_service( cluster="SatCompCluster", service=leader_service, desiredCount=leader_node_count ) ecs.update_service( cluster="SatCompCluster", service=worker_service, desiredCount=worker_node_count ) except Exception as e: logger.error(f"Failed to update ECS service: {e}") raise e def describe_ecs_services(self): result = {} worker_service = self.get_ecs_service("SolverWorkerService") leader_service = self.get_ecs_service("SolverLeaderService") try: result = ecs.describe_services( cluster="SatCompCluster", services=[leader_service, worker_service] ) except Exception as e: logger.error(f"Failed to describe ECS service: {e}") raise e return result class ScalingGroup: def __init__(self, client) -> None: self.asg_client = client def update_asg(self, desired_count: str): try: response = self.asg_client.describe_auto_scaling_groups()['AutoScalingGroups'] for group in response: if 'EcsInstanceAsg' in group["AutoScalingGroupName"]: asg_name = group["AutoScalingGroupName"] response = self.asg_client.update_auto_scaling_group( AutoScalingGroupName= asg_name, MaxSize=desired_count, DesiredCapacity=desired_count, ) except Exception as e: logger.error(f"Failed to update ASG: {e}") raise e def await_completion(ecs_service, asg_client): # wait for ECS setup/teardown to complete start = time.time() while True: status = ecs_service.describe_ecs_services() leader = status["services"][0]["deployments"][0] leader_desired = leader["desiredCount"] leader_pending = leader["pendingCount"] leader_running = leader["runningCount"] worker = status["services"][1]["deployments"][0] worker_desired = worker["desiredCount"] worker_pending = worker["pendingCount"] worker_running = worker["runningCount"] elapsed = time.time() - start print(f"Waiting for ECS ({elapsed/60.:3.1f} mins)") print(f" leader: {leader_desired} desired, {leader_pending} pending, {leader_running} running") print(f" workers: {worker_desired} desired, {worker_pending} pending, {worker_running} running") print("") if (leader_desired==leader_running and worker_desired==worker_running): print("ECS configuration complete") return time.sleep(30) # MWW: I am disabling this output since the output 'lies' in the sense that it reports # a failure in the usual case, before eventually succeeding. # put this after the first sleep, since it's usually delayed # asg_status = asg_client.describe_scaling_activities() # only display the most recent message # asg_update = asg_status["Activities"][0] # print(f"Most recent AutoScaling Activity Log") # print(f" StatusCode: {asg_update['StatusCode']}") # print(f" StatusCause: {asg_update['Cause']}") # print(f" StatusDescription: {asg_update['Description']}") # print("") #pprint.pprint(asg_update) #print("") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('mode', choices = ["setup", "shutdown"], help = "AddInstances / DeleteInstances.") parser.add_argument('--profile', default = "default", required = False, help = "AWS profile") parser.add_argument('--workers', required = False, help = "Required Worker nodes count") args = parser.parse_args() if args.mode == 'setup': # Setup Instances worker_node_count = args.workers leader_node_count = "1" desired_count = str(int(worker_node_count)+1) elif args.mode == 'shutdown': # Shutdown instances leader_node_count = worker_node_count = desired_count = "0" session = boto3.Session(profile_name=args.profile) ecs = session.client('ecs') ecs_service = EcsService(ecs) asg_client = session.client('autoscaling') asg = ScalingGroup(asg_client) asg.update_asg(int(desired_count)) try: ecs_service.update_ecs_service(int(leader_node_count), int(worker_node_count)) except Exception as e: logger.info(f"Failed to update ECS service. {e}") logger.info("Updating ASG") asg.update_asg("0") # wait for ECS setup/teardown to complete await_completion(ecs_service, asg_client)