diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eda77392041baffc560992f840f626488292b051..07546c0fd51fff55d870fdf6093f8b95a3340ef5 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1496,6 +1496,7 @@ public: Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) return failure(); + Value source = defOp->getOperand(0); if (extractOp.getType() == source.getType()) return failure(); @@ -1504,10 +1505,10 @@ public: }; unsigned broadcastSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); - // We only consider the case where the rank of the source is smaller than - // the rank of the extract dst. The other cases are handled in the folding - // patterns. - if (extractResultRank <= broadcastSrcRank) + // We only consider the case where the rank of the source is less than or + // equal to the rank of the extract dst. The other cases are handled in the + // folding patterns. + if (extractResultRank < broadcastSrcRank) return failure(); rewriter.replaceOpWithNewOp<vector::BroadcastOp>( extractOp, extractOp.getType(), source); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a083851a4bb981ab271671675e408562d87069e0..8b6640bb06784a2edbd28b3bf0fba226a8d03875 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -566,6 +566,18 @@ func @fold_extract_broadcast(%a : f32) -> vector<4xf32> { // ----- +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: vector<1xf32> +// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32> +// CHECK: return %[[R]] : vector<8xf32> +func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> { + %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32> + %r = vector.extract %b[0] : vector<1x8xf32> + return %r : vector<8xf32> +} + +// ----- + // CHECK-LABEL: func @fold_extract_shapecast // CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32> // CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>