cloud-sat/src/sharer.cpp

310 lines
10 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "light.hpp"
#include "solver_api/basesolver.hpp"
#include "sharer.hpp"
#include "unordered_map"
#include "clause.hpp"
#include <unistd.h>
#include "comm_tag.h"
#include <boost/thread/thread.hpp>
int nums = 0;
double share_time = 0;
int num_procs, rank;
const int BUF_SIZE = 100 * 1024 * 1024;
std::vector<std::pair<MPI_Request*, int*>> send_data_struct;
MPI_Request receive_request;
int buf[BUF_SIZE];
int num_received_clauses_by_network = 0;
int num_skip_clauses_by_network = 0;
// 记录子句是否已经导入过
std::unordered_map<int, bool> clause_imported;
void sharer::share_clauses_to_next_node(int from, const std::vector<shared_ptr<clause_store>> &cls) {
// 环形传递,数据来源如果是目的地,说明数据已轮转一圈,停止发送。
if(from == S->next_node) return;
// 定义发送数据
MPI_Request *send_request = new MPI_Request();
int *send_buf;
int send_length = 1;
// 初始化发送数据
for(int i=0; i<cls.size(); i++) {
send_length += (cls[i]->size + 2);
}
send_buf = new int[send_length];
int index = 0;
send_buf[index++] = from;
for(int i=0; i<cls.size(); i++) {
auto& c = cls[i];
send_buf[index++] = c->size;
send_buf[index++] = c->lbd;
for(int j=0; j<c->size; j++) {
send_buf[index++] = c->data[j];
}
}
assert(index == send_length);
// 调用 MPI 发送共享子句
MPI_Isend(send_buf, send_length, MPI_INT, S->next_node, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, send_request);
send_data_struct.push_back(std::make_pair(send_request, send_buf));
// printf("c [worker] send clauses: %d\n", send_length);
// 清理 send_data_struct把发送完毕的发送数据结构清理掉
for(int i=0; i<send_data_struct.size(); i++) {
// 已完成发送,释放内存空间
int flag;
if(MPI_Test(send_data_struct[i].first, &flag, MPI_STATUS_IGNORE) == MPI_SUCCESS && flag == 1) {
delete send_data_struct[i].first;
delete []send_data_struct[i].second;
// 与数组最后一个交换,然后 pop_back;
std::swap(send_data_struct[i], send_data_struct[send_data_struct.size()-1]);
send_data_struct.pop_back();
// printf("c [worker] free send request, now: %d\n", send_data_struct.size());
}
}
}
int sharer::receive_clauses_from_last_node(std::vector<shared_ptr<clause_store>> &clauses, int &transmitter) {
clauses.clear();
int flag;
MPI_Status status;
transmitter = -1;
int from = -1;
// 已接收完数据
while(MPI_Test(&receive_request, &flag, &status) == MPI_SUCCESS && flag == 1) {
int index = 0;
int count;
MPI_Get_count(&status, MPI_INT, &count);
if(transmitter == -1) {
transmitter = status.MPI_SOURCE;
}
assert(transmitter == status.MPI_SOURCE);
from = buf[index++];
while(index < count) {
num_received_clauses_by_network++;
shared_ptr<clause_store> cl = std::make_shared<clause_store>(buf[index++]);
cl->lbd = buf[index++];
for(int i=0; i<cl->size; i++) {
cl->data[i] = buf[index++];
}
if(clause_imported[cl->hash_code()]) {
num_skip_clauses_by_network++;
continue;
}
clauses.push_back(cl);
}
assert(index == count);
MPI_Irecv(buf, BUF_SIZE, MPI_INT, MPI_ANY_SOURCE, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, &receive_request);
}
return from;
}
void sharer::clause_sharing_init() {
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Irecv(buf, BUF_SIZE, MPI_INT, MPI_ANY_SOURCE, SHARE_CLAUSES_TAG, MPI_COMM_WORLD, &receive_request);
}
void sharer::clause_sharing_end() {
printf("c node%d sharing nums: %d\nc sharing time: %.2lf\n", rank, nums, share_time);
printf("c node%d sharing received_num_by_network: %d\n", rank, num_received_clauses_by_network);
printf("c node%d sharing skip_num_by_network: %d\n", rank, num_skip_clauses_by_network);
printf("c node%d sharing unique reduce percentage: %.2f%%\n", rank, (double) num_skip_clauses_by_network / num_received_clauses_by_network * 100);
}
void sharer::do_clause_sharing() {
static auto clk_st = std::chrono::high_resolution_clock::now();
++nums;
auto clk_now = std::chrono::high_resolution_clock::now();
int solve_time = std::chrono::duration_cast<std::chrono::milliseconds>(clk_now - clk_st).count();
printf("c node%d(%d)round %d, time: %d.%d\n", rank, S->worker_type, nums, solve_time / 1000, solve_time % 1000);
// 导入外部网络传输的子句
std::vector<shared_ptr<clause_store>> clauses;
int transmitter;
int from = receive_clauses_from_last_node(clauses, transmitter);
if(from != -1 && clauses.size() > 0) {
printf("c node%d(%d)->%d get %d exported clauses from node-%d\n", rank, S->worker_type, S->next_node, clauses.size(), transmitter);
// printf("c [node-%d] sharing unique reduce percentage: %.2f%%\n", rank, (double) num_skip_clauses_by_network / num_received_clauses_by_network * 100);
for (int j = 0; j < consumers.size(); j++) {
consumers[j]->import_clauses_from(clauses);
}
for (int k = 0; k < clauses.size(); k++) {
clause_imported[clauses[k]->hash_code()] = true;
}
// 传递外部网络传输的子句给下个节点
share_clauses_to_next_node(from, clauses);
}
// printf("start sharing %d\n", sq->share_intv);
for (int i = 0; i < producers.size(); i++) {
cls.clear();
producers[i]->export_clauses_to(cls);
// 删除掉重复的学习子句
int t_size = cls.size();
for(int i=0; i<t_size; i++) {
if(clause_imported[cls[i]->hash_code()]) {
std::swap(cls[i], cls[t_size-1]);
t_size--;
}
}
cls.resize(t_size);
//分享当前节点产生的子句
if(cls.size() > 0) share_clauses_to_next_node(rank, cls);
//printf("c [worker] thread-%d: get %d exported clauses\n", i, t_size);
// 增加 lits 限制
int percent = sort_clauses(i);
if(S->worker_type == light::UNSAT) {
if (percent < 75) {
producers[i]->broaden_export_limit();
}
else if (percent > 98) {
producers[i]->restrict_export_limit();
}
}
for (int j = 0; j < consumers.size(); j++) {
if (producers[i]->id == consumers[j]->id) continue;
consumers[j]->import_clauses_from(cls);
}
for (int k = 0; k < cls.size(); k++) {
clause_imported[cls[k]->hash_code()] = true;
}
}
auto clk_ed = std::chrono::high_resolution_clock::now();
share_time += 0.001 * std::chrono::duration_cast<std::chrono::milliseconds>(clk_ed - clk_now).count();
}
int sharer::import_clauses(int id) {
int current_period = producers[id]->get_period() - 1, import_period = current_period - OPT(margin);
if (import_period < 0) return 0;
basesolver *t = producers[id];
for (int i = 0; i < producers.size(); i++) {
if (i == id) continue;
basesolver *s = producers[i];
//wait for s reach period $import_period
// printf("c %d start waiting, since import_p is %d, current_p is %d\n", id, import_period, s->get_period());
boost::mutex::scoped_lock lock(s->mtx);
while (s->period <= import_period && s->terminate_period > import_period && !s->terminated) {
s->cond.wait(lock);
}
if (s->terminated) return -1;
if (s->terminate_period <= import_period) return -1;
period_clauses *pc = s->pq.find(import_period);
if (pc->period != import_period) {
printf("thread %d, now period = %d, import period = %d, import thread %d\n", id, current_period, import_period, i);
puts("*******************************************************");
}
// printf("c %d finish waiting %d %d\n", id, import_period, s->period_queue[pos]->period);
t->import_clauses_from(pc->cls);
}
t->unfree.push(import_period);
return 1;
// printf("c thread %d, period %d, import finish\n", id, current_period);
}
int sharer::sort_clauses(int x) {
for (int i = 0; i < cls.size(); i++) {
int sz = cls[i]->size;
while (sz > bucket[x].size()) bucket[x].push();
if (sz * (bucket[x][sz - 1].size() + 1) <= OPT(share_lits))
bucket[x][sz - 1].push_back(cls[i]);
// else
// cls[i]->free_clause();
}
cls.clear();
int space = OPT(share_lits);
for (int i = 0; i < bucket[x].size(); i++) {
int clause_num = space / (i + 1);
// printf("%d %d\n", clause_num, bucket[x][i].size());
if (!clause_num) break;
if (clause_num >= bucket[x][i].size()) {
space -= bucket[x][i].size() * (i + 1);
for (int j = 0; j < bucket[x][i].size(); j++)
cls.push_back(bucket[x][i][j]);
bucket[x][i].clear();
}
else {
space -= clause_num * (i + 1);
for (int j = 0; j < clause_num; j++) {
cls.push_back(bucket[x][i].back());
bucket[x][i].pop_back();
}
}
}
return (OPT(share_lits) - space) * 100 / OPT(share_lits);
}
// void light::share() {
// // printf("c sharing start\n");
// if (OPT(DPS)) {
// sharer* s = new sharer(0, OPT(share_intv), OPT(share_lits), OPT(DPS));
// s->margin = OPT(margin);
// for (int j = 0; j < OPT(threads); j++) {
// s->producers.push(workers[j]);
// workers[j]->in_sharer = s;
// }
// sharers.push(s);
// }
// else {
// int sharers_number = 1;
// for (int i = 0; i < sharers_number; i++) {
// sharer* s = new sharer(i, OPT(share_intv), OPT(share_lits), OPT(DPS));
// for (int j = 0; j < OPT(threads); j++) {
// s->producers.push(workers[j]);
// s->consumers.push(workers[j]);
// workers[j]->in_sharer = s;
// }
// sharers.push(s);
// }
// sharer_ptrs = new pthread_t[sharers_number];
// for (int i = 0; i < sharers_number; i++) {
// pthread_create(&sharer_ptrs[i], NULL, share_worker, sharers[i]);
// }
// }
// }