cloud-sat/src/workers/sharer.cpp
2023-04-07 09:22:40 +00:00

296 lines
10 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "../light.hpp"
#include "basesolver.hpp"
#include "sharer.hpp"
#include "clause.hpp"
#include <unistd.h>
#include "../distributed/comm_tag.h"
#include <boost/thread/thread.hpp>
const int BUF_SIZE = 1024 * 1024;
std::vector<std::pair<MPI_Request*, int*>> send_data_struct;
MPI_Request receive_request;
int buf[BUF_SIZE];
void share_clauses_to_next_node(const vec<clause_store *> &cls) {
// 清理 send_data_struct把发送完毕的发送数据结构清理掉
for(int i=0; i<send_data_struct.size(); i++) {
auto& request = send_data_struct[i].first;
auto& send_buf = send_data_struct[i].second;
// 已完成发送,释放内存空间
int flag;
if(MPI_Test(request, &flag, MPI_STATUS_IGNORE) == MPI_SUCCESS && flag == 1) {
delete request;
delete []send_buf;
// 与数组最后一个交换,然后 pop_back;
std::swap(send_data_struct[i], send_data_struct[send_data_struct.size()-1]);
send_data_struct.pop_back();
printf("c [worker] free send request, now: %d\n", send_data_struct.size());
}
}
// 定义发送数据
MPI_Request *send_request = new MPI_Request();
int *send_buf;
int send_length = 0;
// 初始化发送数据
for(int i=0; i<cls.size(); i++) {
send_length += (cls[i]->size + 2);
}
send_buf = new int[send_length];
int index = 0;
for(int i=0; i<cls.size(); i++) {
auto& c = cls[i];
send_buf[index++] = c->size;
send_buf[index++] = c->lbd;
for(int j=0; j<c->size; j++) {
send_buf[index++] = c->data[j];
}
}
assert(index == send_length);
// 调用 MPI 发送共享子句
int num_procs, rank;
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int target = rank % (num_procs - 1) + 1;
MPI_Isend(send_buf, send_length, MPI_INT, target, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, send_request);
send_data_struct.push_back(std::make_pair(send_request, send_buf));
printf("c [worker] send clauses: %d\n", send_length);
}
bool receive_clauses_from_last_node(vec<clause_store*> &clauses) {
clauses.clear();
int flag;
MPI_Status status;
bool received = false;
// 已接收完数据
while(MPI_Test(&receive_request, &flag, &status) == MPI_SUCCESS && flag == 1) {
int index = 0;
int count;
MPI_Get_count(&status, MPI_INT, &count);
while(index < count) {
clause_store* cl = new clause_store(buf[index++]);
cl->lbd = buf[index++];
for(int i=0; i<cl->size; i++) {
cl->data[i] = buf[index++];
}
clauses.push(cl);
}
assert(index == count);
// 进行下一步接收数据
int num_procs, rank;
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int from = (rank - 2 + num_procs - 1) % (num_procs - 1) + 1;
//LOGGER->info("receive clauses: %v", count);
MPI_Irecv(buf, BUF_SIZE, MPI_INT, from, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, &receive_request);
received = true;
}
return received;
}
void * share_worker(void *arg) {
int nums = 0;
sharer * sq = (sharer *)arg;
auto clk_st = std::chrono::high_resolution_clock::now();
double share_time = 0;
int num_procs, rank;
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int from = (rank - 2 + num_procs - 1) % (num_procs - 1) + 1;
MPI_Irecv(buf, BUF_SIZE, MPI_INT, from, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, &receive_request);
while (true) {
++nums;
usleep(sq->share_intv);
auto clk_now = std::chrono::high_resolution_clock::now();
int solve_time = std::chrono::duration_cast<std::chrono::milliseconds>(clk_now - clk_st).count();
//LOGGER->info("round %v, time: %v.%v", nums, solve_time / 1000, solve_time % 1000);
printf("c [worker] round %d, time: %d.%d\n", nums, solve_time / 1000, solve_time % 1000);
if (terminated) {
MPI_Cancel(&receive_request);
for(int i=0; i<send_data_struct.size(); i++) {
auto& request = send_data_struct[i].first;
auto& send_buf = send_data_struct[i].second;
MPI_Cancel(request);
delete []send_buf;
}
break;
}
// printf("start sharing %d\n", sq->share_intv);
for (int i = 0; i < sq->producers.size(); i++) {
sq->cls.clear();
sq->producers[i]->export_clauses_to(sq->cls);
//printf("c size %d\n", sq->cls.size());
int number = sq->cls.size();
printf("c [worker] thread-%d: get %d exported clauses\n", i, number);
//分享当前节点产生的子句
if(sq->cls.size() > 0) share_clauses_to_next_node(sq->cls);
// 导入外部网络传输的子句
vec<clause_store*> clauses;
if(receive_clauses_from_last_node(clauses)) {
for (int j = 0; j < sq->consumers.size(); j++) {
for (int k = 0; k < clauses.size(); k++)
clauses[k]->increase_refs(1);
sq->consumers[j]->import_clauses_from(clauses);
}
// 传递外部网络传输的子句给下个节点
share_clauses_to_next_node(clauses);
for (int k = 0; k < clauses.size(); k++) {
clauses[k]->free_clause();
}
}
// 导入当前节点产生的子句
int percent = sq->sort_clauses(i);
if (percent < 75) {
sq->producers[i]->broaden_export_limit();
}
else if (percent > 98) {
sq->producers[i]->restrict_export_limit();
}
for (int j = 0; j < sq->consumers.size(); j++) {
if (sq->producers[i]->id == sq->consumers[j]->id) continue;
for (int k = 0; k < sq->cls.size(); k++)
sq->cls[k]->increase_refs(1);
sq->consumers[j]->import_clauses_from(sq->cls);
}
for (int k = 0; k < sq->cls.size(); k++) {
sq->cls[k]->free_clause();
}
}
auto clk_ed = std::chrono::high_resolution_clock::now();
share_time += 0.001 * std::chrono::duration_cast<std::chrono::milliseconds>(clk_ed - clk_now).count();
}
printf("c sharing nums: %d\nc sharing time: %.2lf\n", nums, share_time);
// if (terminated) puts("terminated set to 1");
return NULL;
}
int sharer::import_clauses(int id) {
int current_period = producers[id]->get_period() - 1, import_period = current_period - margin;
if (import_period < 0) return 0;
basesolver *t = producers[id];
for (int i = 0; i < producers.size(); i++) {
if (i == id) continue;
basesolver *s = producers[i];
//wait for s reach period $import_period
// printf("c %d start waiting, since import_p is %d, current_p is %d\n", id, import_period, s->get_period());
boost::mutex::scoped_lock lock(s->mtx);
while (s->period <= import_period && s->terminate_period > import_period && !s->terminated) {
s->cond.wait(lock);
}
if (s->terminated) return -1;
if (s->terminate_period <= import_period) return -1;
period_clauses *pc = s->pq.find(import_period);
if (pc->period != import_period) {
printf("thread %d, now period = %d, import period = %d, import thread %d\n", id, current_period, import_period, i);
puts("*******************************************************");
}
// printf("c %d finish waiting %d %d\n", id, import_period, s->period_queue[pos]->period);
t->import_clauses_from(pc->cls);
}
t->unfree.push(import_period);
return 1;
// printf("c thread %d, period %d, import finish\n", id, current_period);
}
int sharer::sort_clauses(int x) {
for (int i = 0; i < cls.size(); i++) {
int sz = cls[i]->size;
while (sz > bucket[x].size()) bucket[x].push();
if (sz * (bucket[x][sz - 1].size() + 1) <= share_lits)
bucket[x][sz - 1].push(cls[i]);
// else
// cls[i]->free_clause();
}
cls.clear();
int space = share_lits;
for (int i = 0; i < bucket[x].size(); i++) {
int clause_num = space / (i + 1);
// printf("%d %d\n", clause_num, bucket[x][i].size());
if (!clause_num) break;
if (clause_num >= bucket[x][i].size()) {
space -= bucket[x][i].size() * (i + 1);
for (int j = 0; j < bucket[x][i].size(); j++)
cls.push(bucket[x][i][j]);
bucket[x][i].clear();
}
else {
space -= clause_num * (i + 1);
for (int j = 0; j < clause_num; j++) {
cls.push(bucket[x][i].last());
bucket[x][i].pop();
}
}
}
return (share_lits - space) * 100 / share_lits;
}
void light::share() {
// printf("c sharing start\n");
if (OPT(DPS)) {
sharer* s = new sharer(0, OPT(share_intv), OPT(share_lits), OPT(DPS));
s->margin = OPT(margin);
for (int j = 0; j < OPT(threads); j++) {
s->producers.push(workers[j]);
workers[j]->in_sharer = s;
}
sharers.push(s);
}
else {
int sharers_number = 1;
for (int i = 0; i < sharers_number; i++) {
sharer* s = new sharer(i, OPT(share_intv), OPT(share_lits), OPT(DPS));
for (int j = 0; j < OPT(threads); j++) {
s->producers.push(workers[j]);
s->consumers.push(workers[j]);
workers[j]->in_sharer = s;
}
sharers.push(s);
}
sharer_ptrs = new pthread_t[sharers_number];
for (int i = 0; i < sharers_number; i++) {
pthread_create(&sharer_ptrs[i], NULL, share_worker, sharers[i]);
}
}
}