diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 07546c0fd51fff55d870fdf6093f8b95a3340ef5..7d9febec632ca649ff0aace603fbf65542c0d198 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -950,7 +950,12 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) { Attribute src = operands[0]; Attribute pos = operands[1]; - if (!src || !pos) + + // Fold extractelement (splat X) -> X. + if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) + return splat.getInput(); + + if (!pos || !src) return {}; auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8b6640bb06784a2edbd28b3bf0fba226a8d03875..033f17ae2fe127360852f3b498c34dc696d6119c 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1409,3 +1409,13 @@ func @extract_element_fold() -> i32 { %1 = vector.extractelement %v[%i : i32] : vector<4xi32> return %1 : i32 } + +// CHECK-LABEL: func @extract_element_splat_fold +// CHECK-SAME: (%[[ARG:.+]]: i32) +// CHECK: return %[[ARG]] +func @extract_element_splat_fold(%a : i32) -> i32 { + %v = vector.splat %a : vector<4xi32> + %i = arith.constant 2 : i32 + %1 = vector.extractelement %v[%i : i32] : vector<4xi32> + return %1 : i32 +}