diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eda77392041baffc560992f840f626488292b051..07546c0fd51fff55d870fdf6093f8b95a3340ef5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1496,6 +1496,7 @@ public:
     Operation *defOp = extractOp.getVector().getDefiningOp();
     if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
       return failure();
+
     Value source = defOp->getOperand(0);
     if (extractOp.getType() == source.getType())
       return failure();
@@ -1504,10 +1505,10 @@ public:
     };
     unsigned broadcastSrcRank = getRank(source.getType());
     unsigned extractResultRank = getRank(extractOp.getType());
-    // We only consider the case where the rank of the source is smaller than
-    // the rank of the extract dst. The other cases are handled in the folding
-    // patterns.
-    if (extractResultRank <= broadcastSrcRank)
+    // We only consider the case where the rank of the source is less than or
+    // equal to the rank of the extract dst. The other cases are handled in the
+    // folding patterns.
+    if (extractResultRank < broadcastSrcRank)
       return failure();
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a083851a4bb981ab271671675e408562d87069e0..8b6640bb06784a2edbd28b3bf0fba226a8d03875 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -566,6 +566,18 @@ func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast
+//  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
+//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
+//       CHECK:   return %[[R]] : vector<8xf32>
+func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+  %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
+  %r = vector.extract %b[0] : vector<1x8xf32>
+  return %r : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_shapecast
 //  CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
 //       CHECK:   %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>