diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 07546c0fd51fff55d870fdf6093f8b95a3340ef5..7d9febec632ca649ff0aace603fbf65542c0d198 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -950,7 +950,12 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
 
   Attribute src = operands[0];
   Attribute pos = operands[1];
-  if (!src || !pos)
+
+  // Fold extractelement (splat X) -> X.
+  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
+    return splat.getInput();
+
+  if (!pos || !src)
     return {};
 
   auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8b6640bb06784a2edbd28b3bf0fba226a8d03875..033f17ae2fe127360852f3b498c34dc696d6119c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1409,3 +1409,13 @@ func @extract_element_fold() -> i32 {
   %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
   return %1 : i32
 }
+
+// CHECK-LABEL: func @extract_element_splat_fold
+//  CHECK-SAME: (%[[ARG:.+]]: i32)
+//       CHECK:   return %[[ARG]]
+func @extract_element_splat_fold(%a : i32) -> i32 {
+  %v = vector.splat %a : vector<4xi32>
+  %i = arith.constant 2 : i32
+  %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+  return %1 : i32
+}