#include <chrono>
#include <thread>
#include <fstream>
#include <mpi.h>

#include "light.hpp"
#include "utils/cmdline.h"
#include "paras.hpp"

#include "leader.hpp"
#include "worker.hpp"

#include "clause.hpp"

int main(int argc, char **argv) {

    int num_procs, rank;
    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);

    light* S = new light();
    S->arg_parse(argc, argv);

    // 初始化环形数据结构

    int worker_procs = num_procs - 1;
    if(worker_procs >= 8) {
        // 分出一些 worker 跑其他策略
        int sat_procs = worker_procs / 8;
        int unsat_procs = sat_procs;
        int default_procs = worker_procs - sat_procs - unsat_procs;

        S->next_node = rank + 1;

        // [1, sat_procs] for sat
        if(rank >= 1 && rank <= sat_procs) {
            S->worker_type = light::SAT;
            if(S->next_node > sat_procs) {
                S->next_node = 1;
            }
        }
        // [sat_procs+1, sat_procs+unsat_procs] for unsat
        if(rank >= sat_procs+1 && rank <= sat_procs+unsat_procs) {
            S->worker_type = light::UNSAT;
            if(S->next_node > sat_procs+unsat_procs) {
                S->next_node = sat_procs+1;
            }
        }
        // [sat_procs+unsat_procs+1, worker_procs]
        if(rank >= sat_procs+unsat_procs+1 && rank <= worker_procs) {
            S->worker_type = light::DEFAULT;
            if(S->next_node > worker_procs) {
                S->next_node = sat_procs+unsat_procs+1;
            }
        }

    } else {
        // 总线程太小就跑默认策略
        S->worker_type = light::DEFAULT;
        S->next_node = rank + 1;
        if(S->next_node > worker_procs) {
            S->next_node = 1;
        }
    }

    // leader
    if(rank == 0) leader_main(S, num_procs, rank);
    else worker_main(S, num_procs, rank);

    MPI_Barrier(MPI_COMM_WORLD);
    MPI_Finalize();

    return 0;
}