#include "preprocess.hpp" #include #include "m4ri/m4ri.h" #define MAX_XOR 6 bool cmpvar(int x, int y) { return abs(x) < abs(y); } int preprocess::cal_dup_val(int i) { for (int j = 0; j < clause[i].size(); j++) a[j] = clause[i][j]; std::sort(a, a + clause[i].size(), cmpvar); int v = 0; for (int j = 0; j < clause[i].size(); j++) if (a[j] < 0) v |= (1 << j); return v; } int preprocess::search_xors() { vec xorsp; vec dup_table; for (int i = 1; i <= vars; i++) { seen[i] = 0; occurp[i].clear(); occurn[i].clear(); } for (int i = 1; i <= clauses; i++) { abstract[i] = clause_delete[i] = nxtc[i] = 0; int l = clause[i].size(); for (int j = 0; j < l; j++) { if (clause[i][j] > 0) occurp[clause[i][j]].push(i); else occurn[-clause[i][j]].push(i); abstract[i] |= 1 << (abs(clause[i][j]) & 31); } } for (int i = 1; i <= clauses; i++) { if (nxtc[i] || clause_delete[i]) continue; nxtc[i] = 1; int l = clause[i].size(); if (l <= 2 || l > MAX_XOR) continue; int required_num = 1 << (l - 2), skip = 0, mino = clauses + 1, mino_id = 0; for (int j = 0; j < l; j++) { int idx = abs(clause[i][j]); if (occurp[idx].size() < required_num || occurn[idx].size() < required_num) { skip = 1; break; } if (occurp[idx].size() + occurn[idx].size() < mino) { mino = occurp[idx].size() + occurn[idx].size(); mino_id = idx; } } if (skip) continue; xorsp.clear(); xorsp.push(i); for (int j = 0; j < occurp[mino_id].size(); j++) { int o = occurp[mino_id][j]; if (!clause_delete[o] && !nxtc[o] && clause[o].size() == l && abstract[o] == abstract[i]) xorsp.push(o); } for (int j = 0; j < occurn[mino_id].size(); j++) { int o = occurn[mino_id][j]; if (!clause_delete[o] && !nxtc[o] && clause[o].size() == l && abstract[o] == abstract[i]) xorsp.push(o); } if (xorsp.size() < 2 * required_num) continue; int rhs[2] = {0, 0}; for (int j = 0; j < l; j++) seen[abs(clause[i][j])] = i; dup_table.clear(); dup_table.growTo(1 << MAX_XOR, false); for (int j = 0; j < xorsp.size(); j++) { int o = xorsp[j], dup_v; bool xor_sign = true; for (int k = 0; k < clause[o].size(); k++) { if (!seen[abs(clause[o][k])]) goto Next; if (clause[o][k] < 0) xor_sign = !xor_sign; } dup_v = cal_dup_val(o); if (dup_table[dup_v]) continue; dup_table[dup_v] = true; rhs[xor_sign]++; if (j) nxtc[o] = 1; Next:; } assert(rhs[0] <= 2 * required_num); assert(rhs[1] <= 2 * required_num); if (rhs[0] == 2 * required_num) xors.push(xorgate(i, 0, l)); if (rhs[1] == 2 * required_num) xors.push(xorgate(i, 1, l)); } return xors.size(); } bool cmpxorgate(xorgate a, xorgate b) { return a.sz > b.sz; } int preprocess::ecc_var() { scc_id.clear(); scc_id.growTo(vars + 1, -1); scc.clear(); //sort(xors.data, xors.data + xors.sz, cmpxorgate); vec xids; for (int i = 0; i < xors.size(); i++) { int x = xors[i].c; xids.clear(); for (int j = 0; j < clause[x].size(); j++) if (scc_id[abs(clause[x][j])] != -1) xids.push(scc_id[abs(clause[x][j])]); if (xids.size() == 0) { scc.push(); for (int j = 0; j < clause[x].size(); j++) { scc_id[abs(clause[x][j])] = scc.size() - 1; scc[scc.size() - 1].push(abs(clause[x][j])); } } else if (xids.size() == 1) { int id = xids[0]; for (int j = 0; j < clause[x].size(); j++) { if (scc_id[abs(clause[x][j])] == -1) { scc_id[abs(clause[x][j])] = id; scc[id].push(abs(clause[x][j])); } } } else { int id_max = -1; int sz_max = 0; for (int j = 0; j < xids.size(); j++) if (scc[xids[j]].size() > sz_max) sz_max = scc[xids[j]].size(), id_max = xids[j]; for (int j = 0; j < xids.size(); j++) { if (xids[j] != id_max) { vec& v = scc[xids[j]]; for (int k = 0; k < v.size(); k++) { scc_id[v[k]] = id_max; scc[id_max].push(v[k]); } v.clear(); } } for (int j = 0; j < clause[x].size(); j++) { if (scc_id[abs(clause[x][j])] == -1) { scc_id[abs(clause[x][j])] = id_max; scc[id_max].push(abs(clause[x][j])); } } } } return scc.size(); } int preprocess::ecc_xor() { for (int i = 0; i < scc.size(); i++) seen[i] = -1; for (int i = 0; i < xors.size(); i++) { int id = scc_id[abs(clause[xors[i].c][0])]; if (seen[id] == -1) xor_scc.push(), seen[id] = xor_scc.size() - 1; int id2 = seen[id]; xor_scc[id2].push(i); } return xor_scc.size(); } int preprocess::gauss_elimination() { gauss_eli_unit = gauss_eli_binary = 0; vec v2mzd(vars + 1, -1); vec mzd2v; for (int i = 0; i < xor_scc.size(); i++) { if (xor_scc[i].size() == 1) continue; int id = scc_id[abs(clause[xors[xor_scc[i][0]].c][0])]; assert(scc[id].size() > 3); if (scc[id].size() > 1e7 / xor_scc[i].size()) continue; mzd2v.clear(); std::sort(scc[id].data, scc[id].data + scc[id].size(), cmpvar); for (int j = 0; j < scc[id].size(); j++) { assert(scc[id][j] > 0); assert(scc[id][j] <= vars); v2mzd[scc[id][j]] = j; mzd2v.push(scc[id][j]); } int cols = scc[id].size() + 1; mzd_t* mat = mzd_init(xor_scc[i].size(), cols); for (int row = 0; row < xor_scc[i].size(); row++) { int k = xors[xor_scc[i][row]].c; for (int j = 0; j < clause[k].size(); j++) mzd_write_bit(mat, row, v2mzd[abs(clause[k][j])], 1); if (xors[xor_scc[i][row]].rhs) mzd_write_bit(mat, row, cols - 1, 1); } mzd_echelonize(mat, true); mzd_free(mat); for (int row = 0, rhs; row < xor_scc[i].size(); row++) { vec ones; for (int col = 0; col < cols - 1; col++) if (mzd_read_bit(mat, row, col)) { if (ones.size() == 2) goto NextRow; ones.push(mzd2v[col]); } rhs = mzd_read_bit(mat, row, cols - 1); if (ones.size() == 1) { ++gauss_eli_unit; clause.push(); ++clauses; clause[clauses].push(ones[0] * (rhs ? 1 : -1)); } else if (ones.size() == 2) { ++gauss_eli_binary; // assert(clauses == clause.size() - 1); int p = ones[0], q = rhs ? ones[1] : -ones[1]; clause.push(); ++clauses; clause[clauses].push(p); clause[clauses].push(q); clause.push(); ++clauses; clause[clauses].push(-p); clause[clauses].push(-q); } else if (rhs) {return false;} NextRow:; } } return true; } bool preprocess::preprocess_gauss() { int nxors = search_xors(); // printf("c [GE] XORs: %d (time: 0.00)\n", nxors); if (!nxors) return true; int nvarscc = ecc_var(); // printf("c [GE] VAR SCC: %d\n", nvarscc); int nxorscc = ecc_xor(); // printf("c [GE] XOR SCCs: %d (time: 0.00)\n", nxorscc); int res = gauss_elimination(); // printf("c [GE] unary xor: %d, bin xor: %d, bin added\n", gauss_eli_unit, gauss_eli_binary); // if (!res) {printf("c [GE] UNSAT\n");} xors.clear(true); scc_id.clear(true); for (int i = 0; i < scc.size(); i++) scc[i].clear(true); scc.clear(true); for (int i = 0; i < xor_scc.size(); i++) xor_scc[i].clear(true); xor_scc.clear(true); if (!res) return false; clause_delete.growTo(clauses + 1, 0); nxtc.growTo(clauses + 1, 0); return true; }