diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index fe198006a166..ddf865b5857b 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -545,7 +545,7 @@ class CUDATargetBackend final : public TargetBackend { }; // Mark the entry point as a kernel. - setMetadataValueI32("kernel", 1); + llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel); // Set the maximum number of threads in the thread block (CTA). auto exportOp = exportOpMap[funcOp.getName()]; diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp index 385b836a716a..913b3fce29b7 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp @@ -269,7 +269,8 @@ static LogicalResult convertStreamableCall(StreamableFunc &streamableFunc, auto calculateCallOp = builder.create( callOp.getLoc(), resultDimTypes, streamableFunc.resultDimsFunc.getLeafReference().getValue(), - callOp.getOperands(), ArrayAttr{}); + callOp.getOperands(), /*tied_operands=*/ArrayAttr{}, + callOp.getArgAttrsAttr(), callOp.getResAttrsAttr()); llvm::append_range(resultDims, calculateCallOp.getResults()); } else { // Get the shape dimensions from existing call arguments or tied operands. diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 6bcf531c14d2..8f7cb459a6d5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -128,6 +128,7 @@ iree_compiler_cc_library( "LowerUKernelsToCalls.cpp", "MaterializeEncoding.cpp", "MaterializeEncodingIntoNop.cpp", + "MaterializeEncodingIntoPadding.cpp", "MaterializeEncodingPatterns.cpp", "MaterializeTuningSpecsPass.cpp", "MemrefCopyToLinalg.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index b204157eb758..1b67c5db261e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -120,6 +120,7 @@ iree_cc_library( "LowerUKernelsToCalls.cpp" "MaterializeEncoding.cpp" "MaterializeEncodingIntoNop.cpp" + "MaterializeEncodingIntoPadding.cpp" "MaterializeEncodingPatterns.cpp" "MaterializeTuningSpecsPass.cpp" "MemrefCopyToLinalg.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp index e5d1b93cbf82..03383ceb4df3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp @@ -8,12 +8,11 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinAttributes.h" -#include +#include namespace mlir::iree_compiler { @@ -101,14 +100,25 @@ getEncodingInfoFromLayouts(RankedTensorType type) { if (!encodingAttr) { return std::nullopt; } - auto layoutsAttr = encodingAttr.getLayouts(); + ArrayAttr layoutsAttr = encodingAttr.getLayouts(); if (!layoutsAttr) { return std::nullopt; } ArrayRef layouts = layoutsAttr.getValue(); assert(layouts.size() == 1 && "only single layout is supported"); - return cast(layouts[0]) - .getEncodingInfo(type); + if (auto layout = dyn_cast(layouts[0])) { + return layout.getEncodingInfo(type); + } + return std::nullopt; +} + +bool isNonZeroPadding(IREE::Encoding::PadEncodingLayoutAttr padLayout) { + if (!padLayout) { + return false; + } + + return !llvm::all_of(padLayout.getPadding().asArrayRef(), + [](int32_t padValue) { return padValue == 0; }); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h index 08a8a5aadbe6..8ee381cd442e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h @@ -97,13 +97,17 @@ FailureOr lowerUnsetEncodingToUnpackOp( Value packedValue, const MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn); -/// Pouplates the set of patterns that lowers operations with encoding types to +/// Populates the set of patterns that lowers operations with encoding types to /// operations without encodings. void populateMaterializeEncodingPatterns( RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target, MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn); +/// Returns true when `padLayout` adds non-zero padding to at least one +/// dimension. +bool isNonZeroPadding(IREE::Encoding::PadEncodingLayoutAttr padLayout); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp index 6b1502c23aac..3fa8485c3759 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp @@ -121,7 +121,8 @@ emitLinkedTuningSpec(ModuleOp module, ArrayRef specsToLink) { operand = builder .create( loc, anyOpType, symbol, - transform::FailurePropagationMode::Suppress, operand) + transform::FailurePropagationMode::Suppress, operand, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr) .getResults() .front(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPadding.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPadding.cpp new file mode 100644 index 000000000000..38795e2ad977 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPadding.cpp @@ -0,0 +1,266 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include "iree/compiler/Codegen/Common/EncodingUtils.h" +#include "iree/compiler/Codegen/Common/PassUtils.h" +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_MATERIALIZEENCODINGINTOPADDINGPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +using namespace IREE::Encoding; + +namespace { + +// Returns the pad encoding layout, or nullptr if this is not the only layout or +// if there's no encoding at all. +static PadEncodingLayoutAttr getPadLayout(RankedTensorType type) { + auto encoding = + dyn_cast_or_null(type.getEncoding()); + if (!encoding) { + return nullptr; + } + ArrayAttr layouts = encoding.getLayouts(); + if (!layouts || layouts.size() != 1) { + return nullptr; + } + + return dyn_cast(*layouts.begin()); +} + +// Returns a padded tensor type (without encoding) for tensor types with the pad +// encoding layout, or the same type for all other tensors. +static RankedTensorType getPaddedType(RankedTensorType type) { + PadEncodingLayoutAttr layout = getPadLayout(type); + if (!isNonZeroPadding(layout)) { + return dropEncoding(type); + } + + ArrayRef padding = layout.getPadding().asArrayRef(); + auto newShape = llvm::to_vector_of(type.getShape()); + for (auto [newDim, padValue] : llvm::zip_equal(newShape, padding)) { + assert(padValue == 0 || !ShapedType::isDynamic(newDim) && + "Padding dynamic dims not supported"); + newDim += padValue; + } + + return RankedTensorType::get(newShape, type.getElementType()); +} + +static bool hasNonZeroPadding(RankedTensorType type) { + return isNonZeroPadding(getPadLayout(type)); +} + +struct MaterializePadEncodingTypeConverter final + : MaterializeEncodingTypeConverter { + MaterializePadEncodingTypeConverter(MLIRContext *ctx) + : MaterializeEncodingTypeConverter( + IREE::Codegen::EncodingNopLayoutAttr::get(ctx)) { + addConversion([](RankedTensorType type) -> std::optional { + if (!getPadLayout(type)) { + // Return `nullopt` so that other conversion functions have a chance to + // handle this type. + return std::nullopt; + } + return getPaddedType(type); + }); + } +}; + +/// Pattern to convert `flow.dispatch.tensor.store` operation when +/// materializing the encoding. We extract a smaller tensor for the padded +/// source. This way we do not create partial loads prematurely, which would be +/// difficult to undo later on. +struct MaterializeFlowDispatchTensorLoadOp final + : OpMaterializeEncodingPattern { + using OpMaterializeEncodingPattern::OpMaterializeEncodingPattern; + + LogicalResult + matchAndRewrite(IREE::Flow::DispatchTensorLoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle operations where the load covers the entire + // `!flow.dispatch.tensor` type. + if (!loadOp.isLoadOfWholeSource()) { + return rewriter.notifyMatchFailure(loadOp, "unhandled partial loads"); + } + + IREE::Flow::DispatchTensorType sourceType = loadOp.getSourceType(); + auto boundTensorType = cast(sourceType.getBoundType()); + if (!hasNonZeroPadding(boundTensorType)) { + // Let the Nop pattern handle this. + return rewriter.notifyMatchFailure(loadOp, "no padding applied"); + } + + auto &typeConverter = + *getTypeConverter(); + auto paddedType = + typeConverter.convertType(boundTensorType); + assert(paddedType != boundTensorType && "Expected conversion with padding"); + + SmallVector newMixedSizes = + getMixedValues(paddedType.getShape(), loadOp.getSourceDims(), rewriter); + + SmallVector newOffsets(newMixedSizes.size(), + rewriter.getIndexAttr(0)); + SmallVector newStrides(newMixedSizes.size(), + rewriter.getIndexAttr(1)); + SmallVector newStaticDims; + SmallVector newDynamicDims; + dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims); + + Location loc = loadOp.getLoc(); + Value newLoad = rewriter.create( + loc, adaptor.getSource(), newDynamicDims, newOffsets, newMixedSizes, + newStrides); + auto extractType = RankedTensorType::get(boundTensorType.getShape(), + boundTensorType.getElementType()); + SmallVector extractSizes = getMixedValues( + boundTensorType.getShape(), loadOp.getSourceDims(), rewriter); + rewriter.replaceOpWithNewOp( + loadOp, extractType, newLoad, newOffsets, extractSizes, newStrides); + return success(); + } +}; + +/// Pattern to convert `flow.dispatch.tensor.store` operation when +/// materializing the encoding. We create a larger empty tensor for the +/// destination and insert the value into it. This way we do not create partial +/// stores prematurely, which would be difficult to undo later on. +struct MaterializeFlowDispatchTensorStoreOp final + : OpMaterializeEncodingPattern { + using OpMaterializeEncodingPattern::OpMaterializeEncodingPattern; + + LogicalResult + matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle operations where the store covers the entire + // `!flow.dispatch.tensor` type. + if (!storeOp.isStoreToWholeTarget()) { + return rewriter.notifyMatchFailure(storeOp, "unhandled partial stores"); + } + + IREE::Flow::DispatchTensorType targetType = storeOp.getTargetType(); + auto boundTensorType = cast(targetType.getBoundType()); + if (!hasNonZeroPadding(boundTensorType)) { + // Let the Nop pattern handle this. + return rewriter.notifyMatchFailure(storeOp, "no padding applied"); + } + + auto &typeConverter = + *getTypeConverter(); + auto paddedType = + typeConverter.convertType(boundTensorType); + assert(paddedType != boundTensorType && "Expected conversion with padding"); + + Location loc = storeOp.getLoc(); + SmallVector dynamicResultSizes{storeOp->getOperands()}; + Value empty = + rewriter.create(loc, paddedType, dynamicResultSizes); + + SmallVector offsets(paddedType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(paddedType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, adaptor.getValue()); + Value insertOp = rewriter.create( + loc, adaptor.getValue(), empty, offsets, sizes, strides); + + SmallVector newMixedSizes = getMixedValues( + paddedType.getShape(), storeOp.getTargetDims(), rewriter); + SmallVector newStaticDims; + SmallVector newDynamicDims; + dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims); + + rewriter.replaceOpWithNewOp( + storeOp, insertOp, adaptor.getTarget(), newDynamicDims, offsets, + newMixedSizes, strides); + return success(); + } +}; + +struct MaterializeEncodingIntoPaddingPass final + : impl::MaterializeEncodingIntoPaddingPassBase< + MaterializeEncodingIntoPaddingPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + FunctionOpInterface operation = getOperation(); + + auto materializeEncodingValueFn = + [](RankedTensorType, OpBuilder &, + Location) -> FailureOr { + return failure(); + }; + + RewritePatternSet materializeEncodingPattern(context); + MaterializePadEncodingTypeConverter typeConverter(context); + MaterializeEncodingConversionTarget target(*context); + populateMaterializeEncodingPatterns(materializeEncodingPattern, target, + typeConverter, + materializeEncodingValueFn); + + // The majority of this conversion is based on the 'Nop' materialization, + // with the exception of a few ops that have to account for padding. + // We add custom patterns with much higher priority to run before the + // equivalent 'Nop' patterns. + materializeEncodingPattern.add( + context, typeConverter, materializeEncodingValueFn, + PatternBenefit{100}); + + if (failed(applyPartialConversion(operation, target, + std::move(materializeEncodingPattern)))) { + operation.emitOpError("materialization failed"); + return signalPassFailure(); + } + + // Add patterns to resolve dims ops and cleanups. + { + RewritePatternSet patterns(context); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); + context->getOrLoadDialect() + ->getCanonicalizationPatterns(patterns); + // TODO: Drop these when we deprecate partial loads/stores. + IREE::Flow::populateTensorSliceOpWithDispatchTensorOpFoldingPatterns( + patterns, context); + if (failed(applyPatternsGreedily(operation, std::move(patterns)))) { + operation.emitOpError("folding patterns failed"); + return signalPassFailure(); + } + } + } +}; +} // namespace + +void addEncodingToPaddingPasses(FunctionLikeNest &passManager) { + passManager.addPass(createMaterializeEncodingIntoPaddingPass) + .addPass(createBufferizeCopyOnlyDispatchesPass) + .addPass(createCanonicalizerPass); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index bb013466e681..b1f99cc34f96 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -50,9 +50,12 @@ void addIREEComprehensiveBufferizePasses( void addConstantBufferizePasses(OpPassManager &funcPassManager); -/// Populate Encoding to Nop pass and canonicalizer pass to the pipeline +/// Populate Encoding to Nop pass and canonicalizer pass to the pipeline. void addEncodingToNopPasses(FunctionLikeNest &passManager); +/// Populate Encoding to padding pass and canonicalizer pass to the pipeline. +void addEncodingToPaddingPasses(FunctionLikeNest &passManager); + /// Links nested transform dialect tuning specs named sequences into a single /// entry point. Returns the new named sequence op (inserted into the `module`) /// that includes the nested tuning specs, or a null op when no nested named diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 785e1d477b20..811fa9ccc588 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -456,6 +456,17 @@ def MaterializeEncodingIntoNopPass : let summary = "Drop the encodings from tensor types with encodings."; } +def MaterializeEncodingIntoPaddingPass : + InterfacePass<"iree-codegen-materialize-encoding-into-padding", "mlir::FunctionOpInterface"> { + let summary = "Materialize `#iree_encoding.pad_encoding_layout` attributes."; + let description = [{ + Handles padding introduced by `pad_encoding_layout` encoding layouts, which + requires `flow.dispatch.tensor.load`/`.store` to be adjusted to account for + padding regions. + Materializes any other encoding layouts into nop. + }]; +} + def MaterializeTuningSpecsPass : Pass<"iree-codegen-materialize-tuning-specs", "ModuleOp"> { let summary = "Load tuning spec transform dialect libraries and encode them in the module"; diff --git a/compiler/src/iree/compiler/Codegen/Common/TileLargeTensors.cpp b/compiler/src/iree/compiler/Codegen/Common/TileLargeTensors.cpp index 4a38227e3ae1..fc67bf5d99dd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileLargeTensors.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileLargeTensors.cpp @@ -179,11 +179,12 @@ static void processRegion(RewriterBase &rewriter, Region *region, // Try to greedily tile + fuse linalg ops. if (auto linalgOp = dyn_cast(op)) { - // Skip copies and transposes. This is based on an expectation that such - // ops are introduced carefully and don't represent significant - // computation anyway. Equivalent generics are still tiled as they - // typically arise organically. - if (isa(op)) { + // Skip copies, transposes, and fills. This is based on an expectation + // that such ops are introduced carefully and don't represent + // significant computation anyway. Equivalent generics are still tiled + // as they typically arise organically. Fills in particular are almost + // never found on their own and will be fused when tiling if need be. + if (isa(op)) { continue; } tileToMaxVectorSize(rewriter, linalgOp, maxVectorSize); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 7b904e40a231..86c0a107d847 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -964,14 +964,6 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply( return listener.checkAndResetError(); } - // 3. Post-bufferization passes are fine. - PassManager pm(getContext()); - addIREEPostBufferizationPasses(pm); - if (failed(pm.run(target))) { - return mlir::emitDefiniteFailure(target) - << "post-bufferization passes failed"; - } - results.set(getOperation()->getOpResult(0), {target}); return listener.checkAndResetError(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 43a40795ac63..10eba99c02ed 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -61,6 +61,7 @@ iree_lit_test_suite( "llvmcpu_materialize_encoding.mlir", "lower_ukernel_to_calls.mlir", "materialize_encoding_into_nop.mlir", + "materialize_encoding_into_padding.mlir", "materialize_tuning_specs.mlir", "materialize_tuning_specs_default_missing.mlir", "materialize_tuning_specs_invalid_spec.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index cfeef07333a6..69c26592ff3c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -57,6 +57,7 @@ iree_lit_test_suite( "llvmcpu_materialize_encoding.mlir" "lower_ukernel_to_calls.mlir" "materialize_encoding_into_nop.mlir" + "materialize_encoding_into_padding.mlir" "materialize_tuning_specs.mlir" "materialize_tuning_specs_default_missing.mlir" "materialize_tuning_specs_invalid_spec.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_into_padding.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_into_padding.mlir new file mode 100644 index 000000000000..8774f4f2dfa9 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_into_padding.mlir @@ -0,0 +1,169 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-encoding-into-padding))" \ +// RUN: --split-input-file %s | FileCheck %s + +#binding_ro = #hal.pipeline.binding +#binding = #hal.pipeline.binding +#encoding_mmt = #iree_encoding.encoding +#pad_encoding = #iree_encoding.encoding]> +func.func @set_pad_encoding_and_store() { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout() ordinal(0) : i32 + %1 = arith.index_castui %0 : i32 to index + %3 = hal.interface.binding.subspan layout() binding(0) alignment(64) offset(%1) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan layout() binding(1) alignment(64) offset(%c0) flags(Indirect) + : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] + : !flow.dispatch.tensor> -> tensor<2048x2048xf16> + %6 = iree_encoding.set_encoding %5 : tensor<2048x2048xf16> -> tensor<2048x2048xf16, #encoding_mmt> + flow.dispatch.tensor.store %6, %4, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] + : tensor<2048x2048xf16, #encoding_mmt> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: @set_pad_encoding_and_store +// CHECK: %[[A:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[B:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[LD:.+]] = flow.dispatch.tensor.load %[[A]], offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> -> tensor<2048x2048xf16> +// CHECK: flow.dispatch.tensor.store %[[LD]], %[[B]], offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] +// CHECK-SAME: tensor<2048x2048xf16> -> !flow.dispatch.tensor> + +// ----- + +#binding_ro = #hal.pipeline.binding +#binding = #hal.pipeline.binding +#encoding_mmt = #iree_encoding.encoding +#pad_encoding = #iree_encoding.encoding]> +func.func @set_zero_pad_encoding_and_store() { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout() ordinal(0) : i32 + %1 = arith.index_castui %0 : i32 to index + %3 = hal.interface.binding.subspan layout() binding(0) alignment(64) offset(%1) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan layout() binding(1) alignment(64) offset(%c0) flags(Indirect) + : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] + : !flow.dispatch.tensor> -> tensor<2048x2048xf16> + %6 = iree_encoding.set_encoding %5 : tensor<2048x2048xf16> -> tensor<2048x2048xf16, #encoding_mmt> + flow.dispatch.tensor.store %6, %4, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] + : tensor<2048x2048xf16, #encoding_mmt> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: @set_zero_pad_encoding_and_store +// CHECK: %[[A:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[B:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[LD:.+]] = flow.dispatch.tensor.load %[[A]], offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> -> tensor<2048x2048xf16> +// CHECK: flow.dispatch.tensor.store %[[LD]], %[[B]], offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] +// CHECK-SAME: tensor<2048x2048xf16> -> !flow.dispatch.tensor> + +// ----- + +#binding_ro = #hal.pipeline.binding +#binding = #hal.pipeline.binding +#encoding_mmt = #iree_encoding.encoding +#pad_encoding = #iree_encoding.encoding]> +func.func @dynamic_set_zero_pad_encoding_and_store() { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout() ordinal(0) : i32 + %1 = arith.index_castui %0 : i32 to index + %2 = hal.interface.constant.load layout() ordinal(1) : i32 + %dynamic_sz = arith.index_castui %2 : i32 to index + %3 = hal.interface.binding.subspan layout() binding(0) alignment(64) offset(%1) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan layout() binding(1) alignment(64) offset(%c0) flags(Indirect) + : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [%dynamic_sz, 2048], strides = [1, 1] + : !flow.dispatch.tensor>{%dynamic_sz} -> tensor + %6 = iree_encoding.set_encoding %5 : tensor -> tensor + flow.dispatch.tensor.store %6, %4, offsets = [0, 0], sizes = [%dynamic_sz, 2048], strides = [1, 1] + : tensor -> !flow.dispatch.tensor>{%dynamic_sz} + return +} + +// CHECK-LABEL: @dynamic_set_zero_pad_encoding_and_store +// CHECK: %[[CST:.+]] = hal.interface.constant.load {{.+}} ordinal(1) : i32 +// CHECK: %[[SZ:.+]] = arith.index_castui %[[CST]] +// CHECK: %[[A:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[B:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[LD:.+]] = flow.dispatch.tensor.load %[[A]], offsets = [0, 0], sizes = [%[[SZ]], 2048], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[SZ]]} -> tensor +// CHECK: flow.dispatch.tensor.store %[[LD]], %[[B]], offsets = [0, 0], sizes = [%[[SZ]], 2048], strides = [1, 1] +// CHECK-SAME: tensor -> !flow.dispatch.tensor>{%[[SZ]]} + +// ----- + +#binding_ro = #hal.pipeline.binding +#binding = #hal.pipeline.binding +#encoding_mmt_lhs = #iree_encoding.encoding +#pad_encoding_lhs = #iree_encoding.encoding]> +#encoding_mmt_rhs = #iree_encoding.encoding +#pad_encoding_rhs = #iree_encoding.encoding]> +#encoding_mmt_out = #iree_encoding.encoding +func.func @load_from_padded_and_mmt() { + %c0 = arith.constant 0 : index + %c8650752 = arith.constant 8650752 : index + %c17301504 = arith.constant 17301504 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan layout() binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout() binding(0) alignment(64) offset(%c8650752) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout() binding(1) alignment(64) offset(%c17301504) flags(Indirect) + : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] + : !flow.dispatch.tensor> -> tensor<2048x2048xf16, #encoding_mmt_lhs> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] + : !flow.dispatch.tensor> -> tensor<2048x2048xf16, #encoding_mmt_rhs> + %5 = tensor.empty() : tensor<2048x2048xf16, #encoding_mmt_out> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x2048xf16, #encoding_mmt_out>) -> tensor<2048x2048xf16, #encoding_mmt_out> + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%3, %4 : tensor<2048x2048xf16, #encoding_mmt_lhs>, tensor<2048x2048xf16, #encoding_mmt_rhs>) + outs(%6 : tensor<2048x2048xf16, #encoding_mmt_out>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %9 = arith.mulf %in, %in_0 : f16 + %10 = arith.addf %out, %9 : f16 + linalg.yield %10 : f16 + } -> tensor<2048x2048xf16, #encoding_mmt_out> + %8 = iree_encoding.unset_encoding %7 : tensor<2048x2048xf16, #encoding_mmt_out> -> tensor<2048x2048xf16> + flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] + : tensor<2048x2048xf16> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: @load_from_padded_and_mmt +// CHECK: %[[A:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[B:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[C:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: %[[LD_A:.+]] = flow.dispatch.tensor.load %[[A]], offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> -> tensor<2048x2048xf16> +// CHECK: %[[LD_B:.+]] = flow.dispatch.tensor.load %[[B]], offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> -> tensor<2048x2048xf16> +// +// CHECK: tensor.empty() : tensor<2048x2048xf16> +// CHECK: %[[FILL:.+]] = linalg.fill {{.+}} : tensor<2048x2048xf16> +// CHECK: %[[MMT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LD_A]], %[[LD_B]] : tensor<2048x2048xf16>, tensor<2048x2048xf16>) +// CHECK-SAME: outs(%[[FILL]] : tensor<2048x2048xf16>) +// +// CHECK: flow.dispatch.tensor.store %[[MMT]], %[[C]], offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] +// CHECK-SAME: tensor<2048x2048xf16> -> !flow.dispatch.tensor> diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir index 7007fc7bbb6e..d021b459d031 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir @@ -113,3 +113,18 @@ func.func @no_tile_copy(%arg0: tensor<64x256xf32>) -> tensor<64x256xf32> { // CHECK-NOT: scf.for // CHECK: %[[COPY:.+]] = linalg.copy // CHECK: return %[[COPY]] + +// ----- + +func.func @no_tile_fill(%arg0: f32) -> tensor<64x256xf32> { + %empty = tensor.empty() : tensor<64x256xf32> + %0 = linalg.fill + ins(%arg0 : f32) + outs(%empty : tensor<64x256xf32>) -> tensor<64x256xf32> + return %0 : tensor<64x256xf32> +} + +// CHECK-LABEL: func.func @no_tile_fill +// CHECK-NOT: scf.for +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK: return %[[FILL]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 788730dbb753..25997bcdfade 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -164,6 +164,8 @@ iree_compiler_cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:ValueBoundsOpInterface", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToArmSME", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 77b92f3f66a4..19043b59c898 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -138,6 +138,8 @@ iree_cc_library( MLIRTransformDialect MLIRTransformUtils MLIRTransforms + MLIRUBDialect + MLIRUBToLLVM MLIRValueBoundsOpInterface MLIRVectorDialect MLIRVectorToArmSME diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp index 0f5fb32a7c1b..1c0fc6b86cd8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp @@ -37,6 +37,7 @@ #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/TosaToArith/TosaToArith.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1058,9 +1059,9 @@ void ConvertToLLVMPass::runOnOperation() { vector::populateVectorStepLoweringPatterns(patterns); populateVectorToLLVMConversionPatterns(typeConverter, patterns, reassociateFpReductions); + ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); vector::populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); - if (isAArch64(targetAttr) && (hasAnySVEFeature(targetAttr) || hasSMEFeature(targetAttr))) { populateArmSVELegalizeForLLVMExportPatterns(typeConverter, patterns); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp index 4a37389f6528..47c545412bd3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp @@ -792,8 +792,8 @@ MemRefDescriptor HALDispatchABI::loadBinding(Operation *forOp, int64_t ordinal, int64_t rank = memRefType.getRank(); // Build MemRef descriptor for this interface binding. - auto desc = MemRefDescriptor::undef(builder, loc, - typeConverter->convertType(memRefType)); + auto desc = MemRefDescriptor::poison( + builder, loc, typeConverter->convertType(memRefType)); desc.setAllocatedPtr(builder, loc, basePtrValue); desc.setAlignedPtr(builder, loc, basePtrValue); auto llvmIndexType = typeConverter->convertType(builder.getIndexType()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir index a2944d9d125c..866f43956822 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir @@ -35,7 +35,10 @@ module attributes {transform.with_named_sequence} { transform.print %8 : !transform.any_op transform.iree.eliminate_empty_tensors %8 : (!transform.any_op) -> () %9 = transform.iree.bufferize %8 : (!transform.any_op) -> !transform.any_op - // %9 = transform.structured.match ops{["func.func"]} in %8 : (!transform.any_op) -> !transform.any_op + %10 = transform.structured.match ops{["func.func"]} in %9 : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %10 { + transform.apply_patterns.canonicalization + } : !transform.op<"func.func"> transform.yield } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index a5c1bce4beda..4e8338511036 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -215,6 +215,8 @@ iree_compiler_cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:ValueBoundsOpInterface", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToGPU", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 5c206210ab30..d789999267ce 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -163,6 +163,8 @@ iree_cc_library( MLIRTransformDialect MLIRTransformUtils MLIRTransforms + MLIRUBDialect + MLIRUBToLLVM MLIRValueBoundsOpInterface MLIRVectorDialect MLIRVectorToGPU diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index 620a8f4fa310..131f45edcc23 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -406,7 +406,7 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { int64_t rank = memrefType.getRank(); // Build MemRef descriptor for this interface binding. - auto desc = MemRefDescriptor::undef( + auto desc = MemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(memrefType)); desc.setAllocatedPtr(rewriter, loc, llvmBufferBasePtr); desc.setAlignedPtr(rewriter, loc, llvmBufferBasePtr); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp index 44172fe4758b..fcec6283617b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp @@ -22,6 +22,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" @@ -29,6 +30,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -53,7 +55,7 @@ struct ConvertToNVVMPass final void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); + NVVM::NVVMDialect, affine::AffineDialect, ub::UBDialect>(); } void runOnOperation() override { ModuleOp m = getOperation(); @@ -161,6 +163,7 @@ struct ConvertToNVVMPass final populateGpuToNVVMConversionPatterns(converter, llvmPatterns); populateNVGPUToNVVMConversionPatterns(converter, llvmPatterns); populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); + ub::populateUBToLLVMConversionPatterns(converter, llvmPatterns); /// Target specification. LLVMConversionTarget target(getContext()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index 850efcf62fab..60355b8f0db0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -22,6 +22,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" @@ -29,6 +30,7 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -99,9 +101,9 @@ struct ConvertToROCDLPass final ConvertToROCDLPass>::ConvertToROCDLPassBase; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() override { ModuleOp m = getOperation(); @@ -238,6 +240,8 @@ struct ConvertToROCDLPass final LLVMConversionTarget target(getContext()); populateFuncToLLVMFuncOpConversionPattern(converter, llvmPatterns); configureGpuToROCDLConversionLegality(target); + ub::populateUBToLLVMConversionPatterns(converter, llvmPatterns); + if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 545a29fd4e45..155bfda2166f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1157,7 +1157,9 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( FunctionLikeNest funcPassManager(modulePassManager); funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); - addEncodingToNopPasses(funcPassManager); + // This materializes into 'nop' in the absence of pad encoding layout + // attributes. + addEncodingToPaddingPasses(funcPassManager); funcPassManager.addPass(createBlockDynamicDimensionsPass); funcPassManager.addPass(createConfigTrackingCanonicalizerPass); funcPassManager.addPass(createCSEPass); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir index 89853b6352a4..24da7997aa53 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir @@ -4,10 +4,14 @@ module attributes { transform.with_named_sequence } { %tensor_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.iree.eliminate_empty_tensors %tensor_func : (!transform.any_op) -> () %memref_func = transform.iree.bufferize %tensor_func : (!transform.any_op) -> !transform.any_op + %func_op_bufferized = transform.structured.match ops{["func.func"]} in %memref_func : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op_bufferized { + transform.apply_patterns.canonicalization + } : !transform.op<"func.func"> // Annotate the exported function as already translated. %none = transform.param.constant #iree_codegen.translation_info -> !transform.any_param - transform.annotate %memref_func "translation_info" = %none : !transform.any_op, !transform.any_param + transform.annotate %func_op_bufferized "translation_info" = %none : !transform.op<"func.func">, !transform.any_param transform.yield } } // module diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir index 57f73c21ea88..9f902ff08b14 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-lowering))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-lowering,canonicalize,cse))" --split-input-file %s | FileCheck %s module { func.func @broadcast_read_lowering(%arg0: memref<4096x32xf16>) -> vector<1x8xf16> { @@ -11,9 +11,8 @@ module { } // CHECK-LABEL: func.func @broadcast_read_lowering // CHECK-SAME: (%[[ARG0:.+]]: memref<4096x32xf16>) -// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x8xf16> // CHECK: %[[LOAD:.+]] = vector.load %[[ARG0]]{{.*}} : memref<4096x32xf16> // CHECK: %[[ELEM:.+]] = vector.extract %[[LOAD]][0] : f16 from vector<1xf16> // CHECK: %[[SPLAT:.+]] = vector.splat %[[ELEM]] : vector<8xf16> -// CHECK: %[[INSERT:.+]] = vector.insert %[[SPLAT]], %[[INIT]] [0] : vector<8xf16> into vector<1x8xf16> +// CHECK: %[[INSERT:.+]] = vector.broadcast %[[SPLAT]] : vector<8xf16> to vector<1x8xf16> // CHECK: return %[[INSERT]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel index d264a26551f9..7ef31be6fc81 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel @@ -154,6 +154,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToSPIRV", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorInterfaces", "@llvm-project//mlir:VectorToGPU", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt index 08ec5885dc97..867786177b35 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt @@ -129,6 +129,7 @@ iree_cc_library( MLIRTransformDialect MLIRTransformUtils MLIRTransforms + MLIRUBToSPIRV MLIRVectorDialect MLIRVectorInterfaces MLIRVectorToGPU diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index ce6984a1cd11..5f2992e667bd 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -34,6 +34,7 @@ #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" +#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -656,6 +657,8 @@ void ConvertToSPIRVPass::runOnOperation() { // Pull in builtin func to spirv.func conversion. populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); + ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); + // Add IREE HAL interface op conversions. patterns.add< HALInterfaceLoadConstantConverter, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir index 2c0d654521a2..299633abdb8f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir @@ -1,5 +1,5 @@ // RUN: iree-opt --split-input-file \ -// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering))' \ +// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse))' \ // RUN: %s | FileCheck %s func.func @add(%lhs: tensor<2x8xf32>, %rhs: tensor<2x8xf32>) -> tensor<2x8xf32> { @@ -48,7 +48,7 @@ func.func @transpose_leading_one_dim(%input: tensor<4x1x1xf32>) -> tensor<1x1x4x // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[ZERO:.+]] = ub.poison : vector<4xf32> // CHECK: %[[R0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]{{.+}} : tensor<4x1x1xf32>, vector<1xf32> // CHECK: %[[R1:.+]] = vector.transfer_read %[[INPUT]][%[[C1]], %[[C0]], %[[C0]]]{{.+}} : tensor<4x1x1xf32>, vector<1xf32> @@ -93,7 +93,7 @@ func.func @transpose_add(%lhs: tensor<4x2xf32>, %rhs: tensor<2xf32>) -> tensor<2 // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[OINIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[OINIT:.+]] = ub.poison : vector<4xf32> // CHECK: %[[LHS0:.+]] = vector.transfer_read %[[LHS]][%[[C0]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32> // CHECK: %[[LHS1:.+]] = vector.transfer_read %[[LHS]][%[[C1]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32> diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir index 03644e7c9bec..f0c42b154e1e 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir @@ -1,5 +1,5 @@ // RUN: iree-opt --split-input-file \ -// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering))' \ +// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse))' \ // RUN: %s | FileCheck %s func.func @matmul_1x4x4(%lhs: tensor<1x4xf32>, %rhs: tensor<4x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> { @@ -139,9 +139,7 @@ func.func @matmul_broadcast_add(%init: tensor<1x8xf32>, %a: tensor<1x8xf32>, %b: // CHECK: %[[EXT0:.+]] = vector.extract %[[READ]][0] : f32 from vector<1xf32> // CHECK: %[[BCST0:.+]] = vector.splat %[[EXT0]] : vector<4xf32> // CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[BCST0]] : vector<4xf32> -// CHECK: %[[EXT1:.+]] = vector.extract %[[READ]][0] : f32 from vector<1xf32> -// CHECK: %[[BCST1:.+]] = vector.splat %[[EXT1]] : vector<4xf32> -// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[BCST1]] : vector<4xf32> +// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[BCST0]] : vector<4xf32> // CHECK: %[[WRITE0:.+]] = vector.transfer_write %[[ADD0]], %[[INIT]][%[[C0]], %[[C0]]] // CHECK: %[[WRITE1:.+]] = vector.transfer_write %[[ADD1]], %[[WRITE0]][%[[C0]], %[[C4]]] // CHECK: return %[[WRITE1]] @@ -287,7 +285,7 @@ func.func @matmul_4x4x4_i8_to_i32_dot_prod(%lhs: tensor<4x4xi8>, %rhs : tensor<4 // CHECK-SAME: (%[[LHS:.+]]: tensor<4x4xi8>, %[[RHS:.+]]: tensor<4x4xi8>) // CHECK-DAG: %[[C0I8:.+]] = arith.constant 0 : i8 // CHECK-DAG: %[[C0I32:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[V4I8:.+]] = arith.constant dense<0> : vector<4xi8> +// CHECK-DAG: %[[V4I8:.+]] = ub.poison : vector<4xi8> // CHECK-DAG: %[[V4I32:.+]] = arith.constant dense<0> : vector<4xi32> // CHECK-DAG: %[[V1I32:.+]] = arith.constant dense<0> : vector<1xi32> // CHECK-DAG: %[[IDX0:.+]] = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index f4c289aec413..be7734a2d08a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -1566,7 +1566,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (!argAttrs.empty() || !resAttrs.empty()) { assert(type.getNumInputs() == argAttrs.size()); assert(type.getNumResults() == resAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, resAttrs, builder.getStringAttr("arg_attrs"), builder.getStringAttr("res_attrs")); } @@ -1998,7 +1998,7 @@ struct FoldTensorLoadWithExtractSlice }; /// Pattern to fold `tensor.insert_slice` with `flow.dispatch.tensor.store` -/// oeprations. +/// operations. // TODO(ravishankarm): Eventually this should go in as a canonicalization at the // Flow level. struct FoldInsertSliceWithTensorStoreOp @@ -2014,13 +2014,13 @@ struct FoldInsertSliceWithTensorStoreOp return failure(); // Check that the `dest` of the `tensor.insert_slice` and target of the - // `flow.dispatch.tensor.store` are the same interface binding. + // `flow.dispatch.tensor.store` are the same interface binding, if these + // are still in a dispatch region. std::optional destBinding = getBindingArgument(insertSliceOp.getDest()); std::optional targetBinding = getBindingArgument(dispatchTensorStoreOp.getTarget()); - if (!destBinding || !targetBinding || - destBinding.value() != targetBinding.value()) { + if (destBinding != targetBinding) { return failure(); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td index 0241a843b906..a73271ffea41 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -964,7 +964,9 @@ def FLOW_CallOp : FLOW_Op<"call", [ Variadic:$arguments, FLOW_ShapeDynamicDims:$argument_dims, FLOW_ShapeDynamicDims:$result_dims, - OptionalAttr:$tied_operands + OptionalAttr:$tied_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 748483349f98..e7522ff0694c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -924,7 +924,8 @@ struct CmdCallOpPattern rewriter.replaceOpWithNewOp( callOp, resultTypes, callOp.getCallee(), operands, - /*tied_operands=*/ArrayAttr{}); + /*tied_operands=*/ArrayAttr{}, callOp.getArgAttrsAttr(), + callOp.getResAttrsAttr()); return success(); } }; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index cb5bb411810a..a06838e8c361 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -1778,7 +1778,7 @@ ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser, bool isVariadic = false; SmallVector resultAttrs; SmallVector resultTypes; - if (mlir::function_interface_impl::parseFunctionSignature( + if (mlir::function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, resultAttrs)) { return failure(); @@ -1826,7 +1826,7 @@ ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser, // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - mlir::function_interface_impl::addArgAndResultAttrs( + mlir::call_interface_impl::addArgAndResultAttrs( builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 9d7cc63165b9..83064ccdd96e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -18,9 +18,11 @@ #include "llvm/Support/Path.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/Passes.h" // NOTE: redundant bindings will result in unique buffer locations during the @@ -447,14 +449,9 @@ buildBenchmarkModule(IREE::HAL::ExecutableOp sourceExecutableOp, if (!hasAnyBenchmarks) return {}; - // Run CSE and the canonicalizer to pretty up the output. - PassManager passManager(moduleOp->getContext()); - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - if (failed(passManager.run(*moduleOp))) { - moduleOp->emitError("failed to run canonicalizer; malformed output"); - return {}; - } + IRRewriter rewriter(moduleOp->getContext()); + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, moduleOp.get()); return moduleOp; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 00288cc640d6..9b0ca9e209c1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -960,7 +960,7 @@ struct ConvertCallOp : public AffinityOpConversionPattern { op.getLoc(), resultTypes, adaptor.getCalleeAttr(), callOperands, callOperandSizes, callOperandOffsets, callOperandEnds, callOperandLengths, resultSizes, adaptor.getTiedOperandsAttr(), - executionAffinityAttr); + op.getArgAttrsAttr(), op.getResAttrsAttr(), executionAffinityAttr); newOp->setDialectAttrs(op->getDialectAttrs()); replaceOpWithMultiple(op, newOp->getResults(), resultSizes, rewriter); return success(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index 13988a999b2f..7f50ee2f355d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -2764,7 +2764,7 @@ void AsyncFuncOp::build(OpBuilder &builder, OperationState &state, if (!argAttrs.empty() || !resAttrs.empty()) { assert(type.getNumInputs() == argAttrs.size()); assert(type.getNumResults() == resAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, resAttrs, builder.getStringAttr("arg_attrs"), builder.getStringAttr("res_attrs")); } @@ -3555,7 +3555,7 @@ void CmdFuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (!argAttrs.empty() || !resAttrs.empty()) { assert(type.getNumInputs() == argAttrs.size()); assert(type.getNumResults() == resAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, resAttrs, builder.getStringAttr("arg_attrs"), builder.getStringAttr("res_attrs")); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index 62a44d5bac66..4a1b7d28a728 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -2684,6 +2684,8 @@ def Stream_AsyncCallOp : Stream_Op<"async.call", [ Variadic:$resource_operand_lengths, Variadic:$result_sizes, OptionalAttr:$tied_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$affinity ); let results = (outs @@ -3368,6 +3370,8 @@ def Stream_CmdCallOp : Stream_Op<"cmd.call", [ Variadic:$resource_operand_lengths, Variadic:$result_sizes, OptionalAttr:$tied_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, Stream_ResourceAccessArrayAttr:$resource_operand_accesses ); let results = (outs diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp index b7ce5dfece90..e99948db4e20 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp @@ -927,7 +927,8 @@ static LogicalResult applyAsyncCallOp(IREE::Stream::AsyncCallOp asyncOp, newResourceOperands, newResourceSizes, newResourceOffsets, newResourceLengths, /*result_sizes=*/ValueRange{}, - /*tied_operands=*/nullptr, builder.getArrayAttr(newResourceAccesses)); + /*tied_operands=*/nullptr, asyncOp.getArgAttrsAttr(), + asyncOp.getResAttrsAttr(), builder.getArrayAttr(newResourceAccesses)); newOp->setDialectAttrs(asyncOp->getDialectAttrs()); asyncOp.erase(); return success(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp index b3c4cc89f19a..92310ab34721 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp @@ -358,6 +358,22 @@ updateTensorSizeOfOp(RewriterBase &rewriter, return success(); } +/// Updates the target encoding of `op` with resolved layouts. +static LogicalResult +updateTensorFillOp(RewriterBase &rewriter, IREE::Stream::TensorFillOp op, + const SetVector &layoutResolvers) { + auto encodingType = dyn_cast(op.getTargetEncoding()); + std::optional encodingAttr = + getEncodingWithNewLayouts(encodingType, layoutResolvers); + if (!encodingAttr) { + return success(); + } + rewriter.modifyOpInPlace(op, [&] { + op.setTargetEncoding(cloneWithEncoding(encodingType, encodingAttr.value())); + }); + return success(); +} + /// Returns failure if `op` has encoding. The EncodingAttr has padding /// semantic, a constant op with such encoding can not be resolved at this /// moment. @@ -375,7 +391,70 @@ updateTensorConstantOp(RewriterBase &rewriter, return success(); } -/// Updates the result_encoding for `op`. The op have to define a +/// Returns a failure if there are encodings in target encoding type or update +/// encoding type. +static LogicalResult updateTensorUpdateOp(RewriterBase &rewriter, + IREE::Stream::TensorUpdateOp op) { + auto targetEncodingType = dyn_cast(op.getTargetEncoding()); + if (targetEncodingType && targetEncodingType.getEncoding()) { + return failure(); + } + auto updateEncodingType = dyn_cast(op.getUpdateEncoding()); + if (updateEncodingType && updateEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Returns a failure if there are encodings in source encoding type or result +/// encoding type. +static LogicalResult updateTensorCloneOp(RewriterBase &rewriter, + IREE::Stream::TensorCloneOp op) { + auto sourceEncodingType = dyn_cast(op.getSourceEncoding()); + if (sourceEncodingType && sourceEncodingType.getEncoding()) { + return failure(); + } + auto resultEncodingType = dyn_cast(op.getResultEncoding()); + if (resultEncodingType && resultEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Returns a failure if there are encodings in source encoding type or result +/// encoding type. +static LogicalResult updateTensorSliceOp(RewriterBase &rewriter, + IREE::Stream::TensorSliceOp op) { + auto sourceEncodingType = dyn_cast(op.getSourceEncoding()); + if (sourceEncodingType && sourceEncodingType.getEncoding()) { + return failure(); + } + auto resultEncodingType = dyn_cast(op.getResultEncoding()); + if (resultEncodingType && resultEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Updates the source_encoding for `op`. The op has to define a +/// `source_encoding` parameter. +template +static LogicalResult +updateSourceEncoding(RewriterBase &rewriter, OpTy op, + const SetVector &layoutResolvers) { + auto encodingType = dyn_cast(op.getSourceEncoding()); + std::optional encodingAttr = + getEncodingWithNewLayouts(encodingType, layoutResolvers); + if (!encodingAttr) { + return success(); + } + rewriter.modifyOpInPlace(op, [&] { + op.setSourceEncoding(cloneWithEncoding(encodingType, encodingAttr.value())); + }); + return success(); +} + +/// Updates the result_encoding for `op`. The op has to define a /// `result_encoding` parameter. template static LogicalResult @@ -393,6 +472,16 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op, return success(); } +/// Adds the resolved layouts to all tensor types on stream tensor ops, if +/// encodings are present. Most of stream tensor ops implement +/// AffinityOpInterface, where a stream affinity indicates the kind of +/// enviroment the ops are expected run in. When an encoding is present in the +/// tensor type, the method resolves the layouts, strips outdated information, +/// and adds the resolved layouts to the encodings. The updated encodings should +/// have enough information for other lowering transformations. +/// TODO(hanchung): Add support for stream.tensor.load ops and +/// stream.tensor.store ops. They are not affinity ops, so additional analysis +/// will be needed in the work. static LogicalResult addLayoutsToTensorPhaseOps( ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis, FunctionOpInterface funcOp, @@ -424,7 +513,6 @@ static LogicalResult addLayoutsToTensorPhaseOps( return affinityOp.emitError("failed on making layout resolvers"); } - // TODO(hanchung): Update other Stream operations. LogicalResult result = TypeSwitch(affinityOp) .Case([&](auto op) { @@ -442,7 +530,18 @@ static LogicalResult addLayoutsToTensorPhaseOps( .Case([&](auto op) { return updateTensorConstantOp(rewriter, op, layoutResolvers); }) - .Default([](auto *op) { return failure(); }); + .Case([&](auto op) { + return updateTensorFillOp(rewriter, op, layoutResolvers); + }) + .Case( + [&](auto op) { return updateTensorCloneOp(rewriter, op); }) + .Case( + [&](auto op) { return updateTensorSliceOp(rewriter, op); }) + .Case( + [&](auto op) { return updateTensorUpdateOp(rewriter, op); }) + .Default([](Operation *op) { + return op->emitOpError("Unhandled stream op"); + }); if (failed(result)) { return failure(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir index 2c71a86e1639..f57664bcce95 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir @@ -65,12 +65,39 @@ module { // ----- +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + util.func public @tensor_fill_op(%arg0: f32, %arg1: !stream.resource<*>, %arg2: index, %arg3: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = stream.tensor.fill on(#hal.device.affinity<@device_a>) + %arg0, %arg1[%c0, %c0 for %c1, %c1] : f32 + -> tensor{%arg2} in %arg1 as !stream.resource<*>{%arg3} + util.return + } +} +// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<{{.+}} layouts = [#iree_encoding.specialized_encoding<123, tensor>] +// CHECK: #[[TARGET:.+]] = #hal.device.target +// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]] +// CHECK-LABEL: util.func public @tensor_fill_op +// CHECK: stream.tensor.fill on(#hal.device.affinity<@[[$DEVICE]]>) +// CHECK-SAME: f32 -> tensor + +// ----- + // Checks that the stream.tensor.constant op with encoding is not supported. #map0 = affine_map<(m, n, k) -> (m, k)> #map1 = affine_map<(m, n, k) -> (k, n)> #map2 = affine_map<(m, n, k) -> (m, n)> -#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> #device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device #encoding = #iree_encoding.encoding module { @@ -85,6 +112,76 @@ module { // ----- +// Checks that the stream.tensor.clone op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_clone_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %0 = stream.tensor.clone on(#hal.device.affinity<@device_a>) + %arg0 : tensor{%arg1} in !stream.resource<*>{%arg2} + -> tensor{%arg1} in !stream.resource<*>{%arg2} + util.return + } +} + +// ----- + +// Checks that the stream.tensor.slice op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_slice_op_with_encoding(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = stream.tensor.slice on(#hal.device.affinity<@device_a>) + %arg0[%c0, %c1 for %arg3, %c1] : tensor{%arg1} in !stream.resource<*>{%arg2} + -> tensor{%arg3} in !stream.resource<*>{%arg4} + util.return + } +} + +// ----- + +// Checks that the stream.tensor.update op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_update_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = stream.tensor.update on(#hal.device.affinity<@device_a>) + %arg0, %arg2[%c0, %c0] : tensor<2x2xf32, #encoding> in !stream.resource<*>{%arg1} + -> tensor{%arg3} in %arg2 as !stream.resource<*>{%arg4} + util.return + } +} + +// ----- + #executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> #map = affine_map<(d0) -> (d0)> #map0 = affine_map<(m, n, k) -> (m, k)> diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp index 310e9f45b4ed..346660f9386d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp @@ -151,7 +151,8 @@ struct ConvertCallOp : public OpConversionPattern { } auto newOp = rewriter.replaceOpWithNewOp( op, resultTypes, op.getCallee(), adaptor.getOperands(), - adaptor.getTiedOperandsAttr()); + adaptor.getTiedOperandsAttr(), adaptor.getArgAttrsAttr(), + adaptor.getResAttrsAttr()); newOp->setDialectAttrs(op->getDialectAttrs()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index c9e470fd443e..1d417098d002 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -1624,7 +1624,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (!argAttrs.empty() || !resAttrs.empty()) { assert(type.getNumInputs() == argAttrs.size()); assert(type.getNumResults() == resAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, resAttrs, builder.getStringAttr("arg_attrs"), builder.getStringAttr("res_attrs")); } @@ -1716,7 +1716,7 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { result.attributes.append(parsedAttributes); assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, result, arguments, resultAttrs, builder.getStringAttr("arg_attrs"), builder.getStringAttr("res_attrs")); @@ -1913,7 +1913,8 @@ IREE::Util::CallOp IREE::Util::CallOp::cloneAndExpand( return builder.create( getLoc(), newResultTypes, getCallee(), newOperands, - builder.getIndexArrayAttr(newTiedOperands)); + builder.getIndexArrayAttr(newTiedOperands), getArgAttrsAttr(), + getResAttrsAttr()); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index 6dded8984df6..5b57c13172ad 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -778,7 +778,9 @@ def Util_CallOp : Util_Op<"call", [ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands, - OptionalAttr:$tied_operands + OptionalAttr:$tied_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results @@ -792,7 +794,7 @@ def Util_CallOp : Util_Op<"call", [ CArg<"ArrayRef", "{}">:$attrs ), [{ build($_builder, $_state, callee.getResultTypes(), callee.getName(), - operands, tied_operands); + operands, tied_operands, /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); $_state.addAttributes(attrs); }]>, ]; diff --git a/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp b/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp index e43447d7a836..c327585477ef 100644 --- a/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp @@ -243,7 +243,8 @@ DiagnosedSilenceableFailure IREE::Util::transform_dialect::CastAndCallOp::apply( auto callOp = rewriter.create( insertionPoint->getLoc(), targetFunction.getResultTypes(), - targetFunction.getName(), inputs, /*tied_operands=*/ArrayAttr{}); + targetFunction.getName(), inputs, /*tied_operands=*/ArrayAttr{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); // Cast the call results back to the expected types. If any conversions fail // this is a definite failure as the call has been constructed at this point. diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp index 8c284781e4fe..dc844a81da57 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp @@ -610,10 +610,11 @@ static bool applyCallChanges(FuncAnalysis &analysis, return false; // Fully replace call op because we may have changed result count. - // TODO(benvanik): update tied operands. + // TODO(benvanik): update tied operands, arg_attrs, and res_attrs. auto newCallOp = OpBuilder(callOp).create( callOp.getLoc(), newResultTypes, callOp.getCalleeAttr(), newOperands, - /*tied_operands=*/ArrayAttr{}); + /*tied_operands=*/ArrayAttr{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); newCallOp->setDialectAttrs(callOp->getDialectAttrs()); // Remap live old results -> new results. diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index 7b43ecae5f65..19c151fbb162 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -120,10 +120,10 @@ void FuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, assert(type.getNumInputs() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs( - builder, result, argAttrs, - /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(result.name), - getResAttrsAttrName(result.name)); + call_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, + /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); } Block *FuncOp::addEntryBlock() { @@ -289,10 +289,10 @@ ParseResult ImportOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getCurrentLocation()) << "invalid result type list"; } - function_interface_impl::addArgAndResultAttrs( - builder, result, argAttrs, - /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(result.name), - getResAttrsAttrName(result.name)); + call_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, + /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) { return failure(); } @@ -358,10 +358,10 @@ void ImportOp::build(OpBuilder &builder, OperationState &result, StringRef name, if (!argAttrs.empty()) { assert(type.getNumInputs() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs( - builder, result, argAttrs, - /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(result.name), - getResAttrsAttrName(result.name)); + call_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, + /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); } result.addRegion(); diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index a620f5511dc3..c7626b21d94b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -4242,7 +4242,9 @@ def VM_CallOp : VM_CallBaseOp<"call"> { let arguments = (ins VM_FuncRefAttr:$callee, - Variadic:$operands + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results @@ -4301,7 +4303,9 @@ def VM_CallVariadicOp : VM_CallBaseOp<"call.variadic"> { VM_FuncRefAttr:$callee, SignlessIntElementsAttr<16>:$segment_sizes, TypeArrayAttr:$segment_types, - Variadic:$operands + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp index 09ee551b8df3..b2c6c6ea7172 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp @@ -396,7 +396,7 @@ class FuncCallOpPattern : public OpConversionPattern { srcOp->getAttrOfType("iree.abi.tied_operands"); rewriter.replaceOpWithNewOp( srcOp, resultTypes, srcOp.getCallee(), adaptor.getOperands(), - tiedOperandsAttr); + tiedOperandsAttr, srcOp.getArgAttrsAttr(), srcOp.getResAttrsAttr()); return success(); } }; diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp index d0e643d161ee..fb88a6b2212f 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp @@ -498,7 +498,8 @@ struct CmdDispatchOpPattern llvm::append_range(callArgs, adaptor.getResourceLengths()); rewriter.replaceOpWithNewOp( dispatchOp, TypeRange{}, callee.getLeafReference(), callArgs, - /*tied_operands=*/ArrayAttr{}); + /*tied_operands=*/ArrayAttr{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); return success(); } }; @@ -564,7 +565,8 @@ struct CmdCallOpPattern : public OpConversionPattern { rewriter.replaceOpWithNewOp( callOp, resultTypes, callOp.getCallee(), operands, - /*tied_operands=*/ArrayAttr{}); + /*tied_operands=*/ArrayAttr{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); return success(); } }; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir index 59e1c0a67844..4b40bac5b474 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir @@ -136,3 +136,52 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +module attributes {transform.with_named_sequence} { + + // CHECK: func.func @matmul_repeated_operand + func.func @matmul_repeated_operand(%input: tensor<32x64xi8>, %dest: tensor<32x32xi32>) -> tensor<32x32xi32> { + // CHECK-NEXT: linalg.matmul_transpose_b + // CHECK-SAME: match_status = "matched" + %res = linalg.matmul_transpose_b {match_status = "unmatched"} + ins(%input, %input : tensor<32x64xi8>, tensor<32x64xi8>) + outs(%dest : tensor<32x32xi32>) -> tensor<32x32xi32> + return %res : tensor<32x32xi32> + } + + // CHECK: func.func @matmul_non_repeated_operand + func.func @matmul_non_repeated_operand(%input0: tensor<32x64xi8>, %input1: tensor<32x64xi8>, %dest: tensor<32x32xi32>) -> tensor<32x32xi32> { + // CHECK-NEXT: linalg.matmul_transpose_b + // CHECK-SAME: match_status = "unmatched" + %res = linalg.matmul_transpose_b {match_status = "unmatched"} + ins(%input0, %input1 : tensor<32x64xi8>, tensor<32x64xi8>) + outs(%dest : tensor<32x32xi32>) -> tensor<32x32xi32> + return %res : tensor<32x32xi32> + } + + transform.named_sequence @match_matmul_repeated_operand(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %inputs, %outputs = transform.iree.match.cast_compatible_dag_from_root %arg0 { + ^bb0(%arg1: tensor<32x64xi8>, %arg2: tensor<32x32xi32>): + %1 = linalg.matmul_transpose_b {match_status = "unmatched"} + ins(%arg1, %arg1 : tensor<32x64xi8>, tensor<32x64xi8>) + outs(%arg2 : tensor<32x32xi32>) -> tensor<32x32xi32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @annotate(%generic: !transform.any_op {transform.readonly}) { + %0 = transform.param.constant "matched" -> !transform.any_param + transform.annotate %generic "match_status" = %0 : !transform.any_op, !transform.any_param + transform.yield + } + + transform.named_sequence @__transform_main(%module: !transform.any_op) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op + transform.foreach_match in %module + @match_matmul_repeated_operand -> @annotate + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp index 899e2f3a6d66..57fb0bb2f1e7 100644 --- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp +++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp @@ -177,7 +177,7 @@ IREE::transform_dialect::MatchCastCompatibleDagFromRootOp::matchOperation( return emitDefiniteFailure() << "Invalid block argument in target"; } int64_t argIdx = targetBlockArg.getArgNumber(); - if (inputs[argIdx] && inputs[argIdx] != targetOperand) { + if (inputs[argIdx] && inputs[argIdx] != payloadOperand) { return emitSilenceableError() << "input operand with conflicting uses"; } diff --git a/docs/website/docs/guides/deployment-configurations/cpu.md b/docs/website/docs/guides/deployment-configurations/cpu.md index d9667f032e74..c062613c2f82 100644 --- a/docs/website/docs/guides/deployment-configurations/cpu.md +++ b/docs/website/docs/guides/deployment-configurations/cpu.md @@ -13,27 +13,32 @@ IREE supports efficient program execution on CPU devices by using highly optimized CPU native instruction streams, which are embedded in one of IREE's deployable formats. -To compile a program for CPU execution, pick one of IREE's supported executable -formats: +To compile a program for CPU execution: -| Executable Format | Description | -| ----------------- | ----------------------------------------------------- | -| embedded ELF | portable, high performance dynamic library | -| system library | platform-specific dynamic library (.so, .dll, etc.) | -| VMVX | reference target | +1. Pick a CPU target supported by LLVM. By default, IREE includes these LLVM + targets: -At runtime, CPU executables can be loaded using one of IREE's CPU HAL drivers: + * X86 + * ARM + * AArch64 + * RISCV -* `local-task`: asynchronous, multithreaded driver built on IREE's "task" - system -* `local-sync`: synchronous, single-threaded driver that executes work inline + Other targets may work, but in-tree test coverage and performance work is + focused on that list. + +2. Pick one of IREE's supported executable formats: -!!! todo + | Executable Format | Description | + | ----------------- | ----------------------------------------------------- | + | Embedded ELF | (Default) Portable, high performance dynamic library | + | System library | Platform-specific dynamic library (.so, .dll, etc.) | + | VMVX | Reference target | - Add IREE's CPU support matrix: what architectures are supported; what - architectures are well optimized; etc. +At runtime, CPU executables can be loaded using one of IREE's CPU HAL devices: - +* `local-task`: asynchronous, multithreaded device built on IREE's "task" + system +* `local-sync`: synchronous, single-threaded devices that executes work inline ## :octicons-download-16: Prerequisites @@ -44,7 +49,7 @@ At runtime, CPU executables can be loaded using one of IREE's CPU HAL drivers: Python packages are distributed through multiple channels. See the [Python Bindings](../../reference/bindings/python.md) page for more details. The core [`iree-base-compiler`](https://pypi.org/project/iree-base-compiler/) -package includes the LLVM-based CPU compiler: +package includes the compiler tools: --8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md" @@ -52,14 +57,9 @@ package includes the LLVM-based CPU compiler: Please make sure you have followed the [Getting started](../../building-from-source/getting-started.md) page to build -IREE for your host platform and the -[Android cross-compilation](../../building-from-source/android.md) or -[iOS cross-compilation](../../building-from-source/ios.md) page if you are cross -compiling for a mobile device. The `llvm-cpu` compiler backend is compiled in by -default on all platforms. - -Ensure that the `IREE_TARGET_BACKEND_LLVM_CPU` CMake option is `ON` when -configuring for the host. +IREE for your host platform. The `llvm-cpu` compiler backend is compiled in by +default on all platforms, though you should ensure that the +`IREE_TARGET_BACKEND_LLVM_CPU` CMake option is `ON` when configuring. !!! tip `iree-compile` will be built under the `iree-build/tools/` directory. You @@ -71,10 +71,14 @@ You will need to get an IREE runtime that supports the local CPU HAL driver, along with the appropriate executable loaders for your application. You can check for CPU support by looking for the `local-sync` and `local-task` -drivers: +drivers and devices: -```console hl_lines="5 6" ---8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md" +```console hl_lines="10-11" +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md:1" +``` + +```console hl_lines="4-5" +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-amd.md" ``` #### :octicons-download-16: Download the runtime from a release @@ -88,16 +92,12 @@ package includes the local CPU HAL drivers: #### :material-hammer-wrench: Build the runtime from source -Please make sure you have followed the -[Getting started](../../building-from-source/getting-started.md) page to build -IREE for your host platform and the -[Android cross-compilation](../../building-from-source/android.md) page if you -are cross compiling for Android. The local CPU HAL drivers are compiled in by -default on all platforms. - -Ensure that the `IREE_HAL_DRIVER_LOCAL_TASK` and -`IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF` (or other executable loader) CMake -options are `ON` when configuring for the target. +Please make sure you have followed one of the +[Building from source](../../building-from-source/index.md) pages to build +IREE for your target platform. The local CPU HAL drivers and devices are +compiled in by default on all platforms, though you should ensure that the +`IREE_HAL_DRIVER_LOCAL_TASK` and `IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF` +(or other executable loader) CMake options are `ON` when configuring. ## Compile and run a program @@ -105,30 +105,36 @@ With the requirements out of the way, we can now compile a model and run it. ### :octicons-file-code-16: Compile a program -The IREE compiler transforms a model into its final deployable format in many -sequential steps. A model authored with Python in an ML framework should use the -corresponding framework's import tool to convert into a format (i.e., -[MLIR](https://mlir.llvm.org/)) expected by the IREE compiler first. +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-import-onnx-mobilenet.md" -Using MobileNet v2 as an example, you can download the SavedModel with trained -weights from -[TensorFlow Hub](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification) -and convert it using IREE's -[TensorFlow importer](../ml-frameworks/tensorflow.md). Then run the following -command to compile with the `llvm-cpu` target: +Then run the following command to compile with the `llvm-cpu` target: -``` shell hl_lines="2" +``` shell hl_lines="2-3" iree-compile \ --iree-hal-target-backends=llvm-cpu \ - mobilenet_iree_input.mlir -o mobilenet_cpu.vmfb + --iree-llvmcpu-target-cpu=host \ + mobilenetv2.mlir -o mobilenet_cpu.vmfb ``` -!!! tip "Tip - CPU targets" +???+ tip "Tip - Target CPUs and CPU features" + + By default, the compiler will use a generic CPU target which will result in + poor performance. A target CPU or target CPU feature set should be selected + using one of these options: + + * `--iree-llvmcpu-target-cpu=...` + * `--iree-llvmcpu-target-cpu-features=...` + + When not cross compiling, passing `--iree-llvmcpu-target-cpu=host` is + usually sufficient on most devices. + +???+ tip "Tip - CPU targets" The `--iree-llvmcpu-target-triple` flag tells the compiler to generate code for a specific type of CPU. You can see the list of supported targets with - `iree-compile --iree-llvmcpu-list-targets`, or pass "host" to let LLVM - infer the triple from your host machine (e.g. `x86_64-linux-gnu`). + `iree-compile --iree-llvmcpu-list-targets`, or use the default value of + "host" to let LLVM infer the triple from your host machine + (e.g. `x86_64-linux-gnu`). ```console $ iree-compile --iree-llvmcpu-list-targets @@ -149,28 +155,21 @@ iree-compile \ x86-64 - 64-bit X86: EM64T and AMD64 ``` -!!! tip "Tip - CPU features" - - The `--iree-llvmcpu-target-cpu-features` flag tells the compiler to generate - code using certain CPU "features", like SIMD instruction sets. Like the - target triple, you can pass "host" to this flag to let LLVM infer the - features supported by your host machine. - ### :octicons-terminal-16: Run a compiled program -In the build directory, run the following command: +To run the compiled program: ``` shell hl_lines="2" -tools/iree-run-module \ +iree-run-module \ --device=local-task \ --module=mobilenet_cpu.vmfb \ - --function=predict \ - --input="1x224x224x3xf32=0" + --function=torch-jit-export \ + --input="1x3x224x224xf32=0" ``` -The above assumes the exported function in the model is named as `predict` and -it expects one 224x224 RGB image. We are feeding in an image with all 0 values -here for brevity, see `iree-run-module --help` for the format to specify +The above assumes the exported function in the model is named `torch-jit-export` +and it expects one 224x224 RGB image. We are feeding in an image with all 0 +values here for brevity, see `iree-run-module --help` for the format to specify concrete values. diff --git a/docs/website/docs/guides/deployment-configurations/gpu-cuda.md b/docs/website/docs/guides/deployment-configurations/gpu-cuda.md index 4dc12982e30d..a4ce0a89096d 100644 --- a/docs/website/docs/guides/deployment-configurations/gpu-cuda.md +++ b/docs/website/docs/guides/deployment-configurations/gpu-cuda.md @@ -52,16 +52,12 @@ Next you will need to get an IREE runtime that includes the CUDA HAL driver. You can check for CUDA support by looking for a matching driver and device: -```console hl_lines="3" ---8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md" +```console hl_lines="8" +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md:1" ``` ```console hl_lines="3" -$ iree-run-module --list_devices - - cuda://GPU-00000000-1111-2222-3333-444444444444 - local-sync:// - local-task:// +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-nvidia.md" ``` #### :octicons-download-16: Download the runtime from a release @@ -82,69 +78,63 @@ IREE from source, then enable the CUDA HAL driver with the ## Compile and run a program model -With the compiler and runtime ready, we can now compile programs and run them -on GPUs. +With the requirements out of the way, we can now compile a model and run it. ### :octicons-file-code-16: Compile a program -The IREE compiler transforms a model into its final deployable format in many -sequential steps. A model authored with Python in an ML framework should use the -corresponding framework's import tool to convert into a format (i.e., -[MLIR](https://mlir.llvm.org/)) expected by the IREE compiler first. +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-import-onnx-mobilenet.md" -Using MobileNet v2 as an example, you can download the SavedModel with trained -weights from -[TensorFlow Hub](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification) -and convert it using IREE's -[TensorFlow importer](../ml-frameworks/tensorflow.md). Then run one of the -following commands to compile: +Then run the following command to compile with the `cuda` target: ```shell hl_lines="2-3" iree-compile \ --iree-hal-target-backends=cuda \ --iree-cuda-target=<...> \ - mobilenet_iree_input.mlir -o mobilenet_cuda.vmfb + mobilenetv2.mlir -o mobilenet_cuda.vmfb ``` -Canonically a CUDA target (`iree-cuda-target`) matching the LLVM NVPTX backend -of the form `sm_` is needed to compile towards each GPU -architecture. If no architecture is specified then we will default to `sm_60`. +???+ tip "Tip - CUDA targets" + + Canonically a CUDA target (`iree-cuda-target`) matching the LLVM NVPTX + backend of the form `sm_` is needed to compile towards each GPU + architecture. If no architecture is specified then we will default to + `sm_60`. -Here is a table of commonly used architectures: + Here is a table of commonly used architectures: -| CUDA GPU | Target Architecture | Architecture Code Name -| ------------------- | ------------------- | ---------------------- -| NVIDIA P100 | `sm_60` | `pascal` -| NVIDIA V100 | `sm_70` | `volta` -| NVIDIA A100 | `sm_80` | `ampere` -| NVIDIA H100 | `sm_90` | `hopper` -| NVIDIA RTX20 series | `sm_75` | `turing` -| NVIDIA RTX30 series | `sm_86` | `ampere` -| NVIDIA RTX40 series | `sm_89` | `ada` + | CUDA GPU | Target Architecture | Architecture Code Name + | ------------------- | ------------------- | ---------------------- + | NVIDIA P100 | `sm_60` | `pascal` + | NVIDIA V100 | `sm_70` | `volta` + | NVIDIA A100 | `sm_80` | `ampere` + | NVIDIA H100 | `sm_90` | `hopper` + | NVIDIA RTX20 series | `sm_75` | `turing` + | NVIDIA RTX30 series | `sm_86` | `ampere` + | NVIDIA RTX40 series | `sm_89` | `ada` -In addition to the canonical `sm_` scheme, `iree-cuda-target` also -supports two additonal schemes to make a better developer experience: + In addition to the canonical `sm_` scheme, `iree-cuda-target` + also supports two additonal schemes to make a better developer experience: -* Architecture code names like `volta` or `ampere` -* GPU product names like `a100` or `rtx3090` + * Architecture code names like `volta` or `ampere` + * GPU product names like `a100` or `rtx3090` -These two schemes are translated into the canonical form under the hood. -We add support for common code/product names without aiming to be exhaustive. -If the ones you want are missing, please use the canonical form. + These two schemes are translated into the canonical form under the hood. + We add support for common code/product names without aiming to be exhaustive. + If the ones you want are missing, please use the canonical form. ### :octicons-terminal-16: Run a compiled program -Run the following command: +To run the compiled program: ``` shell hl_lines="2" iree-run-module \ --device=cuda \ --module=mobilenet_cuda.vmfb \ - --function=predict \ - --input="1x224x224x3xf32=0" + --function=torch-jit-export \ + --input="1x3x224x224xf32=0" ``` -The above assumes the exported function in the model is named as `predict` and -it expects one 224x224 RGB image. We are feeding in an image with all 0 values -here for brevity, see `iree-run-module --help` for the format to specify +The above assumes the exported function in the model is named `torch-jit-export` +and it expects one 224x224 RGB image. We are feeding in an image with all 0 +values here for brevity, see `iree-run-module --help` for the format to specify concrete values. diff --git a/docs/website/docs/guides/deployment-configurations/gpu-rocm.md b/docs/website/docs/guides/deployment-configurations/gpu-rocm.md index 20d34388ecf1..91c6fba2d1bd 100644 --- a/docs/website/docs/guides/deployment-configurations/gpu-rocm.md +++ b/docs/website/docs/guides/deployment-configurations/gpu-rocm.md @@ -54,16 +54,12 @@ Next you will need to get an IREE runtime that includes the HIP HAL driver. You can check for HIP support by looking for a matching driver and device: -```console hl_lines="4" ---8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md" +```console hl_lines="9" +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md:1" ``` ```console hl_lines="3" -$ iree-run-module --list_devices - - hip://GPU-00000000-1111-2222-3333-444444444444 - local-sync:// - local-task:// +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-amd.md" ``` #### :octicons-download-16: Download the runtime from a release @@ -89,88 +85,83 @@ on GPUs. ### :octicons-file-code-16: Compile a program -The IREE compiler transforms a model into its final deployable format in many -sequential steps. A model authored with Python in an ML framework should use the -corresponding framework's import tool to convert into a format (i.e., -[MLIR](https://mlir.llvm.org/)) expected by the IREE compiler first. +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-import-onnx-mobilenet.md" -Using MobileNet v2 as an example, you can download the SavedModel with trained -weights from -[TensorFlow Hub](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification) -and convert it using IREE's -[TensorFlow importer](../ml-frameworks/tensorflow.md). Then run one of the -following commands to compile: +Then run the following command to compile with the `rocm` target: ```shell hl_lines="2-5" iree-compile \ --iree-hal-target-backends=rocm \ --iree-hip-target=<...> \ - mobilenet_iree_input.mlir -o mobilenet_rocm.vmfb + mobilenetv2.mlir -o mobilenet_rocm.vmfb ``` -Note that IREE comes with bundled bitcode files, which are used for linking -certain intrinsics on AMD GPUs. These will be used automatically or if the -`--iree-hip-bc-dir` is empty. As additional support may be needed for -different chips, users can use this flag to point to an explicit directory. -For example, in ROCm installations on Linux, this is often found under -`/opt/rocm/amdgcn/bitcode`. - -A HIP target (`iree-hip-target`) matching the LLVM AMDGPU backend is needed to -compile towards each GPU chip. Here is a table of commonly used architectures: - -| AMD GPU | SKU Name | Target Architecture | Architecture Code Name | -| ------------------------ | ----------- | ------------------- | ---------------------- | -| AMD MI100 | `mi100` | `gfx908` | `cdna1` | -| AMD MI210 | `mi210` | `gfx90a` | `cdna2` | -| AMD MI250 | `mi250` | `gfx90a` | `cdna2` | -| AMD MI300X (early units) | N/A | `gfx940` | `cdna3` | -| AMD MI300A (early units) | N/A | `gfx941` | `cdna3` | -| AMD MI300A | `mi300a` | `gfx942` | `cdna3` | -| AMD MI300X | `mi300x` | `gfx942` | `cdna3` | -| AMD MI308X | `mi308x` | `gfx942` | `cdna3` | -| AMD MI325X | `mi325x` | `gfx942` | `cdna3` | -| AMD RX7900XTX | `rx7900xtx` | `gfx1100` | `rdna3` | -| AMD RX7900XT | `rx7900xt` | `gfx1100` | `rdna3` | -| AMD PRO W7900 | `w7900` | `gfx1100` | `rdna3` | -| AMD PRO W7800 | `w7800` | `gfx1100` | `rdna3` | -| AMD RX7800XT | `rx7800xt` | `gfx1101` | `rdna3` | -| AMD RX7700XT | `rx7700xt` | `gfx1101` | `rdna3` | -| AMD PRO V710 | `v710` | `gfx1101` | `rdna3` | -| AMD PRO W7700 | `w7700` | `gfx1101` | `rdna3` | - -For a more comprehensive list of prior GPU generations, you can refer to the -[LLVM AMDGPU backend](https://llvm.org/docs/AMDGPUUsage.html#processors). - -The `iree-hip-target` option support three schemes: - -1. The exact GPU product (SKU), e.g., `--iree-hip-target=mi300x`. This allows - the compiler to know about both the target architecture and about additional - hardware details like the number of compute units. This extra information - guides some compiler heuristics and allows for SKU-specific [tuning - specs](../../reference/tuning.md). -2. The GPU architecture, as defined by LLVM, e.g., `--iree-hip-target=gfx942`. - This scheme allows for architecture-specific [tuning - specs](../../reference/tuning.md) only. -3. The architecture code name, e.g., `--iree-hip-target=cdna3`. This scheme gets - translated to closes matching GPU architecture under the hood. - -We support for common code/SKU names without aiming to be exhaustive. If the -ones you want are missing, please use the GPU architecture scheme (2.) as it is -the most general. +???+ tip "Tip - HIP bitcode files" + + That IREE comes with bundled bitcode files, which are used for linking + certain intrinsics on AMD GPUs. These will be used automatically or if the + `--iree-hip-bc-dir` is empty. As additional support may be needed for + different chips, users can use this flag to point to an explicit directory. + For example, in ROCm installations on Linux, this is often found under + `/opt/rocm/amdgcn/bitcode`. + +???+ tip "Tip - HIP targets" + + A HIP target (`iree-hip-target`) matching the LLVM AMDGPU backend is needed + to compile towards each GPU chip. Here is a table of commonly used + architectures: + + | AMD GPU | SKU Name | Target Architecture | Architecture Code Name | + | ------------------------ | ----------- | ------------------- | ---------------------- | + | AMD MI100 | `mi100` | `gfx908` | `cdna1` | + | AMD MI210 | `mi210` | `gfx90a` | `cdna2` | + | AMD MI250 | `mi250` | `gfx90a` | `cdna2` | + | AMD MI300A | `mi300a` | `gfx942` | `cdna3` | + | AMD MI300X | `mi300x` | `gfx942` | `cdna3` | + | AMD MI308X | `mi308x` | `gfx942` | `cdna3` | + | AMD MI325X | `mi325x` | `gfx942` | `cdna3` | + | AMD RX7900XTX | `rx7900xtx` | `gfx1100` | `rdna3` | + | AMD RX7900XT | `rx7900xt` | `gfx1100` | `rdna3` | + | AMD PRO W7900 | `w7900` | `gfx1100` | `rdna3` | + | AMD PRO W7800 | `w7800` | `gfx1100` | `rdna3` | + | AMD RX7800XT | `rx7800xt` | `gfx1101` | `rdna3` | + | AMD RX7700XT | `rx7700xt` | `gfx1101` | `rdna3` | + | AMD PRO V710 | `v710` | `gfx1101` | `rdna3` | + | AMD PRO W7700 | `w7700` | `gfx1101` | `rdna3` | + + For a more comprehensive list of prior GPU generations, you can refer to the + [LLVM AMDGPU backend](https://llvm.org/docs/AMDGPUUsage.html#processors). + + The `iree-hip-target` option support three schemes: + + 1. The exact GPU product (SKU), e.g., `--iree-hip-target=mi300x`. This + allows the compiler to know about both the target architecture and about + additional hardware details like the number of compute units. This extra + information guides some compiler heuristics and allows for SKU-specific + [tuning specs](../../reference/tuning.md). + 2. The GPU architecture, as defined by LLVM, e.g., + `--iree-hip-target=gfx942`. This scheme allows for architecture-specific + [tuning specs](../../reference/tuning.md) only. + 3. The architecture code name, e.g., `--iree-hip-target=cdna3`. This scheme + gets translated to closes matching GPU architecture under the hood. + + We support for common code/SKU names without aiming to be exhaustive. If the + ones you want are missing, please use the GPU architecture scheme (2.) as it + is the most general. ### :octicons-terminal-16: Run a compiled program -Run the following command: +To run the compiled program: ``` shell hl_lines="2" iree-run-module \ --device=hip \ --module=mobilenet_rocm.vmfb \ - --function=predict \ - --input="1x224x224x3xf32=0" + --function=torch-jit-export \ + --input="1x3x224x224xf32=0" ``` -The above assumes the exported function in the model is named as `predict` and -it expects one 224x224 RGB image. We are feeding in an image with all 0 values -here for brevity, see `iree-run-module --help` for the format to specify +The above assumes the exported function in the model is named `torch-jit-export` +and it expects one 224x224 RGB image. We are feeding in an image with all 0 +values here for brevity, see `iree-run-module --help` for the format to specify concrete values. diff --git a/docs/website/docs/guides/deployment-configurations/gpu-vulkan.md b/docs/website/docs/guides/deployment-configurations/gpu-vulkan.md index 48e8a1bba797..e17c104de8fc 100644 --- a/docs/website/docs/guides/deployment-configurations/gpu-vulkan.md +++ b/docs/website/docs/guides/deployment-configurations/gpu-vulkan.md @@ -89,13 +89,9 @@ package includes the SPIR-V compiler: Please make sure you have followed the [Getting started](../../building-from-source/getting-started.md) page to build -IREE for your host platform and the -[Android cross-compilation](../../building-from-source/android.md) page if you -are cross compiling for Android. The SPIR-V compiler backend is compiled in by -default on all platforms. - -Ensure that the `IREE_TARGET_BACKEND_VULKAN_SPIRV` CMake option is `ON` when -configuring for the host. +IREE for your host platform. The SPIR-V compiler backend is compiled in by +default on all platforms, though you should ensure that the +`IREE_TARGET_BACKEND_VULKAN_SPIRV` CMake option is `ON` when configuring. !!! tip `iree-compile` will be built under the `iree-build/tools/` directory. You @@ -107,17 +103,12 @@ Next you will need to get an IREE runtime that supports the Vulkan HAL driver. You can check for Vulkan support by looking for a matching driver and device: -```console hl_lines="7" ---8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md" +```console hl_lines="12" +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md:1" ``` ```console hl_lines="6" -$ iree-run-module --list_devices - - cuda://GPU-00000000-1111-2222-3333-444444444444 - local-sync:// - local-task:// - vulkan://00000000-1111-2222-3333-444444444444 +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-amd.md" ``` #### :octicons-download-16: Download the runtime from a release @@ -131,97 +122,83 @@ package includes the Vulkan HAL driver: #### :material-hammer-wrench: Build the runtime from source -Please make sure you have followed the -[Getting started](../../building-from-source/getting-started.md) page to build -IREE for Linux/Windows and the -[Android cross-compilation](../../building-from-source/android.md) page for -Android. The Vulkan HAL driver is compiled in by default on non-Apple platforms. - -Ensure that the `IREE_HAL_DRIVER_VULKAN` CMake option is `ON` when configuring -for the target. +Please make sure you have followed one of the +[Building from source](../../building-from-source/index.md) pages to build +IREE for your target platform. The Vulkan HAL driver is compiled in by default +on supported platforms, though you should ensure that the +`IREE_HAL_DRIVER_VULKAN` CMake option is `ON` when configuring. ## Compile and run a program -With the SPIR-V compiler and Vulkan runtime, we can now compile programs and run -them on GPUs. +With the requirements out of the way, we can now compile a model and run it. ### :octicons-file-code-16: Compile a program -The IREE compiler transforms a model into its final deployable format in many -sequential steps. A model authored with Python in an ML framework should use the -corresponding framework's import tool to convert into a format (i.e., -[MLIR](https://mlir.llvm.org/)) expected by the IREE compiler first. +--8<-- "docs/website/docs/guides/deployment-configurations/snippets/_iree-import-onnx-mobilenet.md" -Using MobileNet v2 as an example, you can download the SavedModel with trained -weights from -[TensorFlow Hub](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification) -and convert it using IREE's -[TensorFlow importer](../ml-frameworks/tensorflow.md). Then run the following -command to compile with the `vulkan-spirv` target: +Then run the following command to compile with the `vulkan-spirv` target: ``` shell hl_lines="2 3" iree-compile \ --iree-hal-target-backends=vulkan-spirv \ --iree-vulkan-target=<...> \ - mobilenet_iree_input.mlir -o mobilenet_vulkan.vmfb + mobilenetv2.mlir -o mobilenet_vulkan.vmfb ``` -`iree-vulkan-target` specifies the GPU architecture to target. It accepts a few -schemes: - -* LLVM CodeGen backend style: this is using LLVM AMDGPU/NVPTX CodeGen targets - like `gfx1100` for AMD RX 7900XTX and `sm_86` for NVIDIA RTX 3090 GPUs. -* Architecture code name style: e.g., using `rdna3`/`valhall4`/`ampere`/`adreno` - for AMD/ARM/NVIDIA/Qualcomm GPUs. -* Product name style(1): e.g., using `rx7900xtx`/`a100` for corresponding GPUs. - -Here are a few examples showing how you can target various recent common GPUs: - -| GPU | Target Architecture | Architecture Code Name | Product Name -| ------------------- | ------------------- | ---------------------- | ------------ -| AMD RX7900XTX | `gfx1100` | `rdna3` | `rx7900xtx` -| AMD RX7900XT | `gfx1100` | `rdna3` | `rx7900xt` -| AMD RX7800XT | `gfx1101` | `rdna3` | `rx7800xt` -| AMD RX7700XT | `gfx1101` | `rdna3` | `rx7700xt` -| AMD RX6000 series | | `rdna2` | -| AMD RX5000 series | | `rdna1` | -| ARM Mali G715 | | `valhall4` | e.g., `mali-g715` -| ARM Mali G510 | | `valhall3` | e.g., `mali-g510` -| ARM GPUs | | `valhall` | -| NVIDIA RTX40 series | `sm_89` | `ada` | e.g., `rtx4090` -| NVIDIA RTX30 series | `sm_86` | `ampere` | e.g., `rtx3080ti` -| NVIDIA RTX20 series | `sm_75` | `turing` | e.g., `rtx2070super` -| Qualcomm GPUs | | `adreno` | - -If no target is specified, then a safe but more limited default will be used. - -!!! note annotate - Note that We don't support the full spectrum of GPUs here(2). - This is more of a mechanism to help us develop IREE itself--in the long term - we want to perform multiple targetting to generate to multiple architectures - if no target is given. - -1. Note that we only support very limited GPUs that we are actively developing - against in this category, particularly for desktops. -2. It's also impossible to capture all details of a Vulkan implementation - with a target triple, given the allowed variances on extensions, properties, - limits, etc. So the target triple is just an approximation for usage. +???+ tip "Tip - Vulkan targets" + + The `--iree-vulkan-target` specifies the GPU architecture to target. It + accepts a few schemes: + + * LLVM CodeGen backend style: this is using LLVM AMDGPU/NVPTX CodeGen targets + like `gfx1100` for AMD RX 7900XTX and `sm_86` for NVIDIA RTX 3090 GPUs. + * Architecture code name style like `rdna3`/`valhall4`/`ampere`/`adreno` + for AMD/ARM/NVIDIA/Qualcomm GPUs. + * Product name style: e.g., using `rx7900xtx`/`a100` for corresponding GPUs. + + Here are a few examples showing how you can target various recent common GPUs: + + | GPU | Target Architecture | Architecture Code Name | Product Name + | ------------------- | ------------------- | ---------------------- | ------------ + | AMD RX7900XTX | `gfx1100` | `rdna3` | `rx7900xtx` + | AMD RX7900XT | `gfx1100` | `rdna3` | `rx7900xt` + | AMD RX7800XT | `gfx1101` | `rdna3` | `rx7800xt` + | AMD RX7700XT | `gfx1101` | `rdna3` | `rx7700xt` + | AMD RX6000 series | | `rdna2` | + | AMD RX5000 series | | `rdna1` | + | ARM Mali G715 | | `valhall4` | e.g., `mali-g715` + | ARM Mali G510 | | `valhall3` | e.g., `mali-g510` + | ARM GPUs | | `valhall` | + | NVIDIA RTX40 series | `sm_89` | `ada` | e.g., `rtx4090` + | NVIDIA RTX30 series | `sm_86` | `ampere` | e.g., `rtx3080ti` + | NVIDIA RTX20 series | `sm_75` | `turing` | e.g., `rtx2070super` + | Qualcomm GPUs | | `adreno` | + + If no target is specified, then a safe but more limited default will be used. + + Note that we don't support the full spectrum of GPUs here and it is + impossible to capture all details of a Vulkan implementation with a target + triple, given the allowed variances on extensions, properties, limits, etc. + So the target triple is just an approximation for usage. This is more of a + mechanism to help us develop IREE itself--in the long term we want to + perform multiple targetting to generate to multiple architectures if no + target is given. ### :octicons-terminal-16: Run a compiled program -In the build directory, run the following command: +To run the compiled program: ``` shell hl_lines="2" -tools/iree-run-module \ +iree-run-module \ --device=vulkan \ --module=mobilenet_vulkan.vmfb \ - --function=predict \ - --input="1x224x224x3xf32=0" + --function=torch-jit-export \ + --input="1x3x224x224xf32=0" ``` -The above assumes the exported function in the model is named as `predict` and -it expects one 224x224 RGB image. We are feeding in an image with all 0 values -here for brevity, see `iree-run-module --help` for the format to specify +The above assumes the exported function in the model is named `torch-jit-export` +and it expects one 224x224 RGB image. We are feeding in an image with all 0 +values here for brevity, see `iree-run-module --help` for the format to specify concrete values. diff --git a/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md b/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md index 883081506982..d1eae7be8fc3 100644 --- a/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md +++ b/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md @@ -22,7 +22,7 @@ !!! tip `iree-compile` and other tools are installed to your python module installation path. If you pip install with the user mode, it is under - `${HOME}/.local/bin`, or `%APPDATA%Python` on Windows. You may want to + `${HOME}/.local/bin`, or `%APPDATA%\Python` on Windows. You may want to include the path in your system's `PATH` environment variable: ```shell diff --git a/docs/website/docs/guides/deployment-configurations/snippets/_iree-import-onnx-mobilenet.md b/docs/website/docs/guides/deployment-configurations/snippets/_iree-import-onnx-mobilenet.md new file mode 100644 index 000000000000..72095c4eac6f --- /dev/null +++ b/docs/website/docs/guides/deployment-configurations/snippets/_iree-import-onnx-mobilenet.md @@ -0,0 +1,17 @@ +The IREE compiler transforms a model into its final deployable format in several +sequential steps. A model authored with Python in an ML framework should use the +corresponding framework's import tool to convert into a format (i.e., +[MLIR](https://mlir.llvm.org/)) expected by the IREE compiler first. + +Using a +[MobileNet model](https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet) +as an example, import using IREE's [ONNX importer](../ml-frameworks/onnx.md): + +```bash +# Download the model you want to compile and run. +wget https://github.com/onnx/models/raw/refs/heads/main/validated/vision/classification/mobilenet/model/mobilenetv2-10.onnx + +# Import to MLIR using IREE's ONNX importer. +pip install iree-base-compiler[onnx] +iree-import-onnx mobilenetv2-10.onnx --opset-version 17 -o mobilenetv2.mlir +``` diff --git a/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-amd.md b/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-amd.md new file mode 100644 index 000000000000..f472f1c06ed1 --- /dev/null +++ b/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-amd.md @@ -0,0 +1,6 @@ +$ iree-run-module --list_devices + + hip://GPU-00000000-1111-2222-3333-444444444444 + local-sync:// + local-task:// + vulkan://00000000-1111-2222-3333-444444444444 diff --git a/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-nvidia.md b/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-nvidia.md new file mode 100644 index 000000000000..b120238269b5 --- /dev/null +++ b/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-device-list-nvidia.md @@ -0,0 +1,6 @@ +$ iree-run-module --list_devices + + cuda://GPU-00000000-1111-2222-3333-444444444444 + local-sync:// + local-task:// + vulkan://00000000-1111-2222-3333-444444444444 diff --git a/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md b/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md index cd92c2ca3e11..7cab4147d4f2 100644 --- a/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md +++ b/docs/website/docs/guides/deployment-configurations/snippets/_iree-run-module-driver-list.md @@ -1,5 +1,11 @@ + $ iree-run-module --list_drivers +# ============================================================================ +# Available HAL drivers +# ============================================================================ +# Use --list_devices={driver name} to enumerate available devices. + cuda: NVIDIA CUDA HAL driver (via dylib) hip: HIP HAL driver (via dylib) local-sync: Local execution using a lightweight inline synchronous queue diff --git a/docs/website/docs/guides/ml-frameworks/onnx.md b/docs/website/docs/guides/ml-frameworks/onnx.md index f606431416f9..ae9a95df14a7 100644 --- a/docs/website/docs/guides/ml-frameworks/onnx.md +++ b/docs/website/docs/guides/ml-frameworks/onnx.md @@ -10,14 +10,6 @@ icon: simple/onnx # ONNX support -!!! caution "Caution - under development" - - Support for a broad set of [ONNX operators](https://onnx.ai/onnx/operators/) - and [data types](https://onnx.ai/onnx/intro/concepts.html#supported-types) - is an active investment area. See the - [ONNX Op Support tracking issue](https://github.com/nod-ai/SHARK-ModelDev/issues/215) - for the latest status. - ## :octicons-book-16: Overview Machine learning models using the @@ -46,39 +38,33 @@ graph LR ## :octicons-download-16: Prerequisites -1. Install ONNX: - - ``` shell - python -m pip install onnx - ``` - -2. Install IREE packages, either by - [building from source](../../building-from-source/getting-started.md#python-bindings) - or from pip: +Install IREE packages, either by +[building from source](../../building-from-source/getting-started.md#python-bindings) +or from pip: - === ":octicons-package-16: Stable releases" +=== ":octicons-package-16: Stable releases" - Stable release packages are [published to PyPI](https://pypi.org/). + Stable release packages are [published to PyPI](https://pypi.org/). - ``` shell - python -m pip install \ - iree-base-compiler[onnx] \ - iree-base-runtime - ``` + ``` shell + python -m pip install \ + iree-base-compiler[onnx] \ + iree-base-runtime + ``` - === ":octicons-beaker-16: Nightly releases" +=== ":octicons-beaker-16: Nightly releases" - Nightly pre-releases are published on - [GitHub releases](https://github.com/iree-org/iree/releases). + Nightly pre-releases are published on + [GitHub releases](https://github.com/iree-org/iree/releases). - ``` shell - python -m pip install \ - --find-links https://iree.dev/pip-release-links.html \ - --upgrade \ - --pre \ - iree-base-compiler[onnx] \ - iree-base-runtime - ``` + ``` shell + python -m pip install \ + --find-links https://iree.dev/pip-release-links.html \ + --upgrade \ + --pre \ + iree-base-compiler[onnx] \ + iree-base-runtime + ``` ## :octicons-rocket-16: Quickstart @@ -88,11 +74,15 @@ graph LR 2. Convert the `.onnx` file into MLIR using the `iree-import-onnx` tool: ```shell - iree-import-onnx [model.onnx] -o [model.mlir] + iree-import-onnx \ + [model.onnx] \ + --opset-version 17 \ + -o [model.mlir] ``` This tool produces a MLIR file with the help of the - [torch-mlir](https://github.com/llvm/torch-mlir) project. + [torch-mlir](https://github.com/llvm/torch-mlir) project. Run + `iree-import-onnx --help` for a full list of options. 3. Once imported, the standard set of tools and APIs available for any of IREE's [deployment configurations](../deployment-configurations/index.md) and @@ -102,6 +92,7 @@ graph LR iree-compile \ model.mlir \ --iree-hal-target-backends=llvm-cpu \ + --iree-llvmcpu-target-cpu=host \ -o model_cpu.vmfb iree-run-module \ @@ -123,6 +114,12 @@ Importer tests | [torch-mlir `test/python/onnx_importer`](https://github.com/llv ## :octicons-question-16: Troubleshooting +Support for a broad set of [ONNX operators](https://onnx.ai/onnx/operators/) +and [data types](https://onnx.ai/onnx/intro/concepts.html#supported-types) +is an active investment area. See the +[ONNX Op Support tracking issue](https://github.com/nod-ai/SHARK-ModelDev/issues/215) +for the latest status. + ### Failed to legalize operation that was explicitly marked illegal If you see an error compiling a converted .mlir file like this: diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index 1450be7f7204..f7bdd17dcb33 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -1356,6 +1356,10 @@ void SetupHalBindings(nanobind::module_ m) { .value("BFLOAT_16", IREE_HAL_ELEMENT_TYPE_BFLOAT_16) .value("COMPLEX_64", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64) .value("COMPLEX_128", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128) + .value("FLOAT_8_E4M3", IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3) + .value("FLOAT_8_E4M3_FNUZ", IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ) + .value("FLOAT_8_E5M2", IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2) + .value("FLOAT_8_E5M2_FNUZ", IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ) .export_values() .def("__int__", [](enum iree_hal_element_types_t self) { return (uint64_t)self; }); diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi index 499e0c6220a9..c6237c2481a6 100644 --- a/runtime/bindings/python/iree/runtime/_binding.pyi +++ b/runtime/bindings/python/iree/runtime/_binding.pyi @@ -258,6 +258,10 @@ class HalElementType: BOOL_8: ClassVar[HalElementType] = ... COMPLEX_128: ClassVar[HalElementType] = ... COMPLEX_64: ClassVar[HalElementType] = ... + FLOAT_8_E4M3: ClassVar[HalElementType] = ... + FLOAT_8_E4M3_FNUZ: ClassVar[HalElementType] = ... + FLOAT_8_E5M2: ClassVar[HalElementType] = ... + FLOAT_8_E5M2_FNUZ: ClassVar[HalElementType] = ... FLOAT_16: ClassVar[HalElementType] = ... FLOAT_32: ClassVar[HalElementType] = ... FLOAT_64: ClassVar[HalElementType] = ... diff --git a/tests/e2e/regression/dynamic_tosa_quantized_fully_connected_issue_10859.mlir b/tests/e2e/regression/dynamic_tosa_quantized_fully_connected_issue_10859.mlir index 460991a9be87..7ee0bf8d5c26 100644 --- a/tests/e2e/regression/dynamic_tosa_quantized_fully_connected_issue_10859.mlir +++ b/tests/e2e/regression/dynamic_tosa_quantized_fully_connected_issue_10859.mlir @@ -1,13 +1,13 @@ // Regression testcase from https://github.com/iree-org/iree/issues/10859 func.func @main(%arg0: tensor<256xi8>, %arg1: tensor<2xi32>, %arg2: tensor<2x32xi8>, %arg3: tensor<32xi32>, %arg4: tensor<32x32xi8>, %arg5: tensor<32xi32>, %arg6: tensor<32x3360xi8>, %arg7: tensor) -> (tensor) { - %0 = tosa.fully_connected %arg7, %arg6, %arg5 {quantization_info = #tosa.conv_quant} : (tensor, tensor<32x3360xi8>, tensor<32xi32>) -> tensor + %0 = tosa.fully_connected %arg7, %arg6, %arg5 {input_zp = -128 : i32, weight_zp = 0 : i32} : (tensor, tensor<32x3360xi8>, tensor<32xi32>) -> tensor %1 = tosa.rescale %0 {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} : (tensor) -> tensor %2 = tosa.clamp %1 {max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} : (tensor) -> tensor - %3 = tosa.fully_connected %2, %arg4, %arg3 {quantization_info = #tosa.conv_quant} : (tensor, tensor<32x32xi8>, tensor<32xi32>) -> tensor + %3 = tosa.fully_connected %2, %arg4, %arg3 {input_zp = -128 : i32, weight_zp = 0 : i32} : (tensor, tensor<32x32xi8>, tensor<32xi32>) -> tensor %4 = tosa.rescale %3 {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} : (tensor) -> tensor %5 = tosa.clamp %4 {max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} : (tensor) -> tensor - %6 = tosa.fully_connected %5, %arg2, %arg1 {quantization_info = #tosa.conv_quant} : (tensor, tensor<2x32xi8>, tensor<2xi32>) -> tensor + %6 = tosa.fully_connected %5, %arg2, %arg1 {input_zp = -128 : i32, weight_zp = 0 : i32} : (tensor, tensor<2x32xi8>, tensor<2xi32>) -> tensor %7 = tosa.rescale %6 {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 44 : i32, per_channel = false, scale32 = true, shift = array} : (tensor) -> tensor %8 = tosa.table %7, %arg0 : (tensor, tensor<256xi8>) -> tensor return %8 : tensor diff --git a/third_party/llvm-project b/third_party/llvm-project index 8cb4b3e21e03..ea7924e1412b 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 8cb4b3e21e03d3e029ade27139eab1a25720c773 +Subproject commit ea7924e1412b545e2065e30446b721a89a5e07d3 diff --git a/third_party/torch-mlir b/third_party/torch-mlir index eefc553ffca4..36c47e652b58 160000 --- a/third_party/torch-mlir +++ b/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit eefc553ffca45fd2432641918a73b810f64bba39 +Subproject commit 36c47e652b58db24b3c6bafc56e103d39a9befe1