diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index fa4d33f4b942e9e326b0280a75e5dd5e2b6c1932..a3554f79f253ed04087cdfeedb53986416a76570 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -375,4 +375,28 @@ class LLVM_VPBinaryI<string mnem> : LLVM_VPBinaryBase<mnem, AnyInteger>; class LLVM_VPBinaryF<string mnem> : LLVM_VPBinaryBase<mnem, AnyFloat>; +class LLVM_VPUnaryBase<string mnem, Type element> + : LLVM_OneResultIntrOp<"vp." # mnem, [0], [], [NoSideEffect]>, + Arguments<(ins LLVM_VectorOf<element>:$op, + LLVM_VectorOf<I1>:$mask, I32:$evl)>; + +class LLVM_VPUnaryF<string mnem> : LLVM_VPUnaryBase<mnem, AnyFloat>; + +class LLVM_VPTernaryBase<string mnem, Type element> + : LLVM_OneResultIntrOp<"vp." # mnem, [0], [], [NoSideEffect]>, + Arguments<(ins LLVM_VectorOf<element>:$op1, LLVM_VectorOf<element>:$op2, + LLVM_VectorOf<element>:$op3, LLVM_VectorOf<I1>:$mask, + I32:$evl)>; + +class LLVM_VPTernaryF<string mnem> : LLVM_VPTernaryBase<mnem, AnyFloat>; + +class LLVM_VPReductionBase<string mnem, Type element> + : LLVM_OneResultIntrOp<"vp.reduce." # mnem, [], [1], [NoSideEffect]>, + Arguments<(ins element:$satrt_value, LLVM_VectorOf<element>:$val, + LLVM_VectorOf<I1>:$mask, I32:$evl)>; + +class LLVM_VPReductionI<string mnem> : LLVM_VPReductionBase<mnem, AnyInteger>; + +class LLVM_VPReductionF<string mnem> : LLVM_VPReductionBase<mnem, AnyFloat>; + #endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index a0b6dcd026602f96b6df6eeff6841da1d99af44d..29a3d5f068535beffbb07410fe069c609c0d8f6e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1948,7 +1948,28 @@ def LLVM_VPFMulOp : LLVM_VPBinaryF<"fmul">; def LLVM_VPFDivOp : LLVM_VPBinaryF<"fdiv">; def LLVM_VPFRemOp : LLVM_VPBinaryF<"frem">; - +// Float Unary +def LLVM_VPFNegOp : LLVM_VPUnaryF<"fneg">; + +// Float Ternary +def LLVM_VPFmaOp : LLVM_VPTernaryF<"fma">; + +// Integer Reduction +def LLVM_VPReduceAddOp : LLVM_VPReductionI<"add">; +def LLVM_VPReduceMulOp : LLVM_VPReductionI<"mul">; +def LLVM_VPReduceAndOp : LLVM_VPReductionI<"and">; +def LLVM_VPReduceOrOp : LLVM_VPReductionI<"or">; +def LLVM_VPReduceXorOp : LLVM_VPReductionI<"xor">; +def LLVM_VPReduceSMaxOp : LLVM_VPReductionI<"smax">; +def LLVM_VPReduceSMinOp : LLVM_VPReductionI<"smin">; +def LLVM_VPReduceUMaxOp : LLVM_VPReductionI<"umax">; +def LLVM_VPReduceUMinOp : LLVM_VPReductionI<"umin">; + +// Float Reduction +def LLVM_VPReduceFAddOp : LLVM_VPReductionF<"fadd">; +def LLVM_VPReduceFMulOp : LLVM_VPReductionF<"fmul">; +def LLVM_VPReduceFMaxOp : LLVM_VPReductionF<"fmax">; +def LLVM_VPReduceFMinOp : LLVM_VPReductionF<"fmin">; #endif // LLVMIR_OPS diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index bde9ce9cc0fcc4150da13eccd00c5f0b3ad50884..fa8666ace6e97a8697b50f2f7f97ede53c53242f 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -518,6 +518,7 @@ llvm.func @stack_restore(%arg0: !llvm.ptr<i8>) { // CHECK-LABEL: @vector_predication_intrinsics llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>, %C: vector<8xf32>, %D: vector<8xf32>, + %i: i32, %f: f32, %mask: vector<8xi1>, %evl: i32) { // CHECK: call <8 x i32> @llvm.vp.add.v8i32 "llvm.intr.vp.add" (%A, %B, %mask, %evl) : @@ -574,6 +575,55 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>, // CHECK: call <8 x float> @llvm.vp.frem.v8f32 "llvm.intr.vp.frem" (%C, %D, %mask, %evl) : (vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.fneg.v8f32 + "llvm.intr.vp.fneg" (%C, %mask, %evl) : + (vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.fma.v8f32 + "llvm.intr.vp.fma" (%C, %D, %D, %mask, %evl) : + (vector<8xf32>, vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + + // CHECK: call i32 @llvm.vp.reduce.add.v8i32 + "llvm.intr.vp.reduce.add" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.mul.v8i32 + "llvm.intr.vp.reduce.mul" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.and.v8i32 + "llvm.intr.vp.reduce.and" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.or.v8i32 + "llvm.intr.vp.reduce.or" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.xor.v8i32 + "llvm.intr.vp.reduce.xor" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.smax.v8i32 + "llvm.intr.vp.reduce.smax" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.smin.v8i32 + "llvm.intr.vp.reduce.smin" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.umax.v8i32 + "llvm.intr.vp.reduce.umax" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.umin.v8i32 + "llvm.intr.vp.reduce.umin" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + + // CHECK: call float @llvm.vp.reduce.fadd.v8f32 + "llvm.intr.vp.reduce.fadd" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + // CHECK: call float @llvm.vp.reduce.fmul.v8f32 + "llvm.intr.vp.reduce.fmul" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + // CHECK: call float @llvm.vp.reduce.fmax.v8f32 + "llvm.intr.vp.reduce.fmax" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + // CHECK: call float @llvm.vp.reduce.fmin.v8f32 + "llvm.intr.vp.reduce.fmin" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + + llvm.return } @@ -650,3 +700,18 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>, // CHECK-DAG: declare <8 x float> @llvm.vp.fmul.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 // CHECK-DAG: declare <8 x float> @llvm.vp.fdiv.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 // CHECK-DAG: declare <8 x float> @llvm.vp.frem.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fneg.v8f32(<8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fma.v8f32(<8 x float>, <8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.add.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.mul.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.and.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.or.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.xor.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.smax.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.smin.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.umax.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.umin.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fadd.v8f32(float, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fmul.v8f32(float, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fmax.v8f32(float, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fmin.v8f32(float, <8 x float>, <8 x i1>, i32) #0