2023-03-28 08:48:09 +00:00

172 lines
5.9 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import logging
import boto3
import json
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("Runner")
logger.setLevel(logging.INFO)
QUEUE_WAIT_TIME = 10
class SqsService:
def __init__(self, client):
self.sqs = client
def get_satcomp_queue(self):
"""Get SQS services and return sat comp queue.
Returns: SQS SatCompQueue
"""
try:
response = self.sqs.list_queues()
for service in response['QueueUrls']:
if service.endswith('SatCompQueue'):
return service
raise "No queue ending with 'SatCompQueue' "
except Exception as e:
logger.error(f"Failed to get SQS queue: {e}")
raise e
def get_satcomp_output_queue(self):
"""Get SQS services and return sat comp queue.
Returns: SQS SatCompQueue
"""
try:
response = self.sqs.list_queues()
for service in response['QueueUrls']:
if service.endswith('SatCompOutputQueue'):
return service
raise "No queue ending with 'SatCompOutputQueue' "
except Exception as e:
logger.error(f"Failed to get SQS queue: {e}")
raise e
def send_message(self, location, workers, timeout, solverName, language, solverOptions):
# Expected message structure:
"""{
"formula" : {
"value" : <s3 url>,
"language": "SMTLIB2" | "DIMACS",
},
"solverConfig" : {
"solverName" : "",
"solverOptions" : [],
"taskTimeoutSeconds" : 5
},
"num_workers": 0
}"""
queue = self.get_satcomp_queue()
message_body = { \
"formula": { \
"value": location, \
"language": language \
}, \
"solverConfig" : { \
"solverName" : solverName, \
"solverOptions" : solverOptions, \
"taskTimeoutSeconds" : timeout, \
}, \
"num_workers": workers \
}
message_body_str = json.dumps(message_body, indent = 4)
try:
response = self.sqs.send_message(
QueueUrl = queue,
MessageBody = message_body_str
)
except Exception as e:
logger.error(f"Failed to send message: Exception: {e}")
raise e
#
#
# {
# 'Messages': [
# {
# 'MessageId': 'string',
# 'ReceiptHandle': 'string',
# 'MD5OfBody': 'string',
# 'Body': 'string',
# 'Attributes': {
# 'string': 'string'
# },
# 'MD5OfMessageAttributes': 'string',
# 'MessageAttributes': {
# 'string': {
# 'StringValue': 'string',
# 'BinaryValue': b'bytes',
# 'StringListValues': [
# 'string',
# ],
# 'BinaryListValues': [
# b'bytes',
# ],
# 'DataType': 'string'
# }
# }
# },
# ]
# }
def receive_and_delete_message(self, timeout):
queue = self.get_satcomp_output_queue()
logger.info(f"Receiving and deleting message from queue {queue}")
total_time = 0
while (total_time < timeout + 5):
logger.info(f"Waiting up to 10s for a message.")
response = self.sqs.receive_message(
QueueUrl = queue,
WaitTimeSeconds = QUEUE_WAIT_TIME
)
if response["Messages"]:
for msg in response["Messages"]:
body = msg["Body"]
logger.info(f"Response from receive_message was: {body}")
self.sqs.delete_message(
QueueUrl = queue,
ReceiptHandle = msg["ReceiptHandle"]
)
return response
total_time = total_time + QUEUE_WAIT_TIME
logger.error("Solver did not complete within expected timeout.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--profile', required = False, help = "AWS profile")
parser.add_argument('--location', required = True, help = "S3 location for CNF file")
parser.add_argument('--workers', required = True, type=int, help = "Required Worker nodes count")
parser.add_argument('--timeout', type=int, help = "Timeout value for the infrastructure to interrupt the solver", default = 60)
parser.add_argument('--name', help = "Name of solver to be invoked (passed through to the solver). Default: empty string", default = "")
parser.add_argument('--format', help = "Problem format for the problem to be solved.", default = "")
parser.add_argument('--args', nargs='+', help="Arguments to pass through to the solver (--args accepts a space-delimited list of arguments). Default: empty list", default = [])
parser.add_argument('--await-response', required = False, help = "If true, then solver will poll output queue for response message, display, and delete it.")
args = parser.parse_args()
profile = args.profile
# Send message
session = boto3.Session(profile_name=profile)
sqs_client = session.client('sqs')
sqs = SqsService(sqs_client)
try:
sqs.send_message(args.location, args.workers, args.timeout, args.name, args.format, args.args)
if args.await_response:
sqs.receive_and_delete_message(args.timeout)
except Exception as e:
logger.info(f"Failed to send message. {e}")