#!/usr/bin/python
# -*- coding: UTF-8 -*-

from multiprocessing import set_forkserver_preload
import os
import os.path
from posixpath import split
from random import sample
import re
import shutil
from time import monotonic, sleep
from tokenize import Number

# global limit
CUTOFF = 1000
PUNISH = 2 #PAR2
MEMS_MAX = 61440 # 60G

class states(object):
    res = "unknown"
    time = CUTOFF*PUNISH
    mems = MEMS_MAX
    mono = False        # only this one can solve
    best = False        # show the best performance
    ls_time = 0         # LS_time

class solver(object):
    def __init__(self, res_dir, name):
        self.res_dir    = res_dir  # save the results files
        self.print_name = name     # names want to show
        self.datas      = dict()   # datas[ins] save the instances
    def reset(self):
        # SAT-ins UNSAT-ins solved-ins all-ins
        self.sat_num = self.unsat_num = self.solved_num = self.all_num = 0
        self.avg_sat_time = self.avg_unsat_time = self.avg_solved_time = self.avg_all_time = 0.0
        self.PAR_sat_time = self.PAR_unsat_time = self.PAR_solved_time = self.PAR_all_time = 0.0
        self.mono_num = 0
        self.best_num = 0
    def cal_soln(self, ins_name):
        self.all_num += 1
        state = self.datas[ins_name]
        if(self.datas[ins_name].time > CUTOFF):
            self.datas[ins_name] = states()
        if(state.res=="sat"):
            self.sat_num            += 1
            self.solved_num         += 1
            self.avg_sat_time       += state.time
            self.avg_solved_time    += state.time
            self.avg_all_time       += state.time
            self.PAR_sat_time       += state.time
            self.PAR_solved_time    += state.time
            self.PAR_all_time       += state.time
        elif(state.res=="unsat"):
            self.unsat_num          += 1
            self.solved_num         += 1
            self.avg_unsat_time     += state.time
            self.avg_solved_time    += state.time
            self.avg_all_time       += state.time
            self.PAR_unsat_time     += state.time
            self.PAR_solved_time    += state.time
            self.PAR_all_time       += state.time
        else:
            self.avg_all_time       += CUTOFF
            self.PAR_all_time       += CUTOFF * PUNISH
    def deal_avg(self):
        if(self.sat_num>0):
            self.avg_sat_time    /= self.sat_num
            self.PAR_sat_time    /= self.sat_num    
        if(self.unsat_num>0):
            self.avg_unsat_time  /= self.unsat_num
            self.PAR_unsat_time  /= self.unsat_num
        if(self.solved_num>0):
            self.avg_solved_time /= self.solved_num
            self.PAR_solved_time /= self.solved_num
        if(self.all_num>0):
            self.avg_all_time    /= self.all_num
            self.PAR_all_time    /= self.all_num
    def to_string(self, state):
        line = ""
        line += str(state.res) + " "
        line += str(round(state.time,2))
        if state.mono:
            line += "[M]"
        elif state.best:
            line += "[B]"
        # if (state.byCDCL):
        #     line += "{C}"
        # elif(state.byLS):
        #     line += "{L}"
        line += str()
        return line.ljust(18)

        return super().to_string(state)


class solver_SAT_standard_gnomon(solver):
    def cal_soln(self, ins_name):
        if(not ins_name in self.datas):
            self.datas[ins_name] = states()
            real_file_path = self.res_dir + "/" + ins_name
            fstr = open(real_file_path, "r").read()
                
            if(not len(re.findall(r"s\s+UNSATISFIABLE", fstr))==0):
                self.datas[ins_name].res = "unsat"
            elif(not len(re.findall(r"s\s+SATISFIABLE", fstr))==0):
                self.datas[ins_name].res = "sat"
            
            if(not self.datas[ins_name].res == "unknown"):

                timestr = re.findall(r"real\s+(\d+\.\d+)", fstr)[-1]

                # timestr = re.findall(r"real.*m.*s", fstr)[-1]
                # minute = int(timestr.split('m')[0].split()[-1])
                # second = float(timestr.split('m')[-1].split('s')[0])
                self.datas[ins_name].time = float(timestr)
                if (self.datas[ins_name].time > CUTOFF*PUNISH):
                    self.datas[ins_name].res="unknown"
                # confstr = re.findall(r"c conflicts:.*per second", fstr)[-1]
                # self.datas[ins_name].time = int(confstr.split()[-4])
        
        return super().cal_soln(ins_name)
    def to_string(self, state):
        return super().to_string(state)

SOLVER_LEN = 20
SAMPLE_LEN = 20
NUMBER_LEN = 8
print_title = True
class calculater(object):
    solvers     = []
    sample_dirs = []    # sample dirs, [sample_dir, sample_name]s
    def __init__(self, solvers, sample_dirs):
        self.solvers = solvers
        self.sample_dirs = sample_dirs
    def __show_in_mark_down(self, samp_name):
        global print_title
        if(print_title):
            print_title = False
            title =  "| sample".ljust(SAMPLE_LEN+2)
            title += " | solver".ljust(SOLVER_LEN+3)
            title += " | #SAT".ljust(NUMBER_LEN+3)
            title += " | avg_t".ljust(NUMBER_LEN+3)
            title += " | #UNSAT".ljust(NUMBER_LEN+3)
            title += " | avg_t".ljust(NUMBER_LEN+3)
            title += " | #ALL".ljust(NUMBER_LEN+3)
            title += " | PAR2_t".ljust(NUMBER_LEN+3)
            title += " | best".ljust(NUMBER_LEN+3)
            title += " | mono".ljust(NUMBER_LEN+3)
            title += " | s".ljust(NUMBER_LEN+3)
            title += " | TIME".ljust(NUMBER_LEN+3)
            title += " |"
            print(title)

            split =  "| "  + '-'*(SAMPLE_LEN)
            split += " | " + '-'*(SOLVER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " | " + '-'*(NUMBER_LEN)
            split += " |"
            self.split_line = split
            print(split)

        #sota = self.solvers[0].solved_num * self.solvers[0].PAR_solved_time + CUTOFF * PUNISH * (self.sample_ins_ct - self.solvers[0].solved_num)
        sota = self.solvers[0].PAR_all_time * self.sample_ins_ct

        for slv in self.solvers:

            s = (sota - CUTOFF * PUNISH * (self.sample_ins_ct - slv.solved_num)) / (slv.solved_num * slv.PAR_solved_time)

            time = slv.solved_num * slv.PAR_solved_time + CUTOFF * PUNISH * (self.sample_ins_ct - slv.solved_num) / 1.5
            time = time / self.sample_ins_ct

            line =  "| "  + (samp_name + "("+str(self.sample_ins_ct) + ")").ljust(SAMPLE_LEN)
            line += " | " + slv.print_name.ljust(SOLVER_LEN)
            line += " | " + str(slv.sat_num).ljust(NUMBER_LEN)
            line += " | " + str(round(slv.avg_sat_time,2)).ljust(NUMBER_LEN)
            line += " | " + str(slv.unsat_num).ljust(NUMBER_LEN)
            line += " | " + str(round(slv.avg_unsat_time,2)).ljust(NUMBER_LEN)
            line += " | " + str(slv.solved_num).ljust(NUMBER_LEN)
            line += " | " + str(round(slv.PAR_all_time,2)).ljust(NUMBER_LEN)
            line += " | " + str(slv.best_num).ljust(NUMBER_LEN)
            line += " | " + str(slv.mono_num).ljust(NUMBER_LEN)
            line += " | " + str(round(s,2)).ljust(NUMBER_LEN)
            line += " | " + str(round(time,2)).ljust(NUMBER_LEN)
            line += " |"
            print(line)
        
    def cal_and_show(self):
        for sample in self.sample_dirs:
            title_line = ""
            for slv in self.solvers:
                title_line += slv.print_name.ljust(18)
            print(title_line)
            samp_dir  = sample[0]
            samp_name = sample[1]
            print_line_ct = 0
            sample_ins_ct = 0
            for slv in self.solvers:
                slv.reset()
            for ins_name in open(samp_dir):
                sample_ins_ct += 1
                ins_name = ins_name.strip()
                best_time = CUTOFF*PUNISH
                solved_ct = 0
                for slv in self.solvers:
                    slv.cal_soln(ins_name)
                    best_time = min(slv.datas[ins_name].time, best_time)
                    if not slv.datas[ins_name].res == "unknown":
                        solved_ct += 1
                if(not best_time == CUTOFF*PUNISH):
                    for slv in self.solvers:
                        if(slv.datas[ins_name].time == best_time):
                            slv.datas[ins_name].best = True
                            slv.best_num += 1
                            if(solved_ct == 1):
                                slv.datas[ins_name].mono = True
                                slv.mono_num += 1    


                line = ""
                no_answer       = True
                answer_this     = "unknown"
                all_can_solve   = True
                have_diff_res   = False
                for slv in self.solvers:
                    stt = slv.datas[ins_name]
                    line += slv.to_string(stt)
                    if(not stt.res == "unknown"):
                        no_answer = False
                        answer_this = stt.res
                    elif(stt.res == "unknown"):
                        all_can_solve = False
                line += ins_name
                if(not all_can_solve and not no_answer):
                    have_diff_res = True

                # if(True):
                if(False):
                # if(no_answer):
                # if(all_can_solve):
                # if(have_diff_res):
                # if(have_diff_res and answer_this == "sat"):
                # if(self.solvers[-2].datas[ins_name].res != self.solvers[-1].datas[ins_name].res):
                    print_line_ct += 1
                    print(line)
            
            self.sample_ins_ct = sample_ins_ct
            for slv in self.solvers:
                slv.deal_avg() 
            self.__show_in_mark_down(samp_name)
            if(print_line_ct>0):
                print("print line ct = ", print_line_ct)
            else:
                print(self.split_line)


def gen_samples(dir):
    samples = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            sample_name = file.strip(".txt")
            sample_dir = os.path.join(root, file) 
            # print(sample_dir, sample_name)
            samples.append([sample_dir, sample_name])
    return samples

if __name__ == "__main__":
    solvers = []
    # solvers.append(solver_SAT_standard_gnomon("/pub/netdisk1/qianyh/aws-batch-comp-infrastructure-sample/docker/runner/exp-result","mallob"))
    solvers.append(solver_SAT_standard_gnomon("./light-3m-no-pre","light-3m-no-pre"))
    solvers.append(solver_SAT_standard_gnomon("./light-3m","light-3m"))
    solvers.append(solver_SAT_standard_gnomon("./light-no-bug","light-no-bug"))
    solvers.append(solver_SAT_standard_gnomon("./light-circle","light-circle"))
    solvers.append(solver_SAT_standard_gnomon("./light-circle-unique","light-circle-unique"))
    solvers.append(solver_SAT_standard_gnomon("./first_version","bug-cloud"))
    
    # solvers.append(solver_SAT_standard_gnomon("/pub/data/chenzh/res/huawei_sat/kissat-mab","origin-mab"))
    # solvers.append(solver_SAT_standard_gnomon("/pub/data/chenzh/res/huawei_simp/kissat-mab","preprocess-mab"))
    samples = []
    samples.append(["/pub/data/chenzh/data/sat2022/vbs.txt", "dump_sat"])
    # samples.append(["./nohup.log", "dump_sat"])
    clt = calculater(solvers, samples)
    clt.cal_and_show()