From e79b7f501c19784d6160b105a8b84e7fdf28e113 Mon Sep 17 00:00:00 2001 From: jacquesguan <Jianjian.Guan@streamcomputing.com> Date: Fri, 8 Apr 2022 02:56:34 +0000 Subject: [PATCH] [mlir][Vector] Fold extractelement splat. This revision supports to fold vector.extractelement (splat X) -> X. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122960 --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 ++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 07546c0fd51f..7d9febec632c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -950,7 +950,12 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) { Attribute src = operands[0]; Attribute pos = operands[1]; - if (!src || !pos) + + // Fold extractelement (splat X) -> X. + if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) + return splat.getInput(); + + if (!pos || !src) return {}; auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8b6640bb0678..033f17ae2fe1 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1409,3 +1409,13 @@ func @extract_element_fold() -> i32 { %1 = vector.extractelement %v[%i : i32] : vector<4xi32> return %1 : i32 } + +// CHECK-LABEL: func @extract_element_splat_fold +// CHECK-SAME: (%[[ARG:.+]]: i32) +// CHECK: return %[[ARG]] +func @extract_element_splat_fold(%a : i32) -> i32 { + %v = vector.splat %a : vector<4xi32> + %i = arith.constant 2 : i32 + %1 = vector.extractelement %v[%i : i32] : vector<4xi32> + return %1 : i32 +} -- GitLab