From e79b7f501c19784d6160b105a8b84e7fdf28e113 Mon Sep 17 00:00:00 2001
From: jacquesguan <Jianjian.Guan@streamcomputing.com>
Date: Fri, 8 Apr 2022 02:56:34 +0000
Subject: [PATCH] [mlir][Vector] Fold extractelement splat.

This revision supports to fold vector.extractelement (splat X) -> X.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D122960
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  7 ++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 10 ++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 07546c0fd51f..7d9febec632c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -950,7 +950,12 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
 
   Attribute src = operands[0];
   Attribute pos = operands[1];
-  if (!src || !pos)
+
+  // Fold extractelement (splat X) -> X.
+  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
+    return splat.getInput();
+
+  if (!pos || !src)
     return {};
 
   auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8b6640bb0678..033f17ae2fe1 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1409,3 +1409,13 @@ func @extract_element_fold() -> i32 {
   %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
   return %1 : i32
 }
+
+// CHECK-LABEL: func @extract_element_splat_fold
+//  CHECK-SAME: (%[[ARG:.+]]: i32)
+//       CHECK:   return %[[ARG]]
+func @extract_element_splat_fold(%a : i32) -> i32 {
+  %v = vector.splat %a : vector<4xi32>
+  %i = arith.constant 2 : i32
+  %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+  return %1 : i32
+}
-- 
GitLab