#!/usr/bin/env python3
import argparse
import logging
import boto3
import json    
import subprocess   
from botocore.exceptions import ProfileNotFound

logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("Quickstart Run")
logger.setLevel(logging.INFO)

WORKER_NODES = 3

class STS: 
    def __init__(self, sts_client) -> None:
        self.sts = sts_client

    def get_account_number(self) -> str: 
        try: 
            return self.sts.get_caller_identity()["Account"]
        except Exception as e:
            logger.error(f"Failed to get profile account number: {e}")
            raise e


def argument_parser(): 
    parser = argparse.ArgumentParser()
    parser.add_argument('--profile', required = False, help = "AWS profile")
    parser.add_argument('--project', required = False, default = 'comp23', help = "ADVANCED USERS ONLY: Name of the project (default: 'comp23').")
    parser.add_argument('--keep-alive', required = False, help = "If True, then cluster is not destroyed after solving.")
    parser.add_argument('--s3-locations', nargs='+', required = False, help = "S3 URLs for problem location of the form: s3://BUCKET_NAME/OBJECT_LOCATION.  Default is to run test.cnf from the default bucket.  You can specify multiple URLs (space separated)")
    return parser

if __name__ == "__main__":
    parser = argument_parser()   
    args = parser.parse_args()

    try:
        if args.profile:
            session = boto3.Session(profile_name=args.profile)
        else:
            session = boto3.Session() 
    except ProfileNotFound as e:
        logger.error(f"Unable to create AWS session.  Please check that default profile is set up in the ~/.aws/config file and has appropriate rights (or if --profile was provided, that this profile has appropriate rights)")
        sys.exit(1)


    sts = STS(session.client('sts'))    
    account_number = sts.get_account_number()

    if not args.s3_locations:
        s3_bucket = f"s3://{account_number}-us-east-1-{args.project}"
        problem_locations = [f"{s3_bucket}/test.cnf"]
    else:
        problem_locations = args.s3_locations

    profile_args = ['--profile', args.profile] if args.profile else []

     # run the ecs-config script
    cmd = ['python3', 'ecs-config', 
        '--workers', str(WORKER_NODES), \
        'setup'] \
        + profile_args 
    logger.info(f"Setting up ECS cluster with {WORKER_NODES} worker nodes using ecs-config.")
    logger.info(f"command that is being executed is: {' '.join(cmd)}")
    logger.info("This operation will likely take 5-7 minutes.")
    logger.info("***Note that while configuration is in process, the system will report Autoscaling failures.  This is normal, but should not persist for more than 10 minutes.***")
    try: 
        result = subprocess.run(cmd)
        pass
    except Exception as e: 
        logger.error(f"Unexpected error {e}: Unable to run {cmd}. ")
        exit(-1)    
    if result.returncode:
        logger.error(f"ecs-config failed with error code {result}")
        logger.error("Have you set up the default profile in your ~/.aws/config file?")
        exit(-1)

    for problem_location in problem_locations:
        cmd = ['python3', 'send_message', '--location', problem_location, '--workers', str(WORKER_NODES), '--await-response', 'True'] + profile_args

        logger.info("Sending a message to the cluster to run the `temp.cnf' problem.")
        logger.info("It is stored in the S3 location: {problem_location}")
        logger.info("Please go inspect the CloudWatch logs for project {args.project} to see it running, as described in the Infrastructure README.")
        logger.info(f"command that is being executed is: {' '.join(cmd)}")
        try:
            result = subprocess.run(cmd)
            logger.info(f"Send message completed.  Intermediate files and stdout/stderr are available in {s3_bucket}/tmp")
        except Exception as e:
            logger.error(f"Unexpected error {e}: unable to run command to send message.")
        if result.returncode:
            logger.error("Unexpected error: Command to send message to queue failed.")

    if args.keep_alive:
        logger.info("Keep-alive option selected; not deleting cluster")
    else: 
        logger.info("Deleting the cluster.  This will require a minute or two.")
        cmd = ['python3', 'ecs-config', 'shutdown'] + profile_args 
        logger.info("Tearing down the cluster.")
        logger.info(f"command that is being executed is: {' '.join(cmd)}")
        try:
            result = subprocess.run(cmd)
        except Exception as e:
            logger.error(f"Unexpected error {e}: unable to run command: {cmd}.")
            exit(-1)
        if result.returncode:
            logger.error("Unexpected error returned by cluster teardown.  PLEASE MANUALLY CHECK WHETHER CLUSTER WAS DELETED.")
            exit(-1)


    logger.info("Run complete; quickstart-run successful!")
    exit(0)