Skip to content
Snippets Groups Projects
Commit af371f9f authored by River Riddle's avatar River Riddle
Browse files

Reland [GreedPatternRewriter] Preprocess constants while building worklist...

Reland [GreedPatternRewriter] Preprocess constants while building worklist when not processing top down

Reland Note: Adds a fix to properly mark a commutative operation as folded if we change the order
             of its operands. This was uncovered by the fact that we no longer re-process constants.

This avoids accidentally reversing the order of constants during successive
application, e.g. when running the canonicalizer. This helps reduce the number
of iterations, and also avoids unnecessary changes to input IR.

Fixes #51892

Differential Revision: https://reviews.llvm.org/D122692
parent f004ecf6
No related branches found
No related tags found
No related merge requests found
Showing
with 197 additions and 69 deletions
......@@ -569,12 +569,12 @@ end subroutine test_proc_dummy_other
! CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.char<1,40>>,
! CHECK-SAME: %[[VAL_1:.*]]: index,
! CHECK-SAME: %[[VAL_2:.*]]: tuple<!fir.boxproc<() -> ()>, i64> {fir.char_proc}) -> !fir.boxchar<1> {
! CHECK: %[[VAL_3:.*]] = arith.constant 40 : index
! CHECK: %[[VAL_4:.*]] = arith.constant 12 : index
! CHECK: %[[VAL_5:.*]] = arith.constant false
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_7:.*]] = arith.constant 32 : i8
! CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
! CHECK-DAG: %[[VAL_3:.*]] = arith.constant 40 : index
! CHECK-DAG: %[[VAL_4:.*]] = arith.constant 12 : index
! CHECK-DAG: %[[VAL_5:.*]] = arith.constant false
! CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
! CHECK-DAG: %[[VAL_7:.*]] = arith.constant 32 : i8
! CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.char<1,40>>) -> !fir.ref<!fir.char<1,?>>
! CHECK: %[[VAL_10:.*]] = fir.address_of(@_QQcl.{{.*}}) : !fir.ref<!fir.char<1,12>>
! CHECK: %[[VAL_11:.*]] = fir.extract_value %[[VAL_2]], [0 : index] : (tuple<!fir.boxproc<() -> ()>, i64>) -> !fir.boxproc<() -> ()>
......
......@@ -45,6 +45,16 @@ public:
function_ref<void(Operation *)> preReplaceAction = nullptr,
bool *inPlaceUpdate = nullptr);
/// Tries to fold a pre-existing constant operation. `constValue` represents
/// the value of the constant, and can be optionally passed if the value is
/// already known (e.g. if the constant was discovered by m_Constant). This is
/// purely an optimization opportunity for callers that already know the value
/// of the constant. Returns false if an existing constant for `op` already
/// exists in the folder, in which case `op` is replaced and erased.
/// Otherwise, returns true and `op` is inserted into the folder (and
/// hoisted if necessary).
bool insertKnownConstant(Operation *op, Attribute constValue = {});
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.
///
......@@ -114,12 +124,24 @@ private:
using ConstantMap =
DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
/// Returns true if the given operation is an already folded constant that is
/// owned by this folder.
bool isFolderOwnedConstant(Operation *op) const;
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult tryToFold(
OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
function_ref<void(Operation *)> processGeneratedConstants = nullptr);
/// Try to process a set of fold results, generating constants as necessary.
/// Populates `results` on success, otherwise leaves it unchanged.
LogicalResult
processFoldResults(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &results,
ArrayRef<OpFoldResult> foldResults,
function_ref<void(Operation *)> processGeneratedConstants);
/// Try to get or create a new constant entry. On success this returns the
/// constant operation, nullptr otherwise.
Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
......
......@@ -75,8 +75,14 @@ LogicalResult OperationFolder::tryToFold(
// If this is a unique'd constant, return failure as we know that it has
// already been folded.
if (referencedDialects.count(op))
if (isFolderOwnedConstant(op)) {
// Check to see if we should rehoist, i.e. if a non-constant operation was
// inserted before this one.
Block *opBlock = op->getBlock();
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
op->moveBefore(&opBlock->front());
return failure();
}
// Try to fold the operation.
SmallVector<Value, 8> results;
......@@ -104,6 +110,59 @@ LogicalResult OperationFolder::tryToFold(
return success();
}
bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
Block *opBlock = op->getBlock();
// If this is a constant we unique'd, we don't need to insert, but we can
// check to see if we should rehoist it.
if (isFolderOwnedConstant(op)) {
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
op->moveBefore(&opBlock->front());
return true;
}
// Get the constant value of the op if necessary.
if (!constValue) {
matchPattern(op, m_Constant(&constValue));
assert(constValue && "expected `op` to be a constant");
} else {
// Ensure that the provided constant was actually correct.
#ifndef NDEBUG
Attribute expectedValue;
matchPattern(op, m_Constant(&expectedValue));
assert(
expectedValue == constValue &&
"provided constant value was not the expected value of the constant");
#endif
}
// Check for an existing constant operation for the attribute value.
Region *insertRegion = getInsertionRegion(interfaces, opBlock);
auto &uniquedConstants = foldScopes[insertRegion];
Operation *&folderConstOp = uniquedConstants[std::make_tuple(
op->getDialect(), constValue, *op->result_type_begin())];
// If there is an existing constant, replace `op`.
if (folderConstOp) {
op->replaceAllUsesWith(folderConstOp);
op->erase();
return false;
}
// Otherwise, we insert `op`. If `op` is in the insertion block and is either
// already at the front of the block, or the previous operation is already a
// constant we unique'd (i.e. one we inserted), then we don't need to do
// anything. Otherwise, we move the constant to the insertion block.
Block *insertBlock = &insertRegion->front();
if (opBlock != insertBlock || (&insertBlock->front() != op &&
!isFolderOwnedConstant(op->getPrevNode())))
op->moveBefore(&insertBlock->front());
folderConstOp = op;
referencedDialects[op].push_back(op->getDialect());
return true;
}
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.
void OperationFolder::notifyRemoval(Operation *op) {
......@@ -156,19 +215,30 @@ Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
return constOp ? constOp->getResult(0) : Value();
}
bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
return referencedDialects.count(op);
}
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult OperationFolder::tryToFold(
OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
function_ref<void(Operation *)> processGeneratedConstants) {
SmallVector<Attribute, 8> operandConstants;
SmallVector<OpFoldResult, 8> foldResults;
// If this is a commutative operation, move constants to be trailing operands.
bool updatedOpOperands = false;
if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
std::stable_partition(
op->getOpOperands().begin(), op->getOpOperands().end(),
[&](OpOperand &o) { return !matchPattern(o.get(), m_Constant()); });
auto isNonConstant = [&](OpOperand &o) {
return !matchPattern(o.get(), m_Constant());
};
auto *firstConstantIt =
llvm::find_if_not(op->getOpOperands(), isNonConstant);
auto *newConstantIt = std::stable_partition(
firstConstantIt, op->getOpOperands().end(), isNonConstant);
// Remember if we actually moved anything.
updatedOpOperands = firstConstantIt != newConstantIt;
}
// Check to see if any operands to the operation is constant and whether
......@@ -177,10 +247,21 @@ LogicalResult OperationFolder::tryToFold(
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
// Attempt to constant fold the operation.
if (failed(op->fold(operandConstants, foldResults)))
return failure();
// Attempt to constant fold the operation. If we failed, check to see if we at
// least updated the operands of the operation. We treat this as an in-place
// fold.
SmallVector<OpFoldResult, 8> foldResults;
if (failed(op->fold(operandConstants, foldResults)) ||
failed(processFoldResults(builder, op, results, foldResults,
processGeneratedConstants)))
return success(updatedOpOperands);
return success();
}
LogicalResult OperationFolder::processFoldResults(
OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
ArrayRef<OpFoldResult> foldResults,
function_ref<void(Operation *)> processGeneratedConstants) {
// Check to see if the operation was just updated in place.
if (foldResults.empty())
return success();
......@@ -204,8 +285,10 @@ LogicalResult OperationFolder::tryToFold(
// Check if the result was an SSA value.
if (auto repl = foldResults[i].dyn_cast<Value>()) {
if (repl.getType() != op->getResult(i).getType())
if (repl.getType() != op->getResult(i).getType()) {
results.clear();
return failure();
}
results.emplace_back(repl);
continue;
}
......@@ -233,6 +316,7 @@ LogicalResult OperationFolder::tryToFold(
notifyRemoval(&op);
op.erase();
}
results.clear();
return failure();
}
......
......@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/FoldUtils.h"
......@@ -140,8 +141,18 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
for (auto &region : regions)
region.walk([this](Operation *op) { addToWorklist(op); });
for (auto &region : regions) {
region.walk([this](Operation *op) {
// If we aren't processing top-down, check for existing constants when
// populating the worklist. This avoids accidentally reversing the
// constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return;
addToWorklist(op);
});
}
} else {
// Add all nested operations to the worklist in preorder.
for (auto &region : regions)
......
......@@ -244,9 +244,9 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x
// CHECK: }
// CHECK: %[[cst:.*]] = memref.load %[[alloc]][] : memref<vector<3x15xf32>>
// FULL-UNROLL: %[[C7:.*]] = arith.constant 7.000000e+00 : f32
// FULL-UNROLL: %[[VEC0:.*]] = arith.constant dense<7.000000e+00> : vector<3x15xf32>
// FULL-UNROLL: %[[C0:.*]] = arith.constant 0 : index
// FULL-UNROLL-DAG: %[[C7:.*]] = arith.constant 7.000000e+00 : f32
// FULL-UNROLL-DAG: %[[VEC0:.*]] = arith.constant dense<7.000000e+00> : vector<3x15xf32>
// FULL-UNROLL-DAG: %[[C0:.*]] = arith.constant 0 : index
// FULL-UNROLL: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?xf32>
// FULL-UNROLL: cmpi sgt, %[[DIM]], %[[base]] : index
// FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
......
......@@ -5,33 +5,33 @@
// CHECK: %[[MEMREF:.*]]: memref<?xf32>
func @num_worker_threads(%arg0: memref<?xf32>) {
// CHECK: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32
// CHECK: %[[bracketLowerBound4:.*]] = arith.constant 4 : index
// CHECK: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32
// CHECK: %[[bracketLowerBound8:.*]] = arith.constant 8 : index
// CHECK: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32
// CHECK: %[[bracketLowerBound16:.*]] = arith.constant 16 : index
// CHECK: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[bracketLowerBound32:.*]] = arith.constant 32 : index
// CHECK: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32
// CHECK: %[[bracketLowerBound64:.*]] = arith.constant 64 : index
// CHECK: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32
// CHECK: %[[workersIndex:.*]] = async.runtime.num_worker_threads : index
// CHECK: %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index
// CHECK: %[[scalingFactor4:.*]] = arith.select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32
// CHECK: %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index
// CHECK: %[[scalingFactor8:.*]] = arith.select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32
// CHECK: %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index
// CHECK: %[[scalingFactor16:.*]] = arith.select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32
// CHECK: %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index
// CHECK: %[[scalingFactor32:.*]] = arith.select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32
// CHECK: %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index
// CHECK: %[[scalingFactor64:.*]] = arith.select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32
// CHECK: %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32
// CHECK: %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32
// CHECK: %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32
// CHECK: %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32
// CHECK: %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index
// CHECK-DAG: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32
// CHECK-DAG: %[[bracketLowerBound4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32
// CHECK-DAG: %[[bracketLowerBound8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[bracketLowerBound16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[bracketLowerBound32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32
// CHECK-DAG: %[[bracketLowerBound64:.*]] = arith.constant 64 : index
// CHECK-DAG: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32
// CHECK: %[[workersIndex:.*]] = async.runtime.num_worker_threads : index
// CHECK: %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index
// CHECK: %[[scalingFactor4:.*]] = arith.select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32
// CHECK: %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index
// CHECK: %[[scalingFactor8:.*]] = arith.select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32
// CHECK: %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index
// CHECK: %[[scalingFactor16:.*]] = arith.select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32
// CHECK: %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index
// CHECK: %[[scalingFactor32:.*]] = arith.select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32
// CHECK: %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index
// CHECK: %[[scalingFactor64:.*]] = arith.select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32
// CHECK: %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32
// CHECK: %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32
// CHECK: %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32
// CHECK: %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32
// CHECK: %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index
%lb = arith.constant 0 : index
%ub = arith.constant 100 : index
......
......@@ -42,9 +42,9 @@ func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
// CHECK-NEXT: arith.constant 0
// CHECK-NEXT: arith.constant 10
// CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-DAG: arith.constant 0
// CHECK-DAG: arith.constant 10
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
......@@ -106,9 +106,9 @@ func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
// CHECK-NEXT: arith.constant 0
// CHECK-NEXT: arith.constant 10
// CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-DAG: arith.constant 0
// CHECK-DAG: arith.constant 10
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
......@@ -171,9 +171,9 @@ func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
// CHECK-NEXT: arith.constant 0
// CHECK-NEXT: arith.constant 10
// CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-DAG: arith.constant 0
// CHECK-DAG: arith.constant 10
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
......
......@@ -301,7 +301,7 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
return
}
// CHECK-LABEL: func @aligned_promote_fill
// CHECK: %[[cf:.*]] = arith.constant {{.*}} : f32
// CHECK: %[[cf:.*]] = arith.constant 1.{{.*}} : f32
// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
// CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8>
// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
......
......@@ -78,11 +78,11 @@ func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
// CHECK-LABEL: func @dense2(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAL_3:.*]] = arith.constant 32 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
......
......@@ -24,9 +24,9 @@
// CHECK-SAME: %[[VAL_2:.*2]]: f32,
// CHECK-SAME: %[[VAL_3:.*3]]: f32,
// CHECK-SAME: %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
// CHECK: %[[VAL_5:.*]] = arith.constant 2.200000e+00 : f32
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2.200000e+00 : f32
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : f32
// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
......
......@@ -183,9 +183,9 @@ func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
// CHECK-LABEL: func @tensor.generate(
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
......
......@@ -27,8 +27,8 @@ func @pad_non_zero_sizes(%input: tensor<?x?x8xf32>, %low0: index, %high1: index)
return %0 : tensor<?x?x8xf32>
}
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[LOW0]], %[[C0]] : index
// CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[HIGH1]], %[[C0]] : index
// CHECK: %[[AND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1
......
// RUN: mlir-opt -test-patterns %s | FileCheck %s
// RUN: mlir-opt -test-patterns -test-patterns %s | FileCheck %s
func @foo() -> i32 {
%c42 = arith.constant 42 : i32
......@@ -22,3 +22,14 @@ func @test_fold_before_previously_folded_op() -> (i32, i32) {
%1 = "test.cast"() {test_fold_before_previously_folded_op} : () -> (i32)
return %0, %1 : i32, i32
}
func @test_dont_reorder_constants() -> (i32, i32, i32) {
// Test that we don't reorder existing constants during folding if it isn't necessary.
// CHECK: %[[CST:.+]] = arith.constant 1
// CHECK-NEXT: %[[CST:.+]] = arith.constant 2
// CHECK-NEXT: %[[CST:.+]] = arith.constant 3
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
%2 = arith.constant 3 : i32
return %0, %1, %2 : i32, i32, i32
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment