From 2cd50f0162bd32cac664b5651a9ce44b3f9b7575 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 7 Feb 2025 08:01:08 -0800 Subject: [PATCH] [LLVMGPU] Support masked contraction in operand upcasting Currently, there is operands of contraction upcasting that happens in LLVMGPUVectorLowering pass. This commit adds support if its was masked where the upcasting should happen outside of the masking op. Signed-off-by: Manupa Karunaratne --- .../Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp | 21 ++++++++++++------- .../Codegen/LLVMGPU/test/vector_lowering.mlir | 21 +++++++++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp index 07b901f29e3c..68338fce36af 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp @@ -26,11 +26,13 @@ namespace mlir::iree_compiler { namespace { struct PromoteContractOperands final - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::ContractionOp contractOp, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::ContractionOp contractOp, + vector::MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { Type operandElType = getElementTypeOrSelf(contractOp.getLhsType()); Type resultElType = getElementTypeOrSelf(contractOp.getResultType()); @@ -44,11 +46,16 @@ struct PromoteContractOperands final Value rhs = promoteToElementType(loc, rewriter, contractOp.getRhs(), resultElType); - rewriter.replaceOpWithNewOp( - contractOp, lhs, rhs, contractOp.getAcc(), contractOp.getIndexingMaps(), + auto replacement = rewriter.create( + loc, lhs, rhs, contractOp.getAcc(), contractOp.getIndexingMaps(), contractOp.getIteratorTypes()); - return success(); + if (!maskOp) { + return replacement.getResult(); + } + auto maskedOp = vector::maskOperation( + rewriter, replacement, maskOp.getMask(), maskOp.getPassthru()); + return maskedOp->getResult(0); } Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir index 57f73c21ea88..eabce5c1b305 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir @@ -17,3 +17,24 @@ module { // CHECK: %[[SPLAT:.+]] = vector.splat %[[ELEM]] : vector<8xf16> // CHECK: %[[INSERT:.+]] = vector.insert %[[SPLAT]], %[[INIT]] [0] : vector<8xf16> into vector<1x8xf16> // CHECK: return %[[INSERT]] + +// ----- + +module { + func.func @contraction_masked(%lhs: vector<3xf16>, %rhs: vector<2x3xf16>, %acc: vector<2xf32>, %mask: vector<3x2xi1>) -> vector<2xf32> { + %ret = vector.mask %mask { vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind} %lhs, %rhs, %acc : vector<3xf16>, vector<2x3xf16> into vector<2xf32> } : vector<3x2xi1> -> vector<2xf32> + return %ret: vector<2xf32> + } +} + +// CHECK-LABEL: func.func @contraction_masked +// CHECK-SAME: %[[LHS:.+]]: vector<3xf16>, %[[RHS:.+]]: vector<2x3xf16>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<3x2xi1> +// CHECK: %[[TPRHS:.+]] = vector.transpose %[[RHS]], [1, 0] : vector<2x3xf16> to vector<3x2xf16> +// CHECK: %[[RHS_EXTRACT:.+]] = vector.extract %[[TPRHS]][0] : vector<2xf16> from vector<3x2xf16> +// CHECK: %[[LHS_EXTRACT:.+]] = vector.extract %[[LHS]][0] : f16 from vector<3xf16> +// CHECK: %[[RHS_CAST:.+]] = arith.extf %[[RHS_EXTRACT]] : vector<2xf16> to vector<2xf32> +// CHECK: %[[LHS_CAST:.+]] = arith.extf %[[LHS_EXTRACT]] : f16 to f32 +// CHECK: %[[MASK_EXTRACT:.+]] = vector.extract %[[MASK]][0] : vector<2xi1> from vector<3x2xi1> +// CHECK: %[[LHS_SPLAT:.+]] = vector.splat %[[LHS_CAST]] : vector<2xf32> +// CHECK: %[[FMA:.+]] = vector.fma %[[RHS_CAST]], %[[LHS_SPLAT]], %[[ACC]] : vector<2xf32> +// CHECK: arith.select %[[MASK_EXTRACT]], %[[FMA]], %[[ACC]] : vector<2xi1>, vector<2xf32>