296 lines
10 KiB
C++
296 lines
10 KiB
C++
#include "../light.hpp"
|
||
#include "basesolver.hpp"
|
||
#include "sharer.hpp"
|
||
#include "clause.hpp"
|
||
#include <unistd.h>
|
||
#include "../distributed/comm_tag.h"
|
||
#include <boost/thread/thread.hpp>
|
||
|
||
const int BUF_SIZE = 1024 * 1024;
|
||
std::vector<std::pair<MPI_Request*, int*>> send_data_struct;
|
||
MPI_Request receive_request;
|
||
int buf[BUF_SIZE];
|
||
|
||
void share_clauses_to_next_node(const vec<clause_store *> &cls) {
|
||
|
||
// 清理 send_data_struct,把发送完毕的发送数据结构清理掉
|
||
for(int i=0; i<send_data_struct.size(); i++) {
|
||
auto& request = send_data_struct[i].first;
|
||
auto& send_buf = send_data_struct[i].second;
|
||
|
||
// 已完成发送,释放内存空间
|
||
int flag;
|
||
if(MPI_Test(request, &flag, MPI_STATUS_IGNORE) == MPI_SUCCESS && flag == 1) {
|
||
delete request;
|
||
delete []send_buf;
|
||
// 与数组最后一个交换,然后 pop_back;
|
||
std::swap(send_data_struct[i], send_data_struct[send_data_struct.size()-1]);
|
||
send_data_struct.pop_back();
|
||
|
||
printf("c [worker] free send request, now: %d\n", send_data_struct.size());
|
||
}
|
||
}
|
||
|
||
// 定义发送数据
|
||
MPI_Request *send_request = new MPI_Request();
|
||
int *send_buf;
|
||
int send_length = 0;
|
||
|
||
// 初始化发送数据
|
||
for(int i=0; i<cls.size(); i++) {
|
||
send_length += (cls[i]->size + 2);
|
||
}
|
||
|
||
send_buf = new int[send_length];
|
||
int index = 0;
|
||
|
||
for(int i=0; i<cls.size(); i++) {
|
||
auto& c = cls[i];
|
||
send_buf[index++] = c->size;
|
||
send_buf[index++] = c->lbd;
|
||
for(int j=0; j<c->size; j++) {
|
||
send_buf[index++] = c->data[j];
|
||
}
|
||
}
|
||
|
||
assert(index == send_length);
|
||
|
||
// 调用 MPI 发送共享子句
|
||
|
||
int num_procs, rank;
|
||
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
|
||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||
|
||
int target = rank % (num_procs - 1) + 1;
|
||
|
||
MPI_Isend(send_buf, send_length, MPI_INT, target, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, send_request);
|
||
|
||
send_data_struct.push_back(std::make_pair(send_request, send_buf));
|
||
|
||
printf("c [worker] send clauses: %d\n", send_length);
|
||
}
|
||
|
||
bool receive_clauses_from_last_node(vec<clause_store*> &clauses) {
|
||
clauses.clear();
|
||
|
||
int flag;
|
||
MPI_Status status;
|
||
|
||
bool received = false;
|
||
|
||
// 已接收完数据
|
||
while(MPI_Test(&receive_request, &flag, &status) == MPI_SUCCESS && flag == 1) {
|
||
int index = 0;
|
||
int count;
|
||
MPI_Get_count(&status, MPI_INT, &count);
|
||
|
||
while(index < count) {
|
||
clause_store* cl = new clause_store(buf[index++]);
|
||
cl->lbd = buf[index++];
|
||
for(int i=0; i<cl->size; i++) {
|
||
cl->data[i] = buf[index++];
|
||
}
|
||
clauses.push(cl);
|
||
}
|
||
|
||
assert(index == count);
|
||
|
||
// 进行下一步接收数据
|
||
|
||
int num_procs, rank;
|
||
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
|
||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||
|
||
int from = (rank - 2 + num_procs - 1) % (num_procs - 1) + 1;
|
||
|
||
//LOGGER->info("receive clauses: %v", count);
|
||
|
||
MPI_Irecv(buf, BUF_SIZE, MPI_INT, from, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, &receive_request);
|
||
|
||
received = true;
|
||
}
|
||
|
||
return received;
|
||
}
|
||
|
||
void * share_worker(void *arg) {
|
||
int nums = 0;
|
||
sharer * sq = (sharer *)arg;
|
||
auto clk_st = std::chrono::high_resolution_clock::now();
|
||
double share_time = 0;
|
||
|
||
int num_procs, rank;
|
||
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
|
||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||
|
||
int from = (rank - 2 + num_procs - 1) % (num_procs - 1) + 1;
|
||
MPI_Irecv(buf, BUF_SIZE, MPI_INT, from, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, &receive_request);
|
||
|
||
while (true) {
|
||
++nums;
|
||
usleep(sq->share_intv);
|
||
auto clk_now = std::chrono::high_resolution_clock::now();
|
||
int solve_time = std::chrono::duration_cast<std::chrono::milliseconds>(clk_now - clk_st).count();
|
||
//LOGGER->info("round %v, time: %v.%v", nums, solve_time / 1000, solve_time % 1000);
|
||
|
||
printf("c [worker] round %d, time: %d.%d\n", nums, solve_time / 1000, solve_time % 1000);
|
||
if (terminated) {
|
||
|
||
MPI_Cancel(&receive_request);
|
||
for(int i=0; i<send_data_struct.size(); i++) {
|
||
auto& request = send_data_struct[i].first;
|
||
auto& send_buf = send_data_struct[i].second;
|
||
MPI_Cancel(request);
|
||
delete []send_buf;
|
||
}
|
||
break;
|
||
}
|
||
// printf("start sharing %d\n", sq->share_intv);
|
||
for (int i = 0; i < sq->producers.size(); i++) {
|
||
sq->cls.clear();
|
||
sq->producers[i]->export_clauses_to(sq->cls);
|
||
|
||
//printf("c size %d\n", sq->cls.size());
|
||
int number = sq->cls.size();
|
||
|
||
printf("c [worker] thread-%d: get %d exported clauses\n", i, number);
|
||
|
||
//分享当前节点产生的子句
|
||
if(sq->cls.size() > 0) share_clauses_to_next_node(sq->cls);
|
||
|
||
// 导入外部网络传输的子句
|
||
vec<clause_store*> clauses;
|
||
if(receive_clauses_from_last_node(clauses)) {
|
||
for (int j = 0; j < sq->consumers.size(); j++) {
|
||
for (int k = 0; k < clauses.size(); k++)
|
||
clauses[k]->increase_refs(1);
|
||
sq->consumers[j]->import_clauses_from(clauses);
|
||
}
|
||
|
||
// 传递外部网络传输的子句给下个节点
|
||
share_clauses_to_next_node(clauses);
|
||
|
||
for (int k = 0; k < clauses.size(); k++) {
|
||
clauses[k]->free_clause();
|
||
}
|
||
}
|
||
|
||
// 导入当前节点产生的子句
|
||
int percent = sq->sort_clauses(i);
|
||
if (percent < 75) {
|
||
sq->producers[i]->broaden_export_limit();
|
||
}
|
||
else if (percent > 98) {
|
||
sq->producers[i]->restrict_export_limit();
|
||
}
|
||
|
||
for (int j = 0; j < sq->consumers.size(); j++) {
|
||
if (sq->producers[i]->id == sq->consumers[j]->id) continue;
|
||
for (int k = 0; k < sq->cls.size(); k++)
|
||
sq->cls[k]->increase_refs(1);
|
||
sq->consumers[j]->import_clauses_from(sq->cls);
|
||
}
|
||
for (int k = 0; k < sq->cls.size(); k++) {
|
||
sq->cls[k]->free_clause();
|
||
}
|
||
}
|
||
auto clk_ed = std::chrono::high_resolution_clock::now();
|
||
share_time += 0.001 * std::chrono::duration_cast<std::chrono::milliseconds>(clk_ed - clk_now).count();
|
||
}
|
||
printf("c sharing nums: %d\nc sharing time: %.2lf\n", nums, share_time);
|
||
// if (terminated) puts("terminated set to 1");
|
||
return NULL;
|
||
}
|
||
|
||
int sharer::import_clauses(int id) {
|
||
int current_period = producers[id]->get_period() - 1, import_period = current_period - margin;
|
||
if (import_period < 0) return 0;
|
||
basesolver *t = producers[id];
|
||
for (int i = 0; i < producers.size(); i++) {
|
||
if (i == id) continue;
|
||
basesolver *s = producers[i];
|
||
//wait for s reach period $import_period
|
||
// printf("c %d start waiting, since import_p is %d, current_p is %d\n", id, import_period, s->get_period());
|
||
|
||
boost::mutex::scoped_lock lock(s->mtx);
|
||
while (s->period <= import_period && s->terminate_period > import_period && !s->terminated) {
|
||
s->cond.wait(lock);
|
||
}
|
||
|
||
if (s->terminated) return -1;
|
||
if (s->terminate_period <= import_period) return -1;
|
||
|
||
period_clauses *pc = s->pq.find(import_period);
|
||
if (pc->period != import_period) {
|
||
printf("thread %d, now period = %d, import period = %d, import thread %d\n", id, current_period, import_period, i);
|
||
puts("*******************************************************");
|
||
}
|
||
// printf("c %d finish waiting %d %d\n", id, import_period, s->period_queue[pos]->period);
|
||
t->import_clauses_from(pc->cls);
|
||
}
|
||
t->unfree.push(import_period);
|
||
return 1;
|
||
// printf("c thread %d, period %d, import finish\n", id, current_period);
|
||
}
|
||
|
||
int sharer::sort_clauses(int x) {
|
||
for (int i = 0; i < cls.size(); i++) {
|
||
int sz = cls[i]->size;
|
||
while (sz > bucket[x].size()) bucket[x].push();
|
||
if (sz * (bucket[x][sz - 1].size() + 1) <= share_lits)
|
||
bucket[x][sz - 1].push(cls[i]);
|
||
// else
|
||
// cls[i]->free_clause();
|
||
}
|
||
cls.clear();
|
||
int space = share_lits;
|
||
for (int i = 0; i < bucket[x].size(); i++) {
|
||
int clause_num = space / (i + 1);
|
||
// printf("%d %d\n", clause_num, bucket[x][i].size());
|
||
if (!clause_num) break;
|
||
if (clause_num >= bucket[x][i].size()) {
|
||
space -= bucket[x][i].size() * (i + 1);
|
||
for (int j = 0; j < bucket[x][i].size(); j++)
|
||
cls.push(bucket[x][i][j]);
|
||
bucket[x][i].clear();
|
||
}
|
||
else {
|
||
space -= clause_num * (i + 1);
|
||
for (int j = 0; j < clause_num; j++) {
|
||
cls.push(bucket[x][i].last());
|
||
bucket[x][i].pop();
|
||
}
|
||
}
|
||
}
|
||
return (share_lits - space) * 100 / share_lits;
|
||
}
|
||
|
||
void light::share() {
|
||
// printf("c sharing start\n");
|
||
if (OPT(DPS)) {
|
||
sharer* s = new sharer(0, OPT(share_intv), OPT(share_lits), OPT(DPS));
|
||
s->margin = OPT(margin);
|
||
for (int j = 0; j < OPT(threads); j++) {
|
||
s->producers.push(workers[j]);
|
||
workers[j]->in_sharer = s;
|
||
}
|
||
sharers.push(s);
|
||
}
|
||
else {
|
||
int sharers_number = 1;
|
||
for (int i = 0; i < sharers_number; i++) {
|
||
sharer* s = new sharer(i, OPT(share_intv), OPT(share_lits), OPT(DPS));
|
||
for (int j = 0; j < OPT(threads); j++) {
|
||
s->producers.push(workers[j]);
|
||
s->consumers.push(workers[j]);
|
||
workers[j]->in_sharer = s;
|
||
}
|
||
sharers.push(s);
|
||
}
|
||
|
||
sharer_ptrs = new pthread_t[sharers_number];
|
||
for (int i = 0; i < sharers_number; i++) {
|
||
pthread_create(&sharer_ptrs[i], NULL, share_worker, sharers[i]);
|
||
}
|
||
}
|
||
} |