cloud-sat/kissat-inc/src/transitive.c
2023-03-26 19:15:17 +08:00

294 lines
7.9 KiB
C

#include "allocate.h"
#include "analyze.h"
#include "internal.h"
#include "logging.h"
#include "print.h"
#include "proprobe.h"
#include "report.h"
#include "terminate.h"
#include "trail.h"
#include "transitive.h"
#include <stddef.h>
static void
transitive_assign(kissat *solver, unsigned lit)
{
LOG("transitive assign %s", LOGLIT(lit));
value *values = solver->values;
const unsigned not_lit = NOT(lit);
assert(!values[lit]);
assert(!values[not_lit]);
values[lit] = 1;
values[not_lit] = -1;
PUSH_STACK(solver->trail, lit);
}
static void
transitive_backtrack(kissat *solver, unsigned saved)
{
assert(saved <= SIZE_STACK(solver->trail));
value *values = solver->values;
while (SIZE_STACK(solver->trail) > saved)
{
const unsigned lit = POP_STACK(solver->trail);
LOG("transitive unassign %s", LOGLIT(lit));
const unsigned not_lit = NOT(lit);
assert(values[lit] > 0);
assert(values[not_lit] < 0);
values[lit] = values[not_lit] = 0;
}
solver->propagated = saved;
solver->level = 0;
}
static void
prioritize_binaries(kissat *solver)
{
assert(solver->watching);
statches large;
INIT_STACK(large);
watches *all_watches = solver->watches;
for (all_literals(lit))
{
assert(EMPTY_STACK(large));
watches *watches = all_watches + lit;
watch *begin_watches = BEGIN_WATCHES(*watches), *q = begin_watches;
const watch *end_watches = END_WATCHES(*watches), *p = q;
while (p != end_watches)
{
const watch head = *q++ = *p++;
if (head.type.binary)
continue;
const watch tail = *p++;
PUSH_STACK(large, head);
PUSH_STACK(large, tail);
q--;
}
const watch *end_large = END_STACK(large);
const watch *r = BEGIN_STACK(large);
while (r != end_large)
*q++ = *r++;
assert(q == end_watches);
CLEAR_STACK(large);
}
RELEASE_STACK(large);
}
static bool
transitive_reduce(kissat *solver,
unsigned src, uint64_t limit,
uint64_t *reduced_ptr, unsigned *units)
{
bool res = false;
assert(!VALUE(src));
LOG("transitive reduce %s", LOGLIT(src));
watches *all_watches = solver->watches;
watches *src_watches = all_watches + src;
watch *end_src = END_WATCHES(*src_watches);
watch *begin_src = BEGIN_WATCHES(*src_watches);
unsigned ticks = kissat_cache_lines(src_watches->size, sizeof(watch));
ADD(transitive_ticks, ticks + 1);
solver->dps_ticks += 1 + ticks;
const unsigned not_src = NOT(src);
unsigned reduced = 0;
bool failed = false;
for (watch *p = begin_src; p != end_src; p++)
{
const watch src_watch = *p;
if (!src_watch.type.binary)
break;
const unsigned dst = src_watch.binary.lit;
if (dst < src)
continue;
if (VALUE(dst))
continue;
assert(solver->propagated == SIZE_STACK(solver->trail));
unsigned saved = solver->propagated;
assert(!solver->level);
solver->level = 1;
transitive_assign(solver, not_src);
const bool redundant = src_watch.binary.redundant;
bool transitive = false;
unsigned propagated = 0;
while (!transitive && !failed &&
solver->propagated < SIZE_STACK(solver->trail))
{
const unsigned lit = PEEK_STACK(solver->trail, solver->propagated);
solver->propagated++;
propagated++;
LOG("transitive propagate %s", LOGLIT(lit));
assert(VALUE(lit) > 0);
const unsigned not_lit = NOT(lit);
watches *lit_watches = all_watches + not_lit;
const watch *end_lit = END_WATCHES(*lit_watches);
const watch *begin_lit = BEGIN_WATCHES(*lit_watches);
ticks = kissat_cache_lines(lit_watches->size, sizeof(watch));
ADD(transitive_ticks, ticks + 1);
solver->dps_ticks += ticks + 1;
for (const watch *q = begin_lit; q != end_lit; q++)
{
if (p == q)
continue;
const watch lit_watch = *q;
if (!lit_watch.type.binary)
break;
if (not_lit == src && lit_watch.binary.lit == ILLEGAL_LIT)
continue;
if (!redundant && lit_watch.binary.redundant)
continue;
const unsigned other = lit_watch.binary.lit;
if (other == dst)
{
transitive = true;
break;
}
const value value = VALUE(other);
if (value < 0)
{
LOG("both %s and %s reachable from %s",
LOGLIT(NOT(other)), LOGLIT(other), LOGLIT(src));
failed = true;
break;
}
if (!value)
transitive_assign(solver, other);
}
}
assert(solver->probing);
ADD(propagations, propagated);
ADD(probing_propagations, propagated);
ADD(transitive_propagations, propagated);
transitive_backtrack(solver, saved);
if (transitive)
{
LOGBINARY(src, dst, "transitive reduce");
INC(transitive_reduced);
watches *dst_watches = all_watches + dst;
watch dst_watch = src_watch;
assert(dst_watch.binary.lit == dst);
assert(dst_watch.binary.redundant == redundant);
dst_watch.binary.lit = src;
REMOVE_WATCHES(*dst_watches, dst_watch);
kissat_delete_binary(solver,
src_watch.binary.redundant,
src_watch.binary.hyper, src, dst);
p->binary.lit = ILLEGAL_LIT;
reduced++;
res = true;
}
if (failed)
break;
if (solver->statistics.transitive_ticks > limit)
break;
if (TERMINATED(16))
break;
}
if (solver->dps == 1 && solver->dps_ticks >= solver->dps_period)
{
solver->dps_ticks -= solver->dps_period;
solver->cbk_start_new_period(solver->issuer);
}
if (reduced)
{
*reduced_ptr += reduced;
assert(begin_src == BEGIN_WATCHES(WATCHES(src)));
assert(end_src == END_WATCHES(WATCHES(src)));
watch *q = begin_src;
for (const watch *p = begin_src; p != end_src; p++)
{
const watch src_watch = *q++ = *p;
if (!src_watch.type.binary)
{
*q++ = *++p;
continue;
}
if (src_watch.binary.lit == ILLEGAL_LIT)
q--;
}
assert(end_src - q == (ptrdiff_t)reduced);
SET_END_OF_WATCHES(*src_watches, q);
}
if (failed)
{
LOG("transitive failed literal %s", LOGLIT(not_src));
INC(failed);
*units += 1;
res = true;
kissat_assign_unit(solver, src);
CHECK_AND_ADD_UNIT(src);
ADD_UNIT_TO_PROOF(src);
clause *conflict = kissat_probing_propagate(solver, 0);
if (conflict)
{
(void)kissat_analyze(solver, conflict);
assert(solver->inconsistent);
}
else
{
assert(solver->unflushed);
kissat_flush_trail(solver);
}
}
return res;
}
void kissat_transitive_reduction(kissat *solver)
{
if (solver->inconsistent)
return;
assert(solver->watching);
assert(solver->probing);
assert(!solver->level);
if (!GET_OPTION(transitive))
return;
if (TERMINATED(17))
return;
START(transitive);
prioritize_binaries(solver);
bool success = false;
uint64_t reduced = 0;
unsigned units = 0;
SET_EFFICIENCY_BOUND(limit, transitive, transitive_ticks, search_ticks, 0);
assert(solver->transitive < LITS);
const unsigned end = solver->transitive;
#ifndef QUIET
const unsigned active = solver->active;
#endif
unsigned probed = 0;
do
{
const unsigned lit = solver->transitive++;
if (solver->transitive == LITS)
solver->transitive = 0;
if (!ACTIVE(IDX(lit)))
continue;
probed++;
if (transitive_reduce(solver, lit, limit, &reduced, &units))
success = true;
if (solver->inconsistent)
break;
if (solver->statistics.transitive_ticks > limit)
break;
if (TERMINATED(18))
break;
} while (solver->transitive != end);
kissat_phase(solver, "transitive", GET(probings),
"probed %u (%.0f%%): reduced %" PRIu64 ", units %u",
probed, kissat_percent(probed, 2 * active), reduced, units);
STOP(transitive);
REPORT(!success, 't');
#ifdef QUIET
(void)success;
#endif
}