修改成字句ls

This commit is contained in:
YuhangQ 2023-03-17 05:37:29 +00:00
parent 10b021e407
commit 4ba4c8e82e
13 changed files with 407 additions and 77 deletions

2
.gitignore vendored
View File

@ -1,3 +1,3 @@
*.o *.o
*.d *.d*
output.txt output.txt

2
.vscode/launch.json vendored
View File

@ -9,7 +9,7 @@
"type": "cppdbg", "type": "cppdbg",
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/atpg", "program": "${workspaceFolder}/atpg",
"args": ["c432.bench"], "args": ["test.bench"],
"stopAtEntry": false, "stopAtEntry": false,
"cwd": "${fileDirname}", "cwd": "${fileDirname}",
"environment": [], "environment": [],

1
CNC-LS Submodule

@ -0,0 +1 @@
Subproject commit d9122607522ea757b3412f1cff247b9db6c79c55

BIN
atpg

Binary file not shown.

View File

@ -4,6 +4,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "assert.h" #include "assert.h"
#include "clause.h"
void Circuit::init_topo_index() { void Circuit::init_topo_index() {
@ -28,6 +29,10 @@ void Circuit::init_topo_index() {
} }
} }
for(Gate* g : gates) {
id2gate[g->id] = g;
}
// 计算反向拓扑序 // 计算反向拓扑序
topo = 1; topo = 1;
std::unordered_map<Gate*, int> outs; std::unordered_map<Gate*, int> outs;

View File

@ -61,7 +61,7 @@ std::vector<Gate*> POs;
std::vector<Gate*> gates; std::vector<Gate*> gates;
std::vector<Gate*> rtopo_gates; std::vector<Gate*> rtopo_gates;
std::vector<Gate*> stems; // PIs and POs are stems by default std::vector<Gate*> stems; // PIs and POs are stems by default
std::unordered_map<int, Gate*> id2gate;
std::unordered_map<std::string, Gate*> name2gate; std::unordered_map<std::string, Gate*> name2gate;
// 读入和输出电路 // 读入和输出电路
@ -100,12 +100,13 @@ void ls_statistics();
void ls_update_weight(); void ls_update_weight();
Gate* ls_pick(); int ls_pick();
Gate* ls_pick_falsified(); Gate* ls_pick_falsified();
void ls_flip(Gate* stem); void ls_flip_var(int var);
void ls_flip_stem(Gate* stem);
ll ls_pick_score(Gate* stem); ll ls_pick_score(int var);
ll ls_circuit_score(); ll ls_circuit_score();

275
clause.cpp Normal file
View File

@ -0,0 +1,275 @@
#include "clause.h"
ll Clause::total_cost;
void Clause::update_satisfied_lit_count() {
using namespace ClauseLS;
ll old_lit_count = satisfied_lit_count;
satisfied_lit_count = 0;
for(auto& lit : lits) {
if(lit_value[abs(lit)]) {
satisfied_lit_count += (lit > 0);
} else {
satisfied_lit_count += (lit < 0);
}
}
if(old_lit_count == 0 && satisfied_lit_count > 0) {
falsified_clauses.erase(this);
total_cost -= weight;
satisfied_clauses.insert(this);
}
if(old_lit_count > 0 && satisfied_lit_count == 0) {
satisfied_clauses.erase(this);
total_cost += weight;
falsified_clauses.insert(this);
}
}
namespace ClauseLS {
int num_vars;
int num_clauses;
std::unordered_set<Clause*> satisfied_clauses;
std::unordered_set<Clause*> falsified_clauses;
std::vector<Clause*> clauses;
std::vector<Clause*> *lit_related_clauses;
int *lit_value;
int *CC;
void flip(int var) {
// printf("value: [ ");
// for(int i=1; i<=num_vars; i++) {
// printf("%d ", lit_value[i]);
// }
// printf("]\n");
// for(auto& c : clauses) {
// printf("lits: [ ");
// for(auto& lit : c->lits) {
// printf("%d ", lit);
// }
// printf(" ] satifs_cnt: %d\n", c->satisfied_lit_count);
// }
lit_value[var] = !lit_value[var];
for(auto& clause : lit_related_clauses[var]) {
clause->update_satisfied_lit_count();
for(auto& lit : clause->lits) {
CC[abs(lit)] = 1;
}
}
CC[var] = 0;
}
void add_to_tmp_clause(int x) {
static std::vector<int> lits;
if(x != 0) {
lits.push_back(x);
} else {
Clause* clause = new Clause();
clause->lits = lits;
clause->satisfied_lit_count = 0;
clauses.push_back(clause);
lits.clear();
}
}
void init_data_structs(Circuit* circuit) {
build_clauses(circuit);
printf("====== Clause Statistics ====== \n");
printf("num_vars:\t%d\n", num_vars);
printf("num_clauses:\t%ld\n", num_clauses);
printf("================================ \n");
lit_related_clauses = new std::vector<Clause*>[num_vars + 1];
lit_value = new int[num_vars + 1];
CC = new int[num_vars + 1];
for(int i=1; i<=num_vars; i++) {
lit_value[i] = 0;
CC[i] = 1;
}
for(auto& clause : clauses) {
clause->weight = 1;
Clause::total_cost += clause->weight;
falsified_clauses.insert(clause);
//printf("fs: %d. ss: %d\n", falsified_clauses.size(), satisfied_clauses.size());
clause->update_satisfied_lit_count();
for(auto& lit : clause->lits) {
lit_related_clauses[abs(lit)].push_back(clause);
}
}
}
void reset_data_structs() {
for(int i=1; i<=num_vars; i++) {
lit_value[i] = 0;
}
}
void build_clauses(Circuit* circuit) {
int extra_variable = circuit->gates.size();
int last_id, last_v;
for(Gate* g : circuit->gates) {
if(g->pi) continue;
switch(g->type) {
case Gate::NOT:
add_to_tmp_clause(-g->id);
add_to_tmp_clause(-(g->fan_ins[0]->id));
add_to_tmp_clause(0);
add_to_tmp_clause(g->id);
add_to_tmp_clause(-(g->fan_ins[0]->id));
add_to_tmp_clause(0);
break;
case Gate::BUF:
add_to_tmp_clause(-g->id);
add_to_tmp_clause((g->fan_ins[0]->id));
add_to_tmp_clause(0);
add_to_tmp_clause(g->id);
add_to_tmp_clause(-(g->fan_ins[0]->id));
add_to_tmp_clause(0);
break;
case Gate::AND:
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(-g->id);
add_to_tmp_clause(g->fan_ins[i]->id);
add_to_tmp_clause(0);
}
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(-g->fan_ins[i]->id);
}
add_to_tmp_clause(g->id);
add_to_tmp_clause(0);
break;
case Gate::NAND:
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(g->id);
add_to_tmp_clause(g->fan_ins[i]->id);
add_to_tmp_clause(0);
}
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(-g->fan_ins[i]->id);
}
add_to_tmp_clause(-g->id);
add_to_tmp_clause(0);
break;
case Gate::OR:
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(g->fan_ins[i]->id);
}
add_to_tmp_clause(-g->id);
add_to_tmp_clause(0);
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(g->id);
add_to_tmp_clause(-g->fan_ins[i]->id);
add_to_tmp_clause(0);
}
break;
case Gate::NOR:
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(g->fan_ins[i]->id);
}
add_to_tmp_clause(g->id);
add_to_tmp_clause(0);
for(int i=1; i<g->fan_ins.size(); i++) {
add_to_tmp_clause(-g->id);
add_to_tmp_clause(-g->fan_ins[i]->id);
add_to_tmp_clause(0);
}
break;
case Gate::XOR:
last_v = g->fan_ins[0]->id;
for(int i=1; i<g->fan_ins.size() - 1; i++) {
int new_v = ++extra_variable;
add_to_tmp_clause(-g->fan_ins[i]->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(-new_v);
add_to_tmp_clause(0);
add_to_tmp_clause(-g->fan_ins[i]->id); add_to_tmp_clause(last_v); add_to_tmp_clause(new_v);
add_to_tmp_clause(0);
add_to_tmp_clause(g->fan_ins[i]->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(new_v);
add_to_tmp_clause(0);
add_to_tmp_clause(g->fan_ins[i]->id); add_to_tmp_clause(last_v); add_to_tmp_clause(-new_v);
add_to_tmp_clause(0);
last_v = new_v;
}
last_id = g->fan_ins[g->fan_ins.size() - 1]->id;
add_to_tmp_clause(-g->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(-last_id);
add_to_tmp_clause(0);
add_to_tmp_clause(-g->id); add_to_tmp_clause(last_v); add_to_tmp_clause(last_id);
add_to_tmp_clause(0);
add_to_tmp_clause(g->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(last_id);
add_to_tmp_clause(0);
add_to_tmp_clause(g->id); add_to_tmp_clause(last_v); add_to_tmp_clause(-last_id);
add_to_tmp_clause(0);
break;
case Gate::XNOR:
last_v = g->fan_ins[0]->id;
for(int i=1; i<g->fan_ins.size() - 1; i++) {
int new_v = ++extra_variable;
add_to_tmp_clause(-g->fan_ins[i]->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(-new_v);
add_to_tmp_clause(0);
add_to_tmp_clause(-g->fan_ins[i]->id); add_to_tmp_clause(last_v); add_to_tmp_clause(new_v);
add_to_tmp_clause(0);
add_to_tmp_clause(g->fan_ins[i]->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(new_v);
add_to_tmp_clause(0);
add_to_tmp_clause(g->fan_ins[i]->id); add_to_tmp_clause(last_v); add_to_tmp_clause(-new_v);
add_to_tmp_clause(0);
last_v = new_v;
}
last_id = g->fan_ins[g->fan_ins.size() - 1]->id;
add_to_tmp_clause(g->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(-last_id);
add_to_tmp_clause(0);
add_to_tmp_clause(g->id); add_to_tmp_clause(last_v); add_to_tmp_clause(last_id);
add_to_tmp_clause(0);
add_to_tmp_clause(-g->id); add_to_tmp_clause(-last_v); add_to_tmp_clause(last_id);
add_to_tmp_clause(0);
add_to_tmp_clause(-g->id); add_to_tmp_clause(last_v); add_to_tmp_clause(-last_id);
add_to_tmp_clause(0);
break;
default:
exit(-1);
break;
}
}
num_vars = extra_variable;
num_clauses = clauses.size();
}
}

36
clause.h Normal file
View File

@ -0,0 +1,36 @@
#pragma once
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "circuit.h"
class Clause {
public:
int satisfied_lit_count;
int weight;
static ll total_cost;
std::vector<int> lits;
void update_satisfied_lit_count();
};
namespace ClauseLS {
extern int num_vars;
extern int num_clauses;
extern std::vector<Clause*> clauses;
extern std::vector<Clause*> *lit_related_clauses;
extern int *lit_value;
extern int *CC;
extern std::unordered_set<Clause*> satisfied_clauses;
extern std::unordered_set<Clause*> falsified_clauses;
void init_data_structs(Circuit* circuit);
void reset_data_structs();
void add_to_tmp_clause(int x);
void build_clauses(Circuit* circuit);
void flip(int var);
}

10
gate_encode.txt Normal file
View File

@ -0,0 +1,10 @@
NOT:
1.
a <=> ~b
2.
a => ~b
~b => a
3.

124
ls.cpp
View File

@ -7,6 +7,8 @@
#include "assert.h" #include "assert.h"
#include <chrono> #include <chrono>
#include "clause.h"
bool Circuit::local_search() { bool Circuit::local_search() {
ls_reset_data(); ls_reset_data();
@ -20,25 +22,21 @@ bool Circuit::local_search() {
for(int i=0; i<MAX_STEPS; i++) { for(int i=0; i<MAX_STEPS; i++) {
auto start = std::chrono::system_clock::now(); auto start = std::chrono::system_clock::now();
printf("[FLIP] stem: %lld, fault:%lld, stem_cnt: %lld, fault_cnt:%lld, fpl_score: %lld citcuit-score: %lld\n", stem_total_cost, fault_total_weight, stem_total_cnt, fault_total_cnt, fault_propagate_score, ls_circuit_score());
//printf("[FLIP] stem: %lld, fault:%lld, stem_cnt: %lld, fault_cnt:%lld, fpl_score: %lld citcuit-score: %lld\n", stem_total_cost, fault_total_weight, stem_total_cnt, fault_total_cnt, fault_propagate_score, ls_circuit_score()); int id = ls_pick();
Gate* stem = ls_pick(); printf("pick: %d\n", id);
if(stem == nullptr) { //printf("flip: %d. fc: %d. %d\n", id, ClauseLS::falsified_clauses.size(), Clause::total_cost);
//printf("[UP] stem: %lld, fault:%lld, stem_cnt: %lld, fault_cnt:%lld, fpl_score: %lld citcuit-score: %lld\n", stem_total_cost, fault_total_weight, stem_total_cnt, fault_total_cnt, fault_propagate_score, ls_circuit_score());
ls_update_weight();
stem = ls_pick_falsified();
}
if(stem_total_cnt == stems.size()) { if(stem_total_cnt == stems.size()) {
//printf("FIND SOLUTION!\n"); //printf("FIND SOLUTION!\n");
//printf("[SOL] flip: %lld, stem: %lld, fault:%lld. flip_cnt: %d, stem_cnt: %d, fault_cnt:%d\n", flip_total_weight, stem_total_weight, fault_total_weight, flip_total_cnt, stem_total_cnt, fault_total_cnt); printf("[UP] stem: %lld, fault:%lld, stem_cnt: %lld, fault_cnt:%lld, fpl_score: %lld citcuit-score: %lld\n", stem_total_cost, fault_total_weight, stem_total_cnt, fault_total_cnt, fault_propagate_score, ls_circuit_score());
break; break;
} }
ls_flip(stem); ls_flip_var(id);
assert(is_valid_circuit()); assert(is_valid_circuit());
@ -82,11 +80,12 @@ void Circuit::ls_statistics() {
void Circuit::ls_update_weight() { void Circuit::ls_update_weight() {
if(rand() % 10000 <= SP * 10000) { if(rand() % 10000 <= SP * 10000) {
for(Gate* g : gates) { for(auto& clause : ClauseLS::satisfied_clauses) {
if(g->stem && g->stem_satisfied && (g->stem_weight - STEM_INC >= 1)) { if(clause->weight - CLAUSE_FALSIFIED_INC < 1) continue;
g->stem_weight -= STEM_INC; clause->weight -= CLAUSE_FALSIFIED_INC;
} }
for(Gate* g : gates) {
if(g->fault_detected[0] && g->fault_weight[0] - FAULT_INC >= 1) { if(g->fault_detected[0] && g->fault_weight[0] - FAULT_INC >= 1) {
g->fault_weight[0] -= FAULT_INC; g->fault_weight[0] -= FAULT_INC;
fault_propagate_score -= FAULT_INC * (g->fault_propagate_length[0]); fault_propagate_score -= FAULT_INC * (g->fault_propagate_length[0]);
@ -98,21 +97,13 @@ void Circuit::ls_update_weight() {
} }
} }
} else { } else {
for(auto& clause : ClauseLS::falsified_clauses) {
if(clause->weight + CLAUSE_FALSIFIED_INC > CLAUSE_FALSIFIED_MAX) continue;
clause->weight += CLAUSE_FALSIFIED_INC;
clause->total_cost += CLAUSE_FALSIFIED_INC;
}
for(Gate* g : gates) { for(Gate* g : gates) {
if(g->stem && !g->stem_satisfied && (g->stem_weight + STEM_INC < STEM_WEIGHT_MAX)) {
g->stem_weight += STEM_INC;
stem_total_cost += STEM_INC;
for(Gate* suc : g->suc_stems) {
if(suc->stem_weight - STEM_INC >= 1) {
suc->stem_weight -= STEM_INC;
if(!suc->stem_satisfied) {
stem_total_cost -= STEM_INC;
}
}
}
}
if(!g->fault_detected[0] && g->fault_weight[0] > 0 && (g->fault_weight[0] + FAULT_INC < FAULT_WEIGHT_MAX)) { if(!g->fault_detected[0] && g->fault_weight[0] > 0 && (g->fault_weight[0] + FAULT_INC < FAULT_WEIGHT_MAX)) {
g->fault_weight[0] += FAULT_INC; g->fault_weight[0] += FAULT_INC;
fault_propagate_score += FAULT_INC * (g->fault_propagate_length[0]); fault_propagate_score += FAULT_INC * (g->fault_propagate_length[0]);
@ -126,57 +117,39 @@ void Circuit::ls_update_weight() {
} }
} }
Gate* Circuit::ls_pick() { int Circuit::ls_pick() {
Gate* stem = nullptr; int var = -1;
ll max_score = 0; ll max_score = 0;
std::vector<Gate*> stems_random;
std::vector<Gate*> candidates;
for(int i=0; i<stems.size(); i++) { for(int i=0; i<SAMPLING_COUNT; i++) {
if(stems[i]->CC) { int t_var = rand() % ClauseLS::num_vars + 1;
stems_random.push_back(stems[i]); if(!ClauseLS::CC[t_var]) continue;
}
}
for(int i=0; i<stems_random.size(); i++) { ll t_score = ls_pick_score(t_var);
std::swap(stems_random[i], stems_random[rand()%stems_random.size()]);
}
const int max_index = std::min((int)stems_random.size(), SAMPLING_COUNT);
for(int i=0; i<max_index; i++) {
Gate* t_stem = stems_random[i];
ll t_score = ls_pick_score(t_stem);
if(t_score > max_score) { if(t_score > max_score) {
max_score = t_score; max_score = t_score;
stem = t_stem; var = t_var;
} }
} }
return stem; if(var == -1) {
} ls_update_weight();
Gate* Circuit::ls_pick_falsified() { printf("[UP]\n");
std::vector<Gate*> candidates;
for(Gate *g : stems) {
if(g->stem_satisfied) continue;
// for(Gate* pre : g->pre_stems) printf("fals: %d\n", ClauseLS::falsified_clauses.size());
// candidates.push_back(pre);
// for(Gate* suc : g->suc_stems) auto it = std::next(ClauseLS::falsified_clauses.begin(), rand() % ClauseLS::falsified_clauses.size());
// candidates.push_back(suc);
candidates.push_back(g); auto& lits = (*it)->lits;
var = abs(lits[rand()%lits.size()]);
} }
if(candidates.size() == 0) { assert(var != -1);
candidates.push_back(stems[rand()%stems.size()]);
}
return candidates[rand()%candidates.size()]; return var;
} }
void Circuit::ls_init_stems() { void Circuit::ls_init_stems() {
@ -241,15 +214,27 @@ void Circuit::ls_init_stems() {
} }
} }
ll Circuit::ls_pick_score(Gate* stem) { void Circuit::ls_flip_var(int var) {
ClauseLS::flip(var);
if(id2gate.count(var) && id2gate[var]->stem) {
ls_flip_stem(id2gate[var]);
}
}
ll Circuit::ls_pick_score(int var) {
ll old_score = ls_circuit_score(); ll old_score = ls_circuit_score();
ls_flip(stem); ls_flip_var(var);
ll new_score = ls_circuit_score(); ll new_score = ls_circuit_score();
ls_flip(stem); ls_flip_var(var);
assert(old_score == ls_circuit_score()); assert(old_score == ls_circuit_score());
@ -257,7 +242,8 @@ ll Circuit::ls_pick_score(Gate* stem) {
} }
ll Circuit::ls_circuit_score() { ll Circuit::ls_circuit_score() {
ll score = -stem_total_cost + fault_propagate_score + fault_total_weight; //ll score = - Clause::total_cost + fault_propagate_score + fault_total_weight;
ll score = - Clause::total_cost;
return score; return score;
} }
@ -277,7 +263,7 @@ void Circuit::ls_random_circuit() {
// init assignment // init assignment
for(Gate* s : stems) { for(Gate* s : stems) {
s->value = rand() % 2; s->value = ClauseLS::lit_value[s->id];
} }
// recal value by topo // recal value by topo
@ -345,7 +331,7 @@ void Circuit::ls_reset_data() {
} }
void Circuit::ls_flip(Gate* stem) { void Circuit::ls_flip_stem(Gate* stem) {
stem->value = !stem->value; stem->value = !stem->value;
// update CC // update CC

View File

@ -3,6 +3,7 @@
#include <assert.h> #include <assert.h>
#include "circuit.h" #include "circuit.h"
#include "clause.h"
int main(int args, char* argv[]) { int main(int args, char* argv[]) {
@ -29,6 +30,9 @@ int main(int args, char* argv[]) {
circuit->global_fault_undetected_count = circuit->gates.size() * 2; circuit->global_fault_undetected_count = circuit->gates.size() * 2;
// init clause local search
ClauseLS::init_data_structs(circuit);
while(true) { while(true) {
bool ls = circuit->local_search(); bool ls = circuit->local_search();
bool is_valid = circuit->is_valid_circuit(); bool is_valid = circuit->is_valid_circuit();

View File

@ -1,9 +1,9 @@
#pragma once #pragma once
const double SP = 0.01; const double SP = 0.01;
const int MAX_STEPS = 10000; const int MAX_STEPS = 1000000000;
const int SAMPLING_COUNT = 25; const int SAMPLING_COUNT = 1000;
const int STEM_INC = 2; const int STEM_INC = 2;
const int STEM_WEIGHT_MAX = 1e9; const int STEM_WEIGHT_MAX = 1e9;
@ -11,3 +11,5 @@ const int STEM_WEIGHT_MAX = 1e9;
const int FAULT_INC = 1; const int FAULT_INC = 1;
const int FAULT_WEIGHT_MAX = 20; const int FAULT_WEIGHT_MAX = 20;
const int CLAUSE_FALSIFIED_INC = 2;
const int CLAUSE_FALSIFIED_MAX = 1e9;

10
test.bench Normal file
View File

@ -0,0 +1,10 @@
INPUT(1)
INPUT(2)
INPUT(3)
OUTPUT(5)
4 = XOR(1, 2)
5 = XOR(3, 4)