Skip to content
Snippets Groups Projects
Commit 00b293e8 authored by Arjun P's avatar Arjun P
Browse files

[MLIR][Presburger] refactor subtraction to be non-recursive

Subtraction was previously implemented recursively. This refactors it to be
non-recursive to avoid issues with potential stack overflows.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D123248
parent b7042b73
No related branches found
No related tags found
No related merge requests found
......@@ -129,18 +129,17 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
return getNegatedCoeffs(eqCoeffs);
}
/// Return the set difference b \ s and accumulate the result into `result`.
/// `simplex` must correspond to b.
/// Return the set difference b \ s.
///
/// In the following, U denotes union, ^ denotes intersection, \ denotes set
/// In the following, U denotes union, /\ denotes intersection, \ denotes set
/// difference and ~ denotes complement.
/// Let b be the IntegerRelation and s = (U_i s_i) be the set. We want
/// b \ (U_i s_i).
///
/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
/// ~s_i = (~s_i1) U (s_i1 ^ ~s_i2) U (s_i1 ^ s_i2 ^ ~s_i3) U ...
/// And the required result is (b ^ ~s_i1) U (b ^ s_i1 ^ ~s_i2) U ...
/// Let s = (U_i s_i). We want b \ (U_i s_i).
///
/// Let s_i = /\_j s_ij, where each s_ij is a single inequality. To compute
/// b \ s_i = b /\ ~s_i, we partition s_i based on the first violated
/// inequality: ~s_i = (~s_i1) U (s_i1 /\ ~s_i2) U (s_i1 /\ s_i2 /\ ~s_i3) U ...
/// And the required result is (b /\ ~s_i1) U (b /\ s_i1 /\ ~s_i2) U ...
/// We recurse by subtracting U_{j > i} S_j from each of these parts and
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
......@@ -162,151 +161,192 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
/// that some constraints are redundant. These redundant constraints are
/// ignored.
///
/// b should not have duplicate divs because this might lead to existing
/// divs disappearing in the call to mergeLocalIds below, which cannot be
/// handled.
static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
const PresburgerRelation &s, unsigned i,
PresburgerRelation &result) {
if (i == s.getNumDisjuncts()) {
result.unionInPlace(b);
return;
}
static PresburgerRelation getSetDifference(IntegerRelation b,
const PresburgerRelation &s) {
assert(b.isSpaceCompatible(s) && "Spaces should match");
if (b.isEmptyByGCDTest())
return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
IntegerRelation sI = s.getDisjunct(i);
// Remove the duplicate divs up front to avoid them possibly disappearing
// in the call to mergeLocalIds below.
sI.removeDuplicateDivs();
// Below, we append some additional constraints and ids to b. We want to
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
// Similarly, we also want to rollback simplex to its original state.
unsigned initialSnapshot = simplex.getSnapshot();
// Find out which inequalities of sI correspond to division inequalities for
// the local variables of sI.
std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
sI.getLocalReprs(repr);
// Add sI's locals to b, after b's locals. Also add b's locals to sI, before
// sI's locals.
b.mergeLocalIds(sI);
unsigned numLocalsAdded =
b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
// Update simplex to also include the new locals in `b` from merging.
simplex.appendVariable(numLocalsAdded);
// Equalities are processed by considering them as a pair of inequalities.
// The first sI.getNumInequalities() elements are for sI's inequalities;
// then a pair of inequalities occurs for each of sI's equalities.
// If the equality is expr == 0, the first element in the pair
// corresponds to expr >= 0, and the second to expr <= 0.
llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
2 * sI.getNumEqualities());
// Add all division inequalities to `b`.
for (MaybeLocalRepr &maybeInequality : repr) {
assert(maybeInequality.kind == ReprKind::Inequality &&
"Subtraction is not supported when a representation of the local "
"variables of the subtrahend cannot be found!");
unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
b.addInequality(sI.getInequality(lb));
b.addInequality(sI.getInequality(ub));
assert(lb != ub &&
"Upper and lower bounds must be different inequalities!");
// We just added these inequalities to `b`, so there is no point considering
// the parts where these inequalities occur complemented -- such parts are
// empty. Therefore, we mark that these can be ignored.
canIgnoreIneq[lb] = true;
canIgnoreIneq[ub] = true;
}
unsigned offset = simplex.getNumConstraints();
unsigned snapshotBeforeIntersect = simplex.getSnapshot();
simplex.intersectIntegerRelation(sI);
if (simplex.isEmpty()) {
// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
// We are ignoring level i completely, so we restore the state
// *before* going to level i + 1.
b.truncate(initBCounts);
simplex.rollback(initialSnapshot);
subtractRecursively(b, simplex, s, i + 1, result);
return;
}
// Remove duplicate divs up front here to avoid existing
// divs disappearing in the call to mergeLocalIds below.
b.removeDuplicateDivs();
simplex.detectRedundant();
unsigned totalNewSimplexInequalities =
2 * sI.getNumEqualities() + sI.getNumInequalities();
// Redundant inequalities can be safely ignored. This is not required for
// correctness but improves performance and results in a more compact
// representation of the set difference.
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
simplex.rollback(snapshotBeforeIntersect);
SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
if (!canIgnoreIneq[i])
ineqsToProcess.push_back(i);
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
// subtractRecursively. At the time this function is called, the current b is
// actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
// inequality, s_{i,j+1}. This function recurses into the next level i + 1
// with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
b.addInequality(ineq);
simplex.addInequality(ineq);
subtractRecursively(b, simplex, s, i + 1, result);
PresburgerRelation result =
PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
Simplex simplex(b);
// This algorithm is more naturally expressed recursively, but we implement
// it iteratively here to avoid issues with stack sizes.
//
// Each level of the recursion has five stack variables.
struct Frame {
// A snapshot of the simplex state to rollback to.
unsigned simplexSnapshot;
// A CountsSnapshot of `b` to rollback to.
IntegerRelation::CountsSnapshot bCounts;
// The IntegerRelation currently being operated on.
IntegerRelation sI;
// A list of indexes (see getIneqCoeffsFromIdx) of inequalities to be
// processed.
SmallVector<unsigned, 8> ineqsToProcess;
// The index of the last inequality that was processed at this level.
// This is empty when we are coming to this level for the first time.
Optional<unsigned> lastIneqProcessed;
};
SmallVector<Frame, 2> frames;
// When we "recurse", we ensure the current frame is stored in `frames` and
// increment `level`. When we "tail recurse", we just increment `level`,
// without storing any frame. Accordingly, when we return, we return to the
// last level that has a frame associated with it.
unsigned level = 1;
while (level > 0) {
if (level - 1 >= s.getNumDisjuncts()) {
// No more parts to subtract; add to the result and return.
result.unionInPlace(b);
level = frames.size();
continue;
}
// For each inequality ineq, we first recurse with the part where ineq
// is not satisfied, and then add the ineq to b and simplex because
// ineq must be satisfied by all later parts.
auto processInequality = [&](ArrayRef<int64_t> ineq) {
unsigned snapshot = simplex.getSnapshot();
IntegerRelation::CountsSnapshot bCounts = b.getCounts();
recurseWithInequality(getComplementIneq(ineq));
simplex.rollback(snapshot);
b.truncate(bCounts);
b.addInequality(ineq);
simplex.addInequality(ineq);
};
if (level > frames.size()) {
// No frame for this level yet, so we have just recursed into this level.
IntegerRelation sI = s.getDisjunct(level - 1);
// Remove the duplicate divs up front to avoid them possibly disappearing
// in the call to mergeLocalIds below.
sI.removeDuplicateDivs();
// Below, we append some additional constraints and ids to b. We want to
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
// Similarly, we also want to rollback simplex to its original state.
unsigned initialSnapshot = simplex.getSnapshot();
// Find out which inequalities of sI correspond to division inequalities
// for the local variables of sI.
std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
sI.getLocalReprs(repr);
// Add sI's locals to b, after b's locals. Only those locals of sI which
// do not already exist in b will be added. (i.e., duplicate divisions
// will not be added.) Also add b's locals to sI, in such a way that both
// have the same locals in the same order in the end.
b.mergeLocalIds(sI);
// Mark which inequalities of sI are division inequalities and add all
// such inequalities to b.
llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
2 * sI.getNumEqualities());
for (MaybeLocalRepr &maybeInequality : repr) {
assert(
maybeInequality.kind == ReprKind::Inequality &&
"Subtraction is not supported when a representation of the local "
"variables of the subtrahend cannot be found!");
unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
b.addInequality(sI.getInequality(lb));
b.addInequality(sI.getInequality(ub));
assert(lb != ub &&
"Upper and lower bounds must be different inequalities!");
canIgnoreIneq[lb] = true;
canIgnoreIneq[ub] = true;
}
for (unsigned idx : ineqsToProcess)
processInequality(getIneqCoeffsFromIdx(sI, idx));
}
unsigned offset = simplex.getNumConstraints();
unsigned numLocalsAdded =
b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
simplex.appendVariable(numLocalsAdded);
unsigned snapshotBeforeIntersect = simplex.getSnapshot();
simplex.intersectIntegerRelation(sI);
if (simplex.isEmpty()) {
// b /\ s_i is empty, so b \ s_i = b. We move directly to i + 1.
// We are ignoring level i completely, so we restore the state
// *before* going to the next level. We are "tail recursing", so
// we don't add a frame before going to the next level.
b.truncate(initBCounts);
simplex.rollback(initialSnapshot);
++level;
continue;
}
/// Return the set difference disjunct \ set.
///
/// The disjunct here is modified in subtractRecursively, so it cannot be a
/// const reference even though it is restored to its original state before
/// returning from that function.
static PresburgerRelation getSetDifference(IntegerRelation disjunct,
const PresburgerRelation &set) {
assert(disjunct.isSpaceCompatible(set) && "Spaces should match");
if (disjunct.isEmptyByGCDTest())
return PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());
// Remove duplicate divs up front here as subtractRecursively does not support
// this set having duplicate divs.
disjunct.removeDuplicateDivs();
simplex.detectRedundant();
// Equalities are added to simplex as a pair of inequalities.
unsigned totalNewSimplexInequalities =
2 * sI.getNumEqualities() + sI.getNumInequalities();
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
simplex.rollback(snapshotBeforeIntersect);
SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
if (!canIgnoreIneq[i])
ineqsToProcess.push_back(i);
if (ineqsToProcess.empty()) {
// Nothing to process; return. (we have no frame to pop.)
level = frames.size();
continue;
}
unsigned simplexSnapshot = simplex.getSnapshot();
IntegerRelation::CountsSnapshot bCounts = b.getCounts();
frames.push_back(Frame{simplexSnapshot, bCounts, sI, ineqsToProcess,
/*lastIneqProcessed=*/llvm::None});
// We have completed the initial setup for this level.
// Fallthrough to the main recursive part below.
}
// For each inequality ineq, we first recurse with the part where ineq
// is not satisfied, and then add ineq to b and simplex because
// ineq must be satisfied by all later parts.
if (level == frames.size()) {
Frame &frame = frames.back();
if (frame.lastIneqProcessed) {
// Let the current value of b be b' and
// let the initial value of b when we first came to this level be b.
//
// b' is equal to b /\ s_i1 /\ s_i2 /\ ... /\ s_i{j-1} /\ ~s_ij.
// We had previously recursed with the part where s_ij was not
// satisfied; all further parts satisfy s_ij, so we rollback to the
// state before adding this complement constraint, and add s_ij to b.
simplex.rollback(frame.simplexSnapshot);
b.truncate(frame.bCounts);
SmallVector<int64_t, 8> ineq =
getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed);
b.addInequality(ineq);
simplex.addInequality(ineq);
}
if (frame.ineqsToProcess.empty()) {
// No ineqs left to process; pop this level's frame and return.
frames.pop_back();
level = frames.size();
continue;
}
// "Recurse" with the part where the ineq is not satisfied.
frame.bCounts = b.getCounts();
frame.simplexSnapshot = simplex.getSnapshot();
unsigned idx = frame.ineqsToProcess.back();
SmallVector<int64_t, 8> ineq =
getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx));
b.addInequality(ineq);
simplex.addInequality(ineq);
frame.ineqsToProcess.pop_back();
frame.lastIneqProcessed = idx;
++level;
continue;
}
}
PresburgerRelation result =
PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());
Simplex simplex(disjunct);
subtractRecursively(disjunct, simplex, set, 0, result);
return result;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment