172 lines
5.9 KiB
Python
Executable File
172 lines
5.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import argparse
|
|
import logging
|
|
import boto3
|
|
import json
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger("Runner")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
QUEUE_WAIT_TIME = 10
|
|
|
|
|
|
|
|
class SqsService:
|
|
def __init__(self, client):
|
|
self.sqs = client
|
|
|
|
def get_satcomp_queue(self):
|
|
"""Get SQS services and return sat comp queue.
|
|
|
|
Returns: SQS SatCompQueue
|
|
|
|
"""
|
|
try:
|
|
response = self.sqs.list_queues()
|
|
for service in response['QueueUrls']:
|
|
if service.endswith('SatCompQueue'):
|
|
return service
|
|
|
|
raise "No queue ending with 'SatCompQueue' "
|
|
except Exception as e:
|
|
logger.error(f"Failed to get SQS queue: {e}")
|
|
raise e
|
|
|
|
def get_satcomp_output_queue(self):
|
|
"""Get SQS services and return sat comp queue.
|
|
|
|
Returns: SQS SatCompQueue
|
|
|
|
"""
|
|
try:
|
|
response = self.sqs.list_queues()
|
|
for service in response['QueueUrls']:
|
|
if service.endswith('SatCompOutputQueue'):
|
|
return service
|
|
|
|
raise "No queue ending with 'SatCompOutputQueue' "
|
|
except Exception as e:
|
|
logger.error(f"Failed to get SQS queue: {e}")
|
|
raise e
|
|
|
|
|
|
def send_message(self, location, workers, timeout, solverName, language, solverOptions):
|
|
# Expected message structure:
|
|
"""{
|
|
"formula" : {
|
|
"value" : <s3 url>,
|
|
"language": "SMTLIB2" | "DIMACS",
|
|
},
|
|
"solverConfig" : {
|
|
"solverName" : "",
|
|
"solverOptions" : [],
|
|
"taskTimeoutSeconds" : 5
|
|
},
|
|
"num_workers": 0
|
|
}"""
|
|
queue = self.get_satcomp_queue()
|
|
|
|
message_body = { \
|
|
"formula": { \
|
|
"value": location, \
|
|
"language": language \
|
|
}, \
|
|
"solverConfig" : { \
|
|
"solverName" : solverName, \
|
|
"solverOptions" : solverOptions, \
|
|
"taskTimeoutSeconds" : timeout, \
|
|
}, \
|
|
"num_workers": workers \
|
|
}
|
|
|
|
message_body_str = json.dumps(message_body, indent = 4)
|
|
try:
|
|
response = self.sqs.send_message(
|
|
QueueUrl = queue,
|
|
MessageBody = message_body_str
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to send message: Exception: {e}")
|
|
raise e
|
|
|
|
#
|
|
#
|
|
# {
|
|
# 'Messages': [
|
|
# {
|
|
# 'MessageId': 'string',
|
|
# 'ReceiptHandle': 'string',
|
|
# 'MD5OfBody': 'string',
|
|
# 'Body': 'string',
|
|
# 'Attributes': {
|
|
# 'string': 'string'
|
|
# },
|
|
# 'MD5OfMessageAttributes': 'string',
|
|
# 'MessageAttributes': {
|
|
# 'string': {
|
|
# 'StringValue': 'string',
|
|
# 'BinaryValue': b'bytes',
|
|
# 'StringListValues': [
|
|
# 'string',
|
|
# ],
|
|
# 'BinaryListValues': [
|
|
# b'bytes',
|
|
# ],
|
|
# 'DataType': 'string'
|
|
# }
|
|
# }
|
|
# },
|
|
# ]
|
|
# }
|
|
def receive_and_delete_message(self, timeout):
|
|
queue = self.get_satcomp_output_queue()
|
|
logger.info(f"Receiving and deleting message from queue {queue}")
|
|
total_time = 0
|
|
while (total_time < timeout + 5):
|
|
logger.info(f"Waiting up to 10s for a message.")
|
|
response = self.sqs.receive_message(
|
|
QueueUrl = queue,
|
|
WaitTimeSeconds = QUEUE_WAIT_TIME
|
|
)
|
|
if response["Messages"]:
|
|
for msg in response["Messages"]:
|
|
body = msg["Body"]
|
|
logger.info(f"Response from receive_message was: {body}")
|
|
self.sqs.delete_message(
|
|
QueueUrl = queue,
|
|
ReceiptHandle = msg["ReceiptHandle"]
|
|
)
|
|
return response
|
|
|
|
total_time = total_time + QUEUE_WAIT_TIME
|
|
logger.error("Solver did not complete within expected timeout.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--profile', required = False, help = "AWS profile")
|
|
parser.add_argument('--location', required = True, help = "S3 location for CNF file")
|
|
parser.add_argument('--workers', required = True, type=int, help = "Required Worker nodes count")
|
|
parser.add_argument('--timeout', type=int, help = "Timeout value for the infrastructure to interrupt the solver", default = 60)
|
|
parser.add_argument('--name', help = "Name of solver to be invoked (passed through to the solver). Default: empty string", default = "")
|
|
parser.add_argument('--format', help = "Problem format for the problem to be solved.", default = "")
|
|
parser.add_argument('--args', nargs='+', help="Arguments to pass through to the solver (--args accepts a space-delimited list of arguments). Default: empty list", default = [])
|
|
parser.add_argument('--await-response', required = False, help = "If true, then solver will poll output queue for response message, display, and delete it.")
|
|
args = parser.parse_args()
|
|
|
|
|
|
profile = args.profile
|
|
|
|
# Send message
|
|
session = boto3.Session(profile_name=profile)
|
|
sqs_client = session.client('sqs')
|
|
sqs = SqsService(sqs_client)
|
|
try:
|
|
sqs.send_message(args.location, args.workers, args.timeout, args.name, args.format, args.args)
|
|
if args.await_response:
|
|
sqs.receive_and_delete_message(args.timeout)
|
|
except Exception as e:
|
|
logger.info(f"Failed to send message. {e}")
|
|
|