#!/usr/bin/env python3
import argparse
import logging
import boto3
import json       

logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("Runner")
logger.setLevel(logging.INFO)

QUEUE_WAIT_TIME = 10



class SqsService:
    def __init__(self, client):
        self.sqs = client

    def get_satcomp_queue(self):
        """Get SQS services and return sat comp queue.
        
        Returns: SQS SatCompQueue
        
        """
        try:
            response = self.sqs.list_queues()
            for service in response['QueueUrls']:
                if service.endswith('SatCompQueue'):
                    return service
            
            raise "No queue ending with 'SatCompQueue' "
        except Exception as e:
            logger.error(f"Failed to get SQS queue: {e}")
            raise e

    def get_satcomp_output_queue(self):
        """Get SQS services and return sat comp queue.
        
        Returns: SQS SatCompQueue
        
        """
        try:
            response = self.sqs.list_queues()
            for service in response['QueueUrls']:
                if service.endswith('SatCompOutputQueue'):
                    return service
            
            raise "No queue ending with 'SatCompOutputQueue' "
        except Exception as e:
            logger.error(f"Failed to get SQS queue: {e}")
            raise e


    def send_message(self, location, workers, timeout, solverName, language, solverOptions):
        # Expected message structure:
        """{
            "formula" : {
                "value" : <s3 url>,
                "language": "SMTLIB2" | "DIMACS",
            },
            "solverConfig" : {
                "solverName" : "",
                "solverOptions" : [],
                "taskTimeoutSeconds" : 5
            },
            "num_workers": 0
        }"""
        queue = self.get_satcomp_queue()

        message_body = { \
                "formula": { \
                    "value": location, \
                    "language": language \
                }, \
                "solverConfig" : { \
                    "solverName" : solverName, \
                    "solverOptions" : solverOptions, \
                    "taskTimeoutSeconds" : timeout, \
                }, \
                "num_workers": workers \
            }

        message_body_str = json.dumps(message_body, indent = 4)
        try:
            response = self.sqs.send_message(
                QueueUrl = queue,
                MessageBody = message_body_str
            )
        except Exception as e:
            logger.error(f"Failed to send message: Exception: {e}")
            raise e

#
#
# {
#     'Messages': [
#         {
#             'MessageId': 'string',
#             'ReceiptHandle': 'string',
#             'MD5OfBody': 'string',
#             'Body': 'string',
#             'Attributes': {
#                 'string': 'string'
#             },
#             'MD5OfMessageAttributes': 'string',
#             'MessageAttributes': {
#                 'string': {
#                     'StringValue': 'string',
#                     'BinaryValue': b'bytes',
#                     'StringListValues': [
#                         'string',
#                     ],
#                     'BinaryListValues': [
#                         b'bytes',
#                     ],
#                     'DataType': 'string'
#                 }
#             }
#         },
#     ]
# }
    def receive_and_delete_message(self, timeout):
        queue = self.get_satcomp_output_queue()
        logger.info(f"Receiving and deleting message from queue {queue}")
        total_time = 0
        while (total_time < timeout + 5):
            logger.info(f"Waiting up to 10s for a message.")
            response = self.sqs.receive_message(
                QueueUrl = queue, 
                WaitTimeSeconds = QUEUE_WAIT_TIME
            )
            if response["Messages"]:
                for msg in response["Messages"]:
                    body = msg["Body"]
                    logger.info(f"Response from receive_message was: {body}")
                    self.sqs.delete_message(
                        QueueUrl = queue,
                        ReceiptHandle = msg["ReceiptHandle"]
                    )
                return response
            
            total_time = total_time + QUEUE_WAIT_TIME
        logger.error("Solver did not complete within expected timeout.")


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--profile', required = False, help = "AWS profile")
    parser.add_argument('--location', required = True, help = "S3 location for CNF file")
    parser.add_argument('--workers', required = True, type=int, help = "Required Worker nodes count")
    parser.add_argument('--timeout', type=int,  help = "Timeout value for the infrastructure to interrupt the solver", default = 60)
    parser.add_argument('--name', help = "Name of solver to be invoked (passed through to the solver).  Default: empty string", default = "")
    parser.add_argument('--format', help = "Problem format for the problem to be solved.", default = "")
    parser.add_argument('--args', nargs='+', help="Arguments to pass through to the solver (--args accepts a space-delimited list of arguments).  Default: empty list", default = []) 
    parser.add_argument('--await-response', required = False, help = "If true, then solver will poll output queue for response message, display, and delete it.")
    args = parser.parse_args()


    profile = args.profile
    
    # Send message
    session = boto3.Session(profile_name=profile)
    sqs_client = session.client('sqs')
    sqs = SqsService(sqs_client)
    try:
        sqs.send_message(args.location, args.workers, args.timeout, args.name, args.format, args.args)
        if args.await_response:
            sqs.receive_and_delete_message(args.timeout)
    except Exception as e:
        logger.info(f"Failed to send message. {e}")