#!/usr/bin/env python3
import argparse
import logging
import boto3   
import time 
import pprint

logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("Runner")
logger.setLevel(logging.INFO)


class EcsService:
    def __init__(self, client):
        self.ecs = client

    def get_ecs_service(self, service_name):
        """Get ECS services and return Worker service name
        
        Returns: ECS worker node servicename
        
        """
        try:
            response = self.ecs.list_services(
                cluster='SatCompCluster',
            )

            for service in response['serviceArns']:
                if service_name in service:
                    return service
        except Exception as e:
            logger.error(f"Failed to get ECS service names: {e}")
            raise e

    def update_ecs_service(self, leader_node_count, worker_node_count):
        worker_service = self.get_ecs_service("SolverWorkerService")
        leader_service = self.get_ecs_service("SolverLeaderService")
        try:
            ecs.update_service(
                    cluster="SatCompCluster",
                    service=leader_service,
                    desiredCount=leader_node_count
                )
            ecs.update_service(
                    cluster="SatCompCluster",
                    service=worker_service,
                    desiredCount=worker_node_count
                )
        except Exception as e:
            logger.error(f"Failed to update ECS service: {e}")
            raise e

    def describe_ecs_services(self):
        result = {}
        worker_service = self.get_ecs_service("SolverWorkerService")
        leader_service = self.get_ecs_service("SolverLeaderService")
        try:
            result = ecs.describe_services(
                                    cluster="SatCompCluster",
                                    services=[leader_service, worker_service]
            )
        except Exception as e:
            logger.error(f"Failed to describe ECS service: {e}")
            raise e
        return result

class ScalingGroup:
    def __init__(self, client) -> None:
        self.asg_client = client

    def update_asg(self, desired_count: str):
        try:
            response = self.asg_client.describe_auto_scaling_groups()['AutoScalingGroups']
            for group in response:
                if 'EcsInstanceAsg' in group["AutoScalingGroupName"]:
                    asg_name = group["AutoScalingGroupName"]

            response = self.asg_client.update_auto_scaling_group(
                AutoScalingGroupName= asg_name,
                MaxSize=desired_count,
                DesiredCapacity=desired_count,
            )
        except Exception as e:
            logger.error(f"Failed to update ASG: {e}")
            raise e

    
def await_completion(ecs_service, asg_client): 
    # wait for ECS setup/teardown to complete
    start = time.time()
    while True:
        status = ecs_service.describe_ecs_services()

        leader = status["services"][0]["deployments"][0]
        leader_desired = leader["desiredCount"]
        leader_pending = leader["pendingCount"]
        leader_running = leader["runningCount"]

        worker = status["services"][1]["deployments"][0]
        worker_desired = worker["desiredCount"]
        worker_pending = worker["pendingCount"]
        worker_running = worker["runningCount"]

        elapsed = time.time() - start
        print(f"Waiting for ECS ({elapsed/60.:3.1f} mins)")
        print(f"  leader: {leader_desired} desired, {leader_pending} pending, {leader_running} running")
        print(f" workers: {worker_desired} desired, {worker_pending} pending, {worker_running} running")
        print("")

        if (leader_desired==leader_running and worker_desired==worker_running):
            print("ECS configuration complete")
            return

        time.sleep(30)

        # MWW: I am disabling this output since the output 'lies' in the sense that it reports 
        # a failure in the usual case, before eventually succeeding.

        # put this after the first sleep, since it's usually delayed
        # asg_status = asg_client.describe_scaling_activities()
        # only display the most recent message
        # asg_update = asg_status["Activities"][0]
        # print(f"Most recent AutoScaling Activity Log")
        # print(f"         StatusCode: {asg_update['StatusCode']}")
        # print(f"        StatusCause: {asg_update['Cause']}")
        # print(f"  StatusDescription: {asg_update['Description']}")
        # print("")
        #pprint.pprint(asg_update)
        #print("")


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('mode', choices = ["setup", "shutdown"], help = "AddInstances / DeleteInstances.")
    parser.add_argument('--profile', default = "default", required = False, help = "AWS profile")
    parser.add_argument('--workers', required = False, help = "Required Worker nodes count")
    
    args = parser.parse_args()

    if args.mode == 'setup':
        # Setup Instances
        worker_node_count = args.workers
        leader_node_count = "1"
        desired_count = str(int(worker_node_count)+1)
    elif args.mode == 'shutdown':
        # Shutdown instances
        leader_node_count = worker_node_count = desired_count = "0"

    session = boto3.Session(profile_name=args.profile)
    ecs = session.client('ecs')
    ecs_service = EcsService(ecs)

    asg_client = session.client('autoscaling')
    asg = ScalingGroup(asg_client)
    
    asg.update_asg(int(desired_count))
    try:
        ecs_service.update_ecs_service(int(leader_node_count), int(worker_node_count))
    except Exception as e:
        logger.info(f"Failed to update ECS service. {e}")
        logger.info("Updating ASG")
        asg.update_asg("0")

    # wait for ECS setup/teardown to complete
    await_completion(ecs_service, asg_client)