2023-03-26 19:15:17 +08:00

394 lines
8.8 KiB
C

#include "eliminate.h"
#include "gates.h"
#include "inline.h"
#include "sort.h"
#include "terminate.h"
#include "xors.h"
static void
mark_literals (kissat * solver, value * marks,
unsigned size, unsigned *lits, unsigned signs)
{
assert (size < 32);
for (unsigned i = 0, bit = 1; i < size; i++, bit <<= 1)
{
unsigned sign = ((bit & signs) != 0);
const unsigned lit = lits[i] ^ sign;
assert (!marks[lit]);
marks[lit] = 1;
LOG ("marked %s", LOGLIT (lit));
}
#ifndef LOGGING
(void) solver;
#endif
}
static void
unmark_literals (kissat * solver, value * marks,
unsigned size, unsigned *lits, unsigned signs)
{
assert (size < 32);
for (unsigned i = 0, bit = 1; i < size; i++, bit <<= 1)
{
unsigned sign = ((bit & signs) != 0);
const unsigned lit = lits[i] ^ sign;
assert (marks[lit]);
marks[lit] = 0;
LOG ("unmarked %s", LOGLIT (lit));
}
#ifndef LOGGING
(void) solver;
#endif
}
static unsigned
copy_literals (kissat * solver, unsigned lit,
const value * values, unsigned *lits, clause * c)
{
const unsigned *end = c->lits + c->size;
unsigned *q = lits;
#ifndef NDEBUG
bool found_lit = false;
#endif
for (const unsigned *p = c->lits; p != end; p++)
{
const unsigned other = *p;
if (other == lit)
{
#ifndef NDEBUG
assert (!found_lit);
assert (!values[other]);
found_lit = true;
#endif
}
else
{
const value value = values[other];
assert (value <= 0);
if (value < 0)
{
LOG ("skipping falsified %s", LOGLIT (other));
continue;
}
LOG ("copying %s", LOGLIT (other));
*q++ = other;
}
}
assert (found_lit);
*q++ = lit;
const unsigned size = q - lits;
LOGLITS (size, lits, "copied", size);
#ifndef LOGGING
(void) solver;
#endif
return size;
}
static bool
odd_parity (unsigned size, unsigned signs)
{
bool res = false;
for (unsigned i = 0; i < size; i++)
if (signs & (1u << i))
res = !res;
return res;
}
static unsigned
next_marking (kissat * solver, value * marks,
unsigned size, unsigned *lits, unsigned prev)
{
LOG ("prev signs %s", FORMAT_SIGNS (size, prev));
assert (2 < size);
assert (size < 32);
const unsigned limit = (1u << size);
assert (prev < limit);
unsigned next;
for (next = prev + 1; odd_parity (size, next); next++)
;
LOG ("next signs %s", FORMAT_SIGNS (size, next));
for (unsigned i = 0, bit = 1; i < size; i++, bit <<= 1)
{
const unsigned lit = lits[i];
const unsigned not_lit = NOT (lit);
if (!(prev & bit) && (next & bit))
{
assert (marks[lit]);
assert (!marks[not_lit]);
marks[lit] = 0;
LOG ("unmarked %s", LOGLIT (lit));
marks[not_lit] = 1;
LOG ("marked %s", LOGLIT (not_lit));
}
else if ((prev & bit) && !(next & bit))
{
assert (marks[not_lit]);
assert (!marks[lit]);
marks[not_lit] = 0;
LOG ("unmarked %s", LOGLIT (not_lit));
marks[lit] = 1;
LOG ("marked %s", LOGLIT (lit));
}
}
#ifndef LOGGING
(void) solver;
#endif
return next == limit ? 0 : next;
}
static bool
match_lits_ref (kissat * solver, const value * marks, const value * values,
unsigned size, reference ref)
{
clause *c = kissat_dereference_clause (solver, ref);
unsigned found = 0;
for (all_literals_in_clause (lit, c))
{
const value value = values[lit];
if (value > 0)
{
kissat_eliminate_clause (solver, c, INVALID_LIT);
return false;
}
if (value < 0)
continue;
if (!marks[lit])
return false;
found++;
}
assert (found <= size);
if (found < size)
solver->resolve_gate = true;
return true;
}
static bool
match_lits_watch (kissat * solver,
const value * marks, const value * values,
unsigned size, watch watch)
{
if (watch.type.binary)
{
const unsigned other = watch.binary.lit;
if (!marks[other])
return false;
assert (size > 2);
solver->resolve_gate = true;
return true;
}
else
{
const reference ref = watch.large.ref;
return match_lits_ref (solver, marks, values, size, ref);
}
}
static watch *
find_lits_watch (kissat * solver, watch * begin, watch * end,
const value * marks, const value * values,
unsigned size, uint64_t * steps)
{
assert (begin <= end);
for (watch * p = begin; p != end; p++)
{
*steps += 1;
if (match_lits_watch (solver, marks, values, size, *p))
return p;
}
return 0;
}
#define LESS_POINTER(P,Q) \
((P) < (Q))
static void
sort_watch_pointers (kissat * solver, patches * patches)
{
SORT_STACK (watch *, *patches, LESS_POINTER);
}
bool
kissat_find_xor_gate (kissat * solver, unsigned lit, unsigned negative)
{
if (!GET_OPTION (xors))
return false;
const unsigned size_limit = solver->bounds.xork.clause_size;
if (size_limit < 3)
return false;
assert (size_limit < 32);
watches *watches0 = &WATCHES (lit);
watch *begin0 = BEGIN_WATCHES (*watches0);
watch *end0 = END_WATCHES (*watches0);
if (begin0 == end0)
return false;
uint64_t large_clauses0 = 0;
for (watch * p = begin0; p != end0; p++)
if (!p->type.binary && large_clauses0++)
break;
if (large_clauses0 < 2)
return false;
const unsigned not_lit = NOT (lit);
watches *watches1 = &WATCHES (not_lit);
watch *begin1 = BEGIN_WATCHES (*watches1);
watch *end1 = END_WATCHES (*watches1);
uint64_t large_clauses1 = 0;
for (watch * p = begin1; p != end1; p++)
if (!p->type.binary && large_clauses1++)
break;
if (large_clauses1 < 2)
return false;
unsigned lits[size_limit];
const value *values = solver->values;
value *marks = solver->marks;
const unsigned steps_limit = solver->bounds.eliminate.occurrences;
uint64_t steps = 0;
for (watch * p = begin0; p != end0; p++)
{
if (p->type.binary)
continue;
if (steps > steps_limit)
break;
if (TERMINATED (26))
break;
steps++;
clause *c = kissat_dereference_clause (solver, p->large.ref);
if (c->size > size_limit)
continue;
unsigned size = copy_literals (solver, lit, values, lits, c);
assert (size <= 32);
if (size < 3)
continue;
solver->resolve_gate = false;
assert (EMPTY_STACK (solver->xorted[0]));
assert (EMPTY_STACK (solver->xorted[1]));
PUSH_STACK (solver->xorted[0], p);
unsigned signs = 0;
mark_literals (solver, marks, size, lits, signs);
while (steps <= steps_limit &&
(signs = next_marking (solver, marks, size, lits, signs)))
{
if (marks[lit])
{
watch *q = find_lits_watch (solver, begin0, end0,
marks, values, size, &steps);
if (!q)
{
LOGLITS (size, lits,
"could not match signs %s of copied",
FORMAT_SIGNS (size, signs));
break;
}
LOGWATCH (lit, *q, "literal %s XOR", LOGLIT (lit));
PUSH_STACK (solver->xorted[0], q);
}
else
{
watch *q = find_lits_watch (solver, begin1, end1,
marks, values, size, &steps);
if (!q)
break;
LOGWATCH (not_lit, *q, "found %s literal XOR",
LOGLIT (not_lit));
PUSH_STACK (solver->xorted[1], q);
}
}
unmark_literals (solver, marks, size, lits, signs);
unsigned nsort[2] = {
SIZE_STACK (solver->xorted[0]), SIZE_STACK (solver->xorted[1])
};
if (nsort[0] + nsort[1] == (1u << (size - 1)))
{
assert (nsort[0] == (1u << (size - 2)));
assert (nsort[1] == (1u << (size - 2)));
sort_watch_pointers (solver, &solver->xorted[0]);
const watch *prev = 0;
for (unsigned i = 0; i < nsort[0]; i++)
{
const watch *p0 = PEEK_STACK (solver->xorted[0], i);
const watch w0 = *p0;
if (p0 == prev)
LOGWATCH (lit, w0, "dropping repeated");
else
{
LOGWATCH (lit, w0, "%s %s XOR",
FORMAT_ORDINAL (i + 1), LOGLIT (lit));
PUSH_STACK (solver->gates[negative], w0);
}
prev = p0;
}
sort_watch_pointers (solver, &solver->xorted[1]);
prev = 0;
for (unsigned i = 0; i < nsort[1]; i++)
{
const watch *p1 = PEEK_STACK (solver->xorted[1], i);
const watch w1 = *p1;
if (p1 == prev)
LOGWATCH (not_lit, w1, "dropping repeated");
else
{
LOGWATCH (not_lit, w1, "%s %s XOR",
FORMAT_ORDINAL (i + 1), LOGLIT (not_lit));
PUSH_STACK (solver->gates[!negative], w1);
}
prev = p1;
}
assert (!EMPTY_STACK (solver->gates[0]));
#ifdef LOGGING
assert (size > 1);
assert (lits[size - 1] == lit);
lits[0] = NOT (lits[0]);
LOGXOR (lit, size - 1, lits, "found");
lits[0] = NOT (lits[0]);
#endif
}
CLEAR_STACK (solver->xorted[0]);
CLEAR_STACK (solver->xorted[1]);
if (EMPTY_STACK (solver->gates[0]))
continue;
solver->gate_eliminated = GATE_ELIMINATED (xors);
return true;
}
assert (EMPTY_STACK (solver->xorted[0]));
assert (EMPTY_STACK (solver->xorted[1]));
return false;
}