#!/usr/bin/env python3
import json
import logging
import os
import subprocess
import sys
import threading

logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')


class Runner:
    def __init__(self, request_directory: str):
        self.logger = logging.getLogger("Runner")
        self.logger.setLevel(logging.INFO)
        self.request_directory = request_directory
        os.environ['PYTHONUNBUFFERED'] = "1"

    def process_stream(self, stream, str_name, file_handle):
        line = stream.readline()
        while line != "":
            self.logger.info(f"{str_name}: {line}")
            file_handle.write(line)
            line = stream.readline()

    def run(self, cmd: list):
        self.logger.info("Running command: %s", str(cmd))
        
        stdout_target_loc = os.path.join(self.request_directory, "stdout.log")
        stderr_target_loc = os.path.join(self.request_directory, "stderr.log")
        
        with open(stdout_target_loc, "w") as stdout_handle:
            with open(stderr_target_loc, "w") as stderr_handle:
                proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                    universal_newlines=True, start_new_session=True)
                stdout_t = threading.Thread(target = self.process_stream, args=(proc.stdout, "STDOUT", stdout_handle))
                stderr_t = threading.Thread(target = self.process_stream, args=(proc.stderr, "STDERR", stderr_handle))
                stdout_t.start()
                stderr_t.start()
                return_code = proc.wait()
                stdout_t.join()
                stderr_t.join()
       
        return {
            "stdout": stdout_target_loc,
            "stderr": stderr_target_loc,
            "return_code": return_code,
            "output_directory": self.request_directory
        }

    def get_input_json(self):
        input = os.path.join(self.request_directory, "input.json")
        with open(input) as f:
            return json.loads(f.read())
    
    def create_hostfile(self,ips, request_dir):
        hostfile_path = os.path.join(request_dir, 'combined_hostfile')
        with open(hostfile_path, 'w+') as f:
            for ip in ips:
                f.write(f'{ip} slots=4') # distributed default
                #f.write(f'{ip} slots=16') # parallel default
                f.write('\n')
            return hostfile_path

    def get_command(self, input_json):
        formula_file = input_json.get("formula_file")
        worker_node_ips = input_json.get("worker_node_ips", [])
        
        combined_hostfile = self.create_hostfile(worker_node_ips, self.request_directory)

        run_list = ["/competition/run_solver.sh"]
        run_list.append(combined_hostfile)
        run_list.append(formula_file)

        return run_list


class MallobParser:
    @staticmethod
    def get_result(output_file):
        """
        TODO: Participants should replace this with something more robust for their own solver!
        """
        with open(output_file) as f:
            raw_logs = f.read()
            if "found result UNSAT" in raw_logs:
                return 20
            elif "found result SAT" in raw_logs:
                return 10
            elif "UNKNOWN" in raw_logs:
                return 0
            else:
                return -1

if __name__ == "__main__":
    request_directory = sys.argv[1]
    runner = Runner(request_directory)
    
    input_json = runner.get_input_json()
    cmd = runner.get_command(input_json)
    
    print (f"cmd: {cmd}")
    output = runner.run(cmd)
    result = MallobParser.get_result(output["stdout"])
    logging.info(f"RESULT: {result}")
    solver_output = {
        "return_code": result,
        "artifacts": {
            "stdout_path": output["stdout"],
            "stderr_path": output["stderr"]
        }
    }
    
    with open(os.path.join(request_directory, "solver_out.json"), "w+") as f:
        f.write(json.dumps(solver_output))