#include "eliminate.h" #include "gates.h" #include "inline.h" #include "sort.h" #include "terminate.h" #include "xors.h" static void mark_literals (kissat * solver, value * marks, unsigned size, unsigned *lits, unsigned signs) { assert (size < 32); for (unsigned i = 0, bit = 1; i < size; i++, bit <<= 1) { unsigned sign = ((bit & signs) != 0); const unsigned lit = lits[i] ^ sign; assert (!marks[lit]); marks[lit] = 1; LOG ("marked %s", LOGLIT (lit)); } #ifndef LOGGING (void) solver; #endif } static void unmark_literals (kissat * solver, value * marks, unsigned size, unsigned *lits, unsigned signs) { assert (size < 32); for (unsigned i = 0, bit = 1; i < size; i++, bit <<= 1) { unsigned sign = ((bit & signs) != 0); const unsigned lit = lits[i] ^ sign; assert (marks[lit]); marks[lit] = 0; LOG ("unmarked %s", LOGLIT (lit)); } #ifndef LOGGING (void) solver; #endif } static unsigned copy_literals (kissat * solver, unsigned lit, const value * values, unsigned *lits, clause * c) { const unsigned *end = c->lits + c->size; unsigned *q = lits; #ifndef NDEBUG bool found_lit = false; #endif for (const unsigned *p = c->lits; p != end; p++) { const unsigned other = *p; if (other == lit) { #ifndef NDEBUG assert (!found_lit); assert (!values[other]); found_lit = true; #endif } else { const value value = values[other]; assert (value <= 0); if (value < 0) { LOG ("skipping falsified %s", LOGLIT (other)); continue; } LOG ("copying %s", LOGLIT (other)); *q++ = other; } } assert (found_lit); *q++ = lit; const unsigned size = q - lits; LOGLITS (size, lits, "copied", size); #ifndef LOGGING (void) solver; #endif return size; } static bool odd_parity (unsigned size, unsigned signs) { bool res = false; for (unsigned i = 0; i < size; i++) if (signs & (1u << i)) res = !res; return res; } static unsigned next_marking (kissat * solver, value * marks, unsigned size, unsigned *lits, unsigned prev) { LOG ("prev signs %s", FORMAT_SIGNS (size, prev)); assert (2 < size); assert (size < 32); const unsigned limit = (1u << size); assert (prev < limit); unsigned next; for (next = prev + 1; odd_parity (size, next); next++) ; LOG ("next signs %s", FORMAT_SIGNS (size, next)); for (unsigned i = 0, bit = 1; i < size; i++, bit <<= 1) { const unsigned lit = lits[i]; const unsigned not_lit = NOT (lit); if (!(prev & bit) && (next & bit)) { assert (marks[lit]); assert (!marks[not_lit]); marks[lit] = 0; LOG ("unmarked %s", LOGLIT (lit)); marks[not_lit] = 1; LOG ("marked %s", LOGLIT (not_lit)); } else if ((prev & bit) && !(next & bit)) { assert (marks[not_lit]); assert (!marks[lit]); marks[not_lit] = 0; LOG ("unmarked %s", LOGLIT (not_lit)); marks[lit] = 1; LOG ("marked %s", LOGLIT (lit)); } } #ifndef LOGGING (void) solver; #endif return next == limit ? 0 : next; } static bool match_lits_ref (kissat * solver, const value * marks, const value * values, unsigned size, reference ref) { clause *c = kissat_dereference_clause (solver, ref); unsigned found = 0; for (all_literals_in_clause (lit, c)) { const value value = values[lit]; if (value > 0) { kissat_eliminate_clause (solver, c, INVALID_LIT); return false; } if (value < 0) continue; if (!marks[lit]) return false; found++; } assert (found <= size); if (found < size) solver->resolve_gate = true; return true; } static bool match_lits_watch (kissat * solver, const value * marks, const value * values, unsigned size, watch watch) { if (watch.type.binary) { const unsigned other = watch.binary.lit; if (!marks[other]) return false; assert (size > 2); solver->resolve_gate = true; return true; } else { const reference ref = watch.large.ref; return match_lits_ref (solver, marks, values, size, ref); } } static watch * find_lits_watch (kissat * solver, watch * begin, watch * end, const value * marks, const value * values, unsigned size, uint64_t * steps) { assert (begin <= end); for (watch * p = begin; p != end; p++) { *steps += 1; if (match_lits_watch (solver, marks, values, size, *p)) return p; } return 0; } #define LESS_POINTER(P,Q) \ ((P) < (Q)) static void sort_watch_pointers (kissat * solver, patches * patches) { SORT_STACK (watch *, *patches, LESS_POINTER); } bool kissat_find_xor_gate (kissat * solver, unsigned lit, unsigned negative) { if (!GET_OPTION (xors)) return false; const unsigned size_limit = solver->bounds.xork.clause_size; if (size_limit < 3) return false; assert (size_limit < 32); watches *watches0 = &WATCHES (lit); watch *begin0 = BEGIN_WATCHES (*watches0); watch *end0 = END_WATCHES (*watches0); if (begin0 == end0) return false; uint64_t large_clauses0 = 0; for (watch * p = begin0; p != end0; p++) if (!p->type.binary && large_clauses0++) break; if (large_clauses0 < 2) return false; const unsigned not_lit = NOT (lit); watches *watches1 = &WATCHES (not_lit); watch *begin1 = BEGIN_WATCHES (*watches1); watch *end1 = END_WATCHES (*watches1); uint64_t large_clauses1 = 0; for (watch * p = begin1; p != end1; p++) if (!p->type.binary && large_clauses1++) break; if (large_clauses1 < 2) return false; unsigned lits[size_limit]; const value *values = solver->values; value *marks = solver->marks; const unsigned steps_limit = solver->bounds.eliminate.occurrences; uint64_t steps = 0; for (watch * p = begin0; p != end0; p++) { if (p->type.binary) continue; if (steps > steps_limit) break; if (TERMINATED (26)) break; steps++; clause *c = kissat_dereference_clause (solver, p->large.ref); if (c->size > size_limit) continue; unsigned size = copy_literals (solver, lit, values, lits, c); assert (size <= 32); if (size < 3) continue; solver->resolve_gate = false; assert (EMPTY_STACK (solver->xorted[0])); assert (EMPTY_STACK (solver->xorted[1])); PUSH_STACK (solver->xorted[0], p); unsigned signs = 0; mark_literals (solver, marks, size, lits, signs); while (steps <= steps_limit && (signs = next_marking (solver, marks, size, lits, signs))) { if (marks[lit]) { watch *q = find_lits_watch (solver, begin0, end0, marks, values, size, &steps); if (!q) { LOGLITS (size, lits, "could not match signs %s of copied", FORMAT_SIGNS (size, signs)); break; } LOGWATCH (lit, *q, "literal %s XOR", LOGLIT (lit)); PUSH_STACK (solver->xorted[0], q); } else { watch *q = find_lits_watch (solver, begin1, end1, marks, values, size, &steps); if (!q) break; LOGWATCH (not_lit, *q, "found %s literal XOR", LOGLIT (not_lit)); PUSH_STACK (solver->xorted[1], q); } } unmark_literals (solver, marks, size, lits, signs); unsigned nsort[2] = { SIZE_STACK (solver->xorted[0]), SIZE_STACK (solver->xorted[1]) }; if (nsort[0] + nsort[1] == (1u << (size - 1))) { assert (nsort[0] == (1u << (size - 2))); assert (nsort[1] == (1u << (size - 2))); sort_watch_pointers (solver, &solver->xorted[0]); const watch *prev = 0; for (unsigned i = 0; i < nsort[0]; i++) { const watch *p0 = PEEK_STACK (solver->xorted[0], i); const watch w0 = *p0; if (p0 == prev) LOGWATCH (lit, w0, "dropping repeated"); else { LOGWATCH (lit, w0, "%s %s XOR", FORMAT_ORDINAL (i + 1), LOGLIT (lit)); PUSH_STACK (solver->gates[negative], w0); } prev = p0; } sort_watch_pointers (solver, &solver->xorted[1]); prev = 0; for (unsigned i = 0; i < nsort[1]; i++) { const watch *p1 = PEEK_STACK (solver->xorted[1], i); const watch w1 = *p1; if (p1 == prev) LOGWATCH (not_lit, w1, "dropping repeated"); else { LOGWATCH (not_lit, w1, "%s %s XOR", FORMAT_ORDINAL (i + 1), LOGLIT (not_lit)); PUSH_STACK (solver->gates[!negative], w1); } prev = p1; } assert (!EMPTY_STACK (solver->gates[0])); #ifdef LOGGING assert (size > 1); assert (lits[size - 1] == lit); lits[0] = NOT (lits[0]); LOGXOR (lit, size - 1, lits, "found"); lits[0] = NOT (lits[0]); #endif } CLEAR_STACK (solver->xorted[0]); CLEAR_STACK (solver->xorted[1]); if (EMPTY_STACK (solver->gates[0])) continue; solver->gate_eliminated = GATE_ELIMINATED (xors); return true; } assert (EMPTY_STACK (solver->xorted[0])); assert (EMPTY_STACK (solver->xorted[1])); return false; }