cloud-sat/kissat-inc/src/substitute.c
2023-03-26 19:15:17 +08:00

673 lines
18 KiB
C

#include "allocate.h"
#include "backtrack.h"
#include "inline.h"
#include "print.h"
#include "proprobe.h"
#include "report.h"
#include "substitute.h"
#include "terminate.h"
#include "trail.h"
#include "weaken.h"
static bool
really_substitute (kissat * solver, bool first)
{
if (!GET_OPTION (really))
return true;
const uint64_t clauses = CLAUSES;
const uint64_t needed = 3 * clauses; // 2 x watches + Tarjan visits
if (first && clauses > (unsigned) GET_OPTION (substitutelim))
return false;
statistics *statistics = &solver->statistics;
const uint64_t visits = statistics->search_ticks;
return needed < visits + GET_OPTION (substitutemineff);
}
static void
assign_and_propagate_units (kissat * solver, unsigneds * units)
{
if (EMPTY_STACK (*units))
return;
LOG ("assigning and propagating %zu units", SIZE_STACK (*units));
while (!solver->inconsistent && !EMPTY_STACK (*units))
{
const unsigned unit = POP_STACK (*units);
LOG ("trying to assign and propagate unit %s now", LOGLIT (unit));
const value value = VALUE (unit);
if (value > 0)
{
LOG ("skipping satisfied unit %s", LOGLIT (unit));
continue;
}
if (!value)
{
kissat_assign_unit (solver, unit);
INC (failed);
if (!kissat_probing_propagate (solver, 0))
continue;
LOG ("propagation of unit failed");
}
else
LOG ("inconsistent unit %s", LOGLIT (unit));
solver->inconsistent = true;
CHECK_AND_ADD_EMPTY ();
ADD_EMPTY_TO_PROOF ();
}
}
static void
determine_representatives (kissat * solver, unsigned *repr)
{
size_t bytes = LITS * sizeof (unsigned);
unsigned *mark = kissat_calloc (solver, LITS, sizeof *mark);
unsigned *reach = kissat_malloc (solver, LITS * sizeof *reach);
watches *all_watches = solver->watches;
const flags *flags = solver->flags;
unsigned reached = 0;
unsigneds scc;
unsigneds work;
unsigneds units;
INIT_STACK (scc);
INIT_STACK (work);
INIT_STACK (units);
unsigned trivial_sccs = 0;
unsigned non_trivial_sccs = 0;
for (all_literals (root))
{
if (solver->inconsistent)
break;
if (mark[root])
continue;
if (!ACTIVE (IDX (root)))
continue;
assert (EMPTY_STACK (scc));
assert (EMPTY_STACK (work));
LOG ("substitute root %s", LOGLIT (root));
PUSH_STACK (work, root);
bool failed = false;
const unsigned mark_root = reached + 1;
while (!solver->inconsistent && !EMPTY_STACK (work))
{
unsigned lit = TOP_STACK (work);
if (lit == INVALID_LIT)
{
(void) POP_STACK (work);
lit = POP_STACK (work);
const unsigned not_lit = NOT (lit);
unsigned reach_lit = reach[lit];
unsigned mark_lit = mark[lit];
assert (reach_lit == mark_lit);
assert (repr[lit] == INVALID_LIT);
watches *watches = all_watches + not_lit;
unsigned ticks = kissat_cache_lines (watches->size, sizeof (watch));
solver->dps_ticks += 1 + ticks;
for (all_binary_blocking_watches (watch, *watches))
{
if (!watch.type.binary)
continue;
const unsigned other = watch.binary.lit;
const unsigned idx_other = IDX (other);
if (!flags[idx_other].active)
continue;
assert (mark[other]);
unsigned reach_other = reach[other];
if (reach_other < reach_lit)
reach_lit = reach_other;
}
if (solver->dps == 1 && solver->dps_ticks >= solver->dps_period) {
solver->dps_ticks -= solver->dps_period;
solver->cbk_start_new_period(solver->issuer);
}
if (reach_lit != mark_lit)
{
LOG ("reach[%s] = %u", LOGLIT (lit), reach_lit);
reach[lit] = reach_lit;
continue;
}
unsigned *end_scc = END_STACK (scc);
unsigned *begin_scc = end_scc;
do
assert (begin_scc != BEGIN_STACK (scc));
while (*--begin_scc != lit);
SET_END_OF_STACK (scc, begin_scc);
const size_t size_scc = end_scc - begin_scc;
unsigned min_lit = lit;
if (size_scc > 1)
{
for (const unsigned *p = begin_scc; p != end_scc; p++)
{
const unsigned other = *p;
if (other < min_lit)
min_lit = other;
}
non_trivial_sccs++;
LOG ("size %u SCC entered trough %s representative %s",
size_scc, LOGLIT (lit), LOGLIT (min_lit));
}
else
{
trivial_sccs++;
LOG ("trivial size one SCC with %s", LOGLIT (lit));
assert (min_lit == lit);
}
for (const unsigned *p = begin_scc; p != end_scc; p++)
{
const unsigned other = *p;
LOG ("repr[%s] = %s", LOGLIT (other), LOGLIT (min_lit));
repr[other] = min_lit;
reach[other] = UINT_MAX;
const unsigned not_other = NOT (other);
const unsigned repr_not_other = repr[not_other];
if (repr_not_other == INVALID_LIT)
continue;
if (min_lit == repr_not_other)
{
LOG ("clashing literals %s and %s in same SCC",
LOGLIT (other), LOGLIT (not_other));
CHECK_AND_ADD_UNIT (min_lit);
ADD_UNIT_TO_PROOF (min_lit);
CHECK_AND_ADD_EMPTY ();
ADD_EMPTY_TO_PROOF ();
solver->inconsistent = true;
break;
}
assert (NOT (min_lit) == repr_not_other);
if (failed)
continue;
const unsigned mark_not_other = mark[not_other];
assert (mark_not_other != INVALID_LIT);
assert (mark[root] == mark_root);
if (mark_root > mark_not_other)
continue;
LOG ("root %s implies both %s and %s",
LOGLIT (root), LOGLIT (other), LOGLIT (not_other));
const unsigned unit = NOT (root);
PUSH_STACK (units, unit);
CHECK_AND_ADD_UNIT (unit);
ADD_UNIT_TO_PROOF (unit);
failed = true;
}
if (solver->inconsistent)
break;
}
else if (!mark[lit])
{
PUSH_STACK (work, INVALID_LIT);
PUSH_STACK (scc, lit);
mark[lit] = reach[lit] = ++reached;
LOG ("substitute mark[%s] = %u", LOGLIT (lit), reached);
const unsigned not_lit = NOT (lit);
watches *watches = all_watches + not_lit;
unsigned ticks = kissat_cache_lines (watches->size, sizeof (watch));
solver->dps_ticks += 1 + ticks;
for (all_binary_blocking_watches (watch, *watches))
{
if (!watch.type.binary)
continue;
const unsigned other = watch.binary.lit;
const unsigned idx_other = IDX (other);
if (!flags[idx_other].active)
continue;
if (!mark[other])
PUSH_STACK (work, other);
}
if (solver->dps == 1 && solver->dps_ticks >= solver->dps_period) {
solver->dps_ticks -= solver->dps_period;
solver->cbk_start_new_period(solver->issuer);
}
}
else
(void) POP_STACK (work);
}
}
RELEASE_STACK (work);
RELEASE_STACK (scc);
LOG ("reached %u literals", reached);
LOG ("found %u non-trivial SCCs", non_trivial_sccs);
LOG ("found %u trivial SCCs", trivial_sccs);
LOG ("found %zu units", SIZE_STACK (units));
assign_and_propagate_units (solver, &units);
RELEASE_STACK (units);
kissat_free (solver, reach, bytes);
kissat_free (solver, mark, bytes);
for (all_literals (lit))
if (repr[lit] == INVALID_LIT)
repr[lit] = lit;
}
static bool *
add_representative_equivalences (kissat * solver, unsigned *repr)
{
if (solver->inconsistent)
return 0;
bool *eliminate = kissat_calloc (solver, VARS, sizeof *eliminate);
for (all_variables (idx))
{
if (!ACTIVE (idx))
continue;
const unsigned lit = LIT (idx);
const unsigned other = repr[lit];
if (lit == other)
continue;
assert (other < lit);
#ifdef CHECKING_OR_PROVING
const unsigned not_lit = NOT (lit);
const unsigned not_other = NOT (other);
CHECK_AND_ADD_BINARY (not_lit, other);
ADD_BINARY_TO_PROOF (not_lit, other);
CHECK_AND_ADD_BINARY (lit, not_other);
ADD_BINARY_TO_PROOF (lit, not_other);
#endif
eliminate[idx] = true;
}
return eliminate;
}
static void
remove_representative_equivalences (kissat * solver,
unsigned *repr, bool * eliminate)
{
if (!solver->inconsistent)
{
value *values = solver->values;
const bool incremental = GET_OPTION (incremental);
for (all_variables (idx))
{
if (!eliminate[idx])
continue;
assert (ACTIVE (idx));
const unsigned lit = LIT (idx);
const unsigned other = repr[lit];
const unsigned not_lit = NOT (lit);
const unsigned not_other = NOT (other);
assert (other < lit);
assert (not_other < not_lit);
REMOVE_CHECKER_BINARY (not_lit, other);
DELETE_BINARY_FROM_PROOF (not_lit, other);
REMOVE_CHECKER_BINARY (lit, not_other);
DELETE_BINARY_FROM_PROOF (lit, not_other);
INC (substituted);
kissat_mark_eliminated_variable (solver, idx);
const value other_value = values[other];
if (incremental || other_value)
{
if (other_value <= 0)
kissat_weaken_binary (solver, not_lit, other);
if (other_value >= 0)
kissat_weaken_binary (solver, lit, not_other);
}
else
{
kissat_weaken_binary (solver, not_lit, other);
kissat_weaken_unit (solver, lit);
}
}
}
if (eliminate)
kissat_dealloc (solver, eliminate, VARS, sizeof *eliminate);
}
static void
substitute_binaries (kissat * solver, unsigned *repr)
{
if (solver->inconsistent)
return;
assert (sizeof (watch) == sizeof (unsigned));
statches *binaries = (statches *) & solver->delayed;
watches *all_watches = solver->watches;
size_t removed = 0;
size_t substituted = 0;
unsigneds units;
INIT_STACK (units);
for (all_literals (lit))
{
const unsigned repr_lit = repr[lit];
const unsigned not_repr_lit = NOT (repr_lit);
assert (EMPTY_STACK (*binaries));
watches *watches = all_watches + lit;
watch *begin = BEGIN_WATCHES (*watches), *q = begin;
const watch *end = END_WATCHES (*watches), *p = q;
while (p != end)
{
const watch src = *p++;
if (!src.type.binary)
continue;
const unsigned other = src.binary.lit;
const unsigned repr_other = repr[other];
LOGBINARY (lit, other, "substituting");
if (repr_other == not_repr_lit)
{
LOGBINARY (repr_other, repr_lit, "becomes tautological");
if (lit < other)
{
removed++;
kissat_delete_binary (solver,
src.binary.redundant,
src.binary.hyper, lit, other);
}
}
else if (repr_other == repr_lit)
{
const unsigned unit = repr_lit;
LOG ("simplifies to unit %s", LOGLIT (unit));
if (lit < other)
{
removed++;
PUSH_STACK (units, unit);
CHECK_AND_ADD_UNIT (unit);
ADD_UNIT_TO_PROOF (unit);
kissat_delete_binary (solver,
src.binary.redundant,
src.binary.hyper, lit, other);
}
}
else
{
watch dst = src;
dst.binary.lit = repr_other;
if (lit == repr_lit && other == repr_other)
{
LOGBINARY (lit, other, "unchanged");
*q++ = dst;
}
else
{
if (lit == repr_lit)
{
LOGBINARY (repr_lit, repr_other,
"substituted in place");
*q++ = dst;
}
else
{
LOGBINARY (repr_lit, repr_other, "delayed substituted");
PUSH_STACK (*binaries, dst);
}
if (lit < other)
{
substituted++;
ADD_BINARY_TO_PROOF (repr_lit, repr_other);
CHECK_AND_ADD_BINARY (repr_lit, repr_other);
DELETE_BINARY_FROM_PROOF (lit, other);
REMOVE_CHECKER_BINARY (lit, other);
}
}
}
}
SET_END_OF_WATCHES (*watches, q);
if (lit == repr_lit)
continue;
watches = all_watches + repr_lit;
for (all_stack (watch, watch, *binaries))
PUSH_WATCHES (*watches, watch);
CLEAR_STACK (*binaries);
}
LOG ("substituted %zu binary clauses", substituted);
LOG ("removed %zu binary clauses", removed);
assign_and_propagate_units (solver, &units);
RELEASE_STACK (units);
}
static void
substitute_clauses (kissat * solver, unsigned *repr)
{
if (solver->inconsistent)
return;
const value *values = solver->values;
value *marks = solver->marks;
size_t substituted = 0;
size_t removed = 0;
unsigneds units;
INIT_STACK (units);
for (all_clauses (c))
{
if (c->garbage)
continue;
LOGCLS (c, "substituting");
assert (EMPTY_STACK (solver->clause.lits));
bool shrink = false;
bool satisfied = false;
bool substitute = false;
bool tautological = false;
for (all_literals_in_clause (lit, c))
{
const value lit_value = values[lit];
if (lit_value < 0)
{
LOG ("dropping falsified %s", LOGLIT (lit));
shrink = true;
continue;
}
if (lit_value > 0)
{
LOGCLS (c, "satisfied by %s", LOGLIT (lit));
satisfied = true;
break;
}
const unsigned repr_lit = repr[lit];
const value repr_value = values[repr_lit];
if (repr_value < 0)
{
LOG ("dropping falsified substituted %s (was %s)",
LOGLIT (repr_lit), LOGLIT (lit));
shrink = true;
continue;
}
if (repr_value > 0)
{
LOGCLS (c, "satisfied by substituted %s (was %s)",
LOGLIT (repr_lit), LOGLIT (lit));
satisfied = true;
break;
}
if (lit != repr_lit)
{
assert (!values[repr_lit]);
LOG ("substituted literal %s (was %s)",
LOGLIT (repr_lit), LOGLIT (lit));
substitute = true;
}
else
LOG ("copying literal %s", LOGLIT (lit));
if (marks[repr_lit])
{
shrink = true;
LOG ("skipping duplicated %s", LOGLIT (repr_lit));
continue;
}
const unsigned not_repr_lit = NOT (repr_lit);
if (marks[not_repr_lit])
{
LOG ("substituted clause tautological "
"containing both %s and %s",
LOGLIT (not_repr_lit), LOGLIT (repr_lit));
tautological = true;
break;
}
marks[repr_lit] = true;
PUSH_STACK (solver->clause.lits, repr_lit);
}
if (satisfied || tautological)
{
kissat_mark_clause_as_garbage (solver, c);
removed++;
}
else if (substitute || shrink)
{
const unsigned size = SIZE_STACK (solver->clause.lits);
if (!size)
{
LOG ("simplifies to empty clause");
CHECK_AND_ADD_EMPTY ();
ADD_EMPTY_TO_PROOF ();
solver->inconsistent = true;
break;
}
else if (size == 1)
{
assert (shrink);
removed++;
const unsigned unit = PEEK_STACK (solver->clause.lits, 0);
LOGCLS (c, "simplifies to unit %s", LOGLIT (unit));
PUSH_STACK (units, unit);
CHECK_AND_ADD_UNIT (unit);
ADD_UNIT_TO_PROOF (unit);
kissat_mark_clause_as_garbage (solver, c);
}
else if (size == 2)
{
assert (shrink);
substituted++;
const unsigned first = PEEK_STACK (solver->clause.lits, 0);
const unsigned second = PEEK_STACK (solver->clause.lits, 1);
LOGCLS (c, "unsubstituted");
const bool redundant = c->redundant;
LOGBINARY (first, second, "substituted %s",
redundant ? "redundant" : "irredundant");
kissat_new_binary_clause (solver, redundant, first, second);
kissat_mark_clause_as_garbage (solver, c);
}
else
{
substituted++;
LOGCLS (c, "unsubstituted");
const unsigned new_size = SIZE_STACK (solver->clause.lits);
unsigned *new_lits = BEGIN_STACK (solver->clause.lits);
ADD_LITS_TO_PROOF (new_size, new_lits);
CHECK_AND_ADD_LITS (new_size, new_lits);
DELETE_CLAUSE_FROM_PROOF (c);
REMOVE_CHECKER_CLAUSE (c);
const unsigned old_size = c->size;
unsigned *old_lits = c->lits;
assert (new_size <= old_size);
memcpy (old_lits, new_lits, new_size * sizeof *old_lits);
assert (shrink == (new_size < old_size));
if (new_size < old_size)
{
c->size = new_size;
c->searched = 2;
if (!c->shrunken)
{
c->shrunken = true;
c->lits[old_size - 1] = INVALID_LIT;
}
}
LOGCLS (c, "unsorted substituted");
}
}
else
LOGCLS (c, "unchanged");
for (all_stack (unsigned, lit, solver->clause.lits))
marks[lit] = 0;
CLEAR_STACK (solver->clause.lits);
}
LOG ("removed %zu substituted large clauses", removed);
assign_and_propagate_units (solver, &units);
RELEASE_STACK (units);
}
static bool
substitute_round (kissat * solver, unsigned round)
{
assert (!solver->inconsistent);
const unsigned active = solver->active;
size_t bytes = LITS * sizeof (unsigned);
unsigned *repr = kissat_malloc (solver, bytes);
memset (repr, 0xff, bytes);
determine_representatives (solver, repr);
bool *eliminate = add_representative_equivalences (solver, repr);
substitute_binaries (solver, repr);
substitute_clauses (solver, repr);
remove_representative_equivalences (solver, repr, eliminate);
kissat_dealloc (solver, repr, LITS, sizeof *repr);
unsigned removed = active - solver->active;
kissat_phase (solver, "substitute", GET (substitutions),
"round %u removed %u variables %.0f%%",
round, removed, kissat_percent (removed, active));
kissat_check_statistics (solver);
REPORT (!removed, 'd');
#ifdef QUIET
(void) round;
#endif
return !solver->inconsistent && removed;
}
static void
substitute_rounds (kissat * solver)
{
START (substitute);
INC (substitutions);
unsigned maxrounds = GET_OPTION (substituterounds);
if (CLAUSES > (unsigned) GET_OPTION (substitutelim))
maxrounds = 1;
for (unsigned round = 1; round <= maxrounds; round++)
if (!substitute_round (solver, round))
break;
if (!solver->inconsistent)
{
kissat_watch_large_clauses (solver);
solver->propagated = 0;
if (kissat_probing_propagate (solver, 0))
{
LOG ("unit propagation after substitution results in conflict");
CHECK_AND_ADD_EMPTY ();
ADD_EMPTY_TO_PROOF ();
solver->inconsistent = true;
}
else if (solver->unflushed)
kissat_flush_trail (solver);
}
STOP (substitute);
}
void
kissat_substitute (kissat * solver, bool first)
{
if (solver->inconsistent)
return;
assert (solver->probing);
assert (solver->watching);
assert (!solver->level);
if (!GET_OPTION (substitute))
return;
if (TERMINATED (12))
return;
if (!really_substitute (solver, first))
return;
substitute_rounds (solver);
}